Sort blocks topologically

Also lots of spirv-val work
This commit is contained in:
khyperia 2020-09-10 16:37:12 +02:00
parent 89065c1db0
commit c1134788cd
9 changed files with 168 additions and 63 deletions

View File

@ -14,6 +14,7 @@ crate-type = ["dylib"]
[dependencies]
rspirv = { git = "https://github.com/gfx-rs/rspirv" }
topological-sort = "0.1"
[dev-dependencies]
pretty_assertions = "0.6"

View File

@ -220,10 +220,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Ignore => SpirvType::Void.def(cx),
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.spirv_type_immediate(cx),
PassMode::Cast(cast_target) => cast_target.spirv_type(cx),
PassMode::Indirect(_arg_attributes, wide_ptr_attrs) => {
if wide_ptr_attrs.is_some() {
panic!("TODO: PassMode::Indirect wide ptr not supported for return type");
}
PassMode::Indirect(..) => {
let pointee = self.ret.layout.spirv_type(cx);
let pointer = SpirvType::Pointer {
storage_class: StorageClass::Function,
@ -300,7 +297,8 @@ fn trans_type_impl<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>, is_immedia
// Note! Do not pass through is_immediate here - they're wrapped in a struct, hence, not immediate.
let one_spirv = trans_scalar(cx, ty, one, Some(0), false);
let two_spirv = trans_scalar(cx, ty, two, Some(1), false);
// TODO: Note: We can't use auto_struct_layout here because the spirv types here might be undefined.
// Note: We can't use auto_struct_layout here because the spirv types here might be undefined due to
// recursive pointer types.
let one_offset = Size::ZERO;
let two_offset = one.value.size(cx).align_to(two.value.align(cx).abi);
let size = if ty.is_unsized() { None } else { Some(ty.size) };
@ -352,7 +350,6 @@ fn trans_scalar<'tcx>(
}
match scalar.value {
// TODO: Do we use scalar.valid_range?
Primitive::Int(width, mut signedness) => {
if cx.kernel_mode {
signedness = false;
@ -434,10 +431,6 @@ fn dig_scalar_pointee<'tcx>(
let ptr_ty = cx.layout_of(cx.tcx.mk_mut_ptr(ty.ty.boxed_ty()));
dig_scalar_pointee(cx, ptr_ty, index)
}
// TyKind::Tuple(substs) if substs.len() == 1 => {
// let item = cx.layout_of(ty.ty.tuple_fields().next().unwrap());
// trans_scalar_known_ty(cx, item, scalar, is_immediate)
// }
TyKind::Tuple(_) | TyKind::Adt(..) | TyKind::Closure(..) => {
dig_scalar_pointee_adt(cx, ty, index)
}
@ -454,6 +447,9 @@ fn dig_scalar_pointee_adt<'tcx>(
index: Option<usize>,
) -> PointeeTy<'tcx> {
match &ty.variants {
// If it's a Variants::Multiple, then we want to emit the type of the dataful variant, not the type of the
// discriminant. This is because the discriminant can e.g. have type *mut(), whereas we want the full underlying
// type, only available in the dataful variant.
Variants::Multiple {
tag_encoding,
tag_field,
@ -513,8 +509,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Word {
"FieldsShape::Primitive not supported yet in trans_type: {:?}",
ty
),
// TODO: Is this the right thing to do?
FieldsShape::Union(_field_count) => {
FieldsShape::Union(_) => {
assert_ne!(ty.size.bytes(), 0, "{:#?}", ty);
assert!(!ty.is_unsized(), "{:#?}", ty);
let byte = SpirvType::Integer(8, false).def(cx);
@ -525,7 +520,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Word {
}
.def(cx)
}
FieldsShape::Array { stride: _, count } => {
FieldsShape::Array { stride, count } => {
let element_type = trans_type_impl(cx, ty.field(cx, 0), false);
if ty.is_unsized() {
// There's a potential for this array to be sized, but the element to be unsized, e.g. `[[u8]; 5]`.
@ -547,8 +542,13 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Word {
}
.def(cx)
} else {
// TODO: Assert stride is same as spirv's stride?
let count_const = cx.constant_u32(count as u32).def;
let element_spv = cx.lookup_type(element_type);
let stride_spv = element_spv
.sizeof(cx)
.expect("Unexpected unsized type in sized FieldsShape::Array")
.align_to(element_spv.alignof(cx));
assert_eq!(stride_spv, stride);
SpirvType::Array {
element: element_type,
count: count_const,

View File

@ -178,7 +178,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
let (signed, construct_case) = match self.lookup_type(v.ty) {
SpirvType::Integer(width, signed) => {
let construct_case = match width {
// TODO: How are negative values represented? sign-extended? if so, they'll be >MAX
8 => |signed, v| {
if v > u8::MAX as u128 {
panic!("Switches to values above u8::MAX not supported: {:?}", v)
@ -285,7 +284,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
match self.lookup_type(ty) {
SpirvType::Integer(_, _) => self.emit().bitwise_and(ty, None, lhs.def, rhs.def),
SpirvType::Bool => self.emit().logical_and(ty, None, lhs.def, rhs.def),
o => panic!("TODO: and() not implemented for type {}", o.debug(ty, self)),
o => panic!("and() not implemented for type {}", o.debug(ty, self)),
}
.unwrap()
.with_type(ty)
@ -296,7 +295,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
match self.lookup_type(ty) {
SpirvType::Integer(_, _) => self.emit().bitwise_or(ty, None, lhs.def, rhs.def),
SpirvType::Bool => self.emit().logical_or(ty, None, lhs.def, rhs.def),
o => panic!("TODO: or() not implemented for type {}", o.debug(ty, self)),
o => panic!("or() not implemented for type {}", o.debug(ty, self)),
}
.unwrap()
.with_type(ty)
@ -307,7 +306,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
match self.lookup_type(ty) {
SpirvType::Integer(_, _) => self.emit().bitwise_xor(ty, None, lhs.def, rhs.def),
SpirvType::Bool => self.emit().logical_not_equal(ty, None, lhs.def, rhs.def),
o => panic!("TODO: xor() not implemented for type {}", o.debug(ty, self)),
o => panic!("xor() not implemented for type {}", o.debug(ty, self)),
}
.unwrap()
.with_type(ty)
@ -316,10 +315,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
match self.lookup_type(val.ty) {
SpirvType::Integer(_, _) => self.emit().not(val.ty, None, val.def),
SpirvType::Bool => self.emit().logical_not(val.ty, None, val.def),
o => panic!(
"TODO: not() not implemented for type {}",
o.debug(val.ty, self)
),
o => panic!("not() not implemented for type {}", o.debug(val.ty, self)),
}
.unwrap()
.with_type(val.ty)
@ -834,8 +830,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
fn icmp(&mut self, op: IntPredicate, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
// TODO: Do we want to assert signedness matches the opcode? Is it possible to have one that doesn't match? Does
// spir-v allow nonmatching instructions?
// Note: the signedness of the opcode doesn't have to match the signedness of the operands.
use IntPredicate::*;
assert_ty_eq!(self, lhs.ty, rhs.ty);
let b = SpirvType::Bool.def(self);
@ -1112,7 +1107,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
.def(self);
if self.builder.lookup_const(elt.def).is_ok() {
// TODO: Cache this?
self.emit()
.constant_composite(result_type, std::iter::repeat(elt.def).take(num_elts))
} else {
@ -1362,10 +1356,11 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
unsafe fn delete_basic_block(&mut self, _bb: Self::BasicBlock) {
todo!()
// Ignore: If we were to delete the block, then other builder's selected_block index would become invalid, due
// to shifting blocks.
}
fn do_not_inline(&mut self, _llret: Self::Value) {
todo!()
// Ignore
}
}

View File

@ -649,9 +649,12 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> {
fn abort(&mut self) {
// codegen_llvm uses call(llvm.trap) here, so it is not a block terminator
if !self.kernel_mode {
self.emit().kill().unwrap();
*self = self.build_sibling_block("abort_continue");
}
// TODO: Figure out an appropriate instruction for kernel mode.
}
fn assume(&mut self, _val: Self::Value) {
// TODO: llvm.assume

View File

@ -160,6 +160,20 @@ impl BuilderSpirv {
Err("Definition not found")
}
pub fn lookup_const_bool(&self, def: Word) -> Result<bool, &'static str> {
let builder = self.builder.borrow();
for inst in &builder.module_ref().types_global_values {
if inst.result_id == Some(def) {
return match inst.class.opcode {
Op::ConstantFalse => Ok(true),
Op::ConstantTrue => Ok(true),
_ => Err("Instruction not OpConstantTrue/False"),
};
}
}
Err("Definition not found")
}
pub fn lookup_global_constant_variable(&self, def: Word) -> Result<Word, &'static str> {
// TODO: Maybe assert that this indeed a constant?
let builder = self.builder.borrow();

View File

@ -1,7 +1,7 @@
use crate::abi::ConvSpirvType;
use crate::builder::ExtInst;
use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvValue, SpirvValueExt};
use crate::poison_pass::poison_pass;
use crate::finalizing_passes::{block_ordering_pass, poison_pass};
use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache};
use rspirv::dr::{Module, Operand};
use rspirv::spirv::{Decoration, FunctionControl, LinkageType, StorageClass, Word};
@ -144,6 +144,9 @@ impl<'tcx> CodegenCx<'tcx> {
poison_pass(&mut result, &mut self.poisoned_values.borrow_mut());
// defs go before fns
result.functions.sort_by_key(|f| !f.blocks.is_empty());
for function in &mut result.functions {
block_ordering_pass(function);
}
result
}
@ -383,9 +386,6 @@ impl<'tcx> LayoutTypeMethods<'tcx> for CodegenCx<'tcx> {
}
impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> {
// TODO: llvm types are signless, as in neither signed nor unsigned (I think?), so these are expected to be
// signless. Do we want a SpirvType::Integer(_, Signless) to indicate the sign is unknown, and to do conversions at
// appropriate places?
fn type_i1(&self) -> Self::Type {
SpirvType::Bool.def(self)
}
@ -462,10 +462,7 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> {
}
.def(self)
}
fn type_ptr_to_ext(&self, ty: Self::Type, address_space: AddressSpace) -> Self::Type {
if address_space != AddressSpace::DATA {
panic!("TODO: Unimplemented AddressSpace {:?}", address_space)
}
fn type_ptr_to_ext(&self, ty: Self::Type, _address_space: AddressSpace) -> Self::Type {
SpirvType::Pointer {
storage_class: StorageClass::Function,
pointee: ty,
@ -526,7 +523,7 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> {
.emit_global()
.variable(ty, None, StorageClass::Function, Some(cv.def))
.with_type(ty);
// TODO: These should be StorageClass::Private, so just poison for now.
// TODO: These should be StorageClass::UniformConstant, so just poison for now.
self.poison(result.def);
result
}
@ -544,23 +541,16 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> {
SpirvType::Pointer { pointee, .. } => pointee,
other => panic!("global had non-pointer type {}", other.debug(g.ty, self)),
};
let v = create_const_alloc(self, alloc, value_ty);
let mut v = create_const_alloc(self, alloc, value_ty);
if self.lookup_type(v.ty) == SpirvType::Bool {
// convert bool -> i8
todo!();
let val = self.builder.lookup_const_bool(v.def).unwrap();
let val_int = if val { 1 } else { 0 };
v = self.constant_u8(val_int);
}
// let instance = Instance::mono(self.tcx, def_id);
// let ty = instance.ty(self.tcx, ParamEnv::reveal_all());
// let llty = self.layout_of(ty).llvm_type(self);
assert_ty_eq!(self, value_ty, v.ty);
self.builder.set_global_initializer(g.def, v.def);
// if attrs.flags.contains(CodegenFnAttrFlags::USED) {
// self.add_used_global(g);
// }
}
/// Mark the given global value as "used", to prevent a backend from potentially removing a
@ -569,7 +559,7 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> {
/// Static variables in Rust can be annotated with the `#[used]` attribute to direct the `rustc`
/// compiler to mark the variable as a "used global".
fn add_used_global(&self, _global: Self::Value) {
todo!()
// TODO: Ignore for now.
}
}
@ -992,7 +982,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
res
}
Primitive::Pointer => {
panic!("TODO: scalar_to_backend Primitive::Ptr not implemented yet")
panic!("scalar_to_backend Primitive::Ptr is an invalid state")
}
},
Scalar::Ptr(ptr) => {
@ -1042,7 +1032,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
},
) => {
if a_space != b_space {
// TODO: make sure the type here is right.
// TODO: Emit the correct type that is passed into this function.
self.poison(value.def);
}
assert_ty_eq!(self, a, b);
@ -1068,8 +1058,14 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
}
fn const_ptrcast(&self, val: Self::Value, ty: Self::Type) -> Self::Value {
// TODO: hack to get things working, this *will* fail spirv-val
val.def.with_type(ty)
if val.ty == ty {
val
} else {
// constant ptrcast is not supported in spir-v
let result = val.def.with_type(ty);
self.poison(result.def);
result
}
}
}

View File

@ -1,6 +1,8 @@
use rspirv::dr::{Instruction, Module, Operand};
use rspirv::dr::{Block, Function, Instruction, Module, Operand};
use rspirv::spirv::{Op, Word};
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::iter::once;
use std::mem::replace;
fn contains_poison(inst: &Instruction, poison: &HashSet<Word>) -> bool {
inst.result_type.map_or(false, |w| poison.contains(&w))
@ -21,7 +23,7 @@ fn is_poison(inst: &Instruction, poison: &HashSet<Word>) -> bool {
}
pub fn poison_pass(module: &mut Module, poison: &mut HashSet<Word>) {
// TODO: This is O(n^2), can we speed it up?
// Note: This is O(n^2).
while spread_poison(module, poison) {}
if option_env!("PRINT_POISON").is_some() {
@ -118,3 +120,97 @@ fn spread_poison(module: &mut Module, poison: &mut HashSet<Word>) -> bool {
}
any
}
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
pub fn block_ordering_pass(func: &mut Function) {
if func.blocks.len() < 2 {
return;
}
let mut graph = func
.blocks
.iter()
.map(|block| {
(
block.label.as_ref().unwrap().result_id.unwrap(),
outgoing_edges(block),
)
})
.collect();
let entry_label = func.blocks[0].label.as_ref().unwrap().result_id.unwrap();
delete_backedges(&mut graph, entry_label);
let mut sorter = topological_sort::TopologicalSort::<Word>::new();
for (key, values) in graph {
for value in values {
sorter.add_dependency(key, value);
}
}
let mut old_blocks = replace(&mut func.blocks, Vec::new());
while let Some(item) = sorter.pop() {
let index = old_blocks
.iter()
.position(|b| b.label.as_ref().unwrap().result_id.unwrap() == item)
.unwrap();
func.blocks.push(old_blocks.remove(index));
}
assert!(sorter.is_empty());
assert!(old_blocks.is_empty());
assert_eq!(
func.blocks[0].label.as_ref().unwrap().result_id.unwrap(),
entry_label,
"Topo sorter did something weird (unreachable blocks?)"
);
}
fn outgoing_edges(block: &Block) -> Vec<Word> {
fn unwrap_id_ref(operand: &Operand) -> Word {
match *operand {
Operand::IdRef(word) => word,
_ => panic!("Expected Operand::IdRef: {}", operand),
}
}
let terminator = block.instructions.last().unwrap();
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Termination
match terminator.class.opcode {
Op::Branch => vec![unwrap_id_ref(&terminator.operands[0])],
Op::BranchConditional => vec![
unwrap_id_ref(&terminator.operands[1]),
unwrap_id_ref(&terminator.operands[2]),
],
Op::Switch => once(unwrap_id_ref(&terminator.operands[1]))
.chain(
terminator.operands[3..]
.iter()
.step_by(2)
.map(unwrap_id_ref),
)
.collect(),
Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => Vec::new(),
_ => panic!("Invalid block terminator: {:?}", terminator),
}
}
fn delete_backedges(graph: &mut HashMap<Word, Vec<Word>>, entry: Word) {
// TODO: This has extremely bad runtime
let mut backedges = HashSet::new();
fn re(
graph: &HashMap<Word, Vec<Word>>,
entry: Word,
stack: &mut Vec<Word>,
backedges: &mut HashSet<(Word, Word)>,
) {
stack.push(entry);
for &item in &graph[&entry] {
if stack.contains(&item) {
backedges.insert((entry, item));
} else if !backedges.contains(&(entry, item)) {
re(graph, item, stack, backedges);
}
}
assert_eq!(stack.pop(), Some(entry));
}
re(graph, entry, &mut Vec::new(), &mut backedges);
for (from, to) in backedges {
graph.get_mut(&from).unwrap().retain(|&o| o != to);
}
}

View File

@ -19,8 +19,8 @@ mod abi;
mod builder;
mod builder_spirv;
mod codegen_cx;
mod finalizing_passes;
mod link;
mod poison_pass;
mod spirv_type;
mod things;

View File

@ -286,8 +286,8 @@ impl SpirvType {
pub fn memset_const_pattern<'tcx>(&self, cx: &CodegenCx<'tcx>, fill_byte: u8) -> Word {
match *self {
SpirvType::Void => panic!("TODO: void memset not implemented yet"),
SpirvType::Bool => panic!("TODO: bool memset not implemented yet"),
SpirvType::Void => panic!("memset invalid on void pattern"),
SpirvType::Bool => panic!("memset invalid on bool pattern"),
SpirvType::Integer(width, _signedness) => match width {
8 => cx.constant_u8(fill_byte).def,
16 => cx.constant_u16(memset_fill_u16(fill_byte)).def,
@ -333,8 +333,8 @@ impl SpirvType {
fill_var: Word,
) -> Word {
match *self {
SpirvType::Void => panic!("TODO: void memset not implemented yet"),
SpirvType::Bool => panic!("TODO: bool memset not implemented yet"),
SpirvType::Void => panic!("memset invalid on void pattern"),
SpirvType::Bool => panic!("memset invalid on bool pattern"),
SpirvType::Integer(width, _signedness) => match width {
8 => fill_var,
16 => memset_dynamic_scalar(builder, fill_var, 2, false),