diff --git a/rustc_codegen_spirv/src/abi.rs b/rustc_codegen_spirv/src/abi.rs index 5987327c7b..a9b210bb3b 100644 --- a/rustc_codegen_spirv/src/abi.rs +++ b/rustc_codegen_spirv/src/abi.rs @@ -3,7 +3,6 @@ use rspirv::spirv::{StorageClass, Word}; use rustc_middle::ty::{layout::TyAndLayout, TyKind}; use rustc_target::abi::{Abi, FieldsShape, LayoutOf, Primitive, Scalar, Size}; use std::fmt; -use std::iter::empty; #[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)] pub enum SpirvType { @@ -11,9 +10,6 @@ pub enum SpirvType { Bool, Integer(u32, bool), Float(u32), - // TODO: Do we fold this into Adt? - /// Zero Sized Type - ZST, /// This uses the rustc definition of "adt", i.e. a struct, enum, or union Adt { // TODO: enums/unions @@ -40,6 +36,10 @@ pub enum SpirvType { impl SpirvType { /// Note: Builder::type_* should be called *nowhere else* but here, to ensure CodegenCx::type_defs stays up-to-date pub fn def<'spv, 'tcx>(&self, cx: &CodegenCx<'spv, 'tcx>) -> Word { + if let Some(&cached) = cx.type_cache.borrow().get(self) { + return cached; + } + //let cached = cx.type_cache.borrow_mut().entry(self); // TODO: rspirv does a linear search to dedupe, probably want to cache here. let result = match *self { SpirvType::Void => cx.emit_global().type_void(), @@ -48,7 +48,6 @@ impl SpirvType { .emit_global() .type_int(width, if signedness { 1 } else { 0 }), SpirvType::Float(width) => cx.emit_global().type_float(width), - SpirvType::ZST => cx.emit_global().type_struct(empty()), SpirvType::Adt { ref field_types } => { cx.emit_global().type_struct(field_types.iter().cloned()) } @@ -65,10 +64,23 @@ impl SpirvType { .emit_global() .type_function(return_type, arguments.iter().cloned()), }; - cx.type_defs - .borrow_mut() - .entry(result) - .or_insert_with(|| self.clone()); + // Change to expect_none if/when stabilized + assert!( + cx.type_defs + .borrow_mut() + .insert(result, self.clone()) + .is_none(), + "type_defs already had entry, caching failed? {:#?}", + self.clone().debug(cx) + ); + assert!( + cx.type_cache + .borrow_mut() + .insert(self.clone(), result) + .is_none(), + "type_cache already had entry, caching failed? {:#?}", + self.clone().debug(cx) + ); result } @@ -96,7 +108,6 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> { .field("signedness", &signedness) .finish(), SpirvType::Float(width) => f.debug_struct("Float").field("width", &width).finish(), - SpirvType::ZST => f.debug_struct("ZST").finish(), SpirvType::Adt { ref field_types } => { let fields = field_types .iter() @@ -155,7 +166,11 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> { pub fn trans_type<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>) -> Word { if ty.is_zst() { - return SpirvType::ZST.def(cx); + // An empty struct is zero-sized + return SpirvType::Adt { + field_types: Vec::new(), + } + .def(cx); } // Note: ty.abi is orthogonal to ty.variants and ty.fields, e.g. `ManuallyDrop>` diff --git a/rustc_codegen_spirv/src/builder/builder_methods.rs b/rustc_codegen_spirv/src/builder/builder_methods.rs index 7f2d3872d9..2020198413 100644 --- a/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -16,10 +16,22 @@ use rustc_target::abi::{Abi, Align, Size}; use std::iter::empty; use std::ops::Range; +macro_rules! assert_ty_eq { + ($codegen_cx:expr, $left:expr, $right:expr) => { + assert_eq!( + $left, + $right, + "Expected types to be equal:\n{:#?}\n==\n{:#?}", + $codegen_cx.debug_type($left), + $codegen_cx.debug_type($right) + ) + }; +} + macro_rules! simple_op { ($func_name:ident, $inst_name:ident) => { fn $func_name(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value { - assert_eq!(lhs.ty, rhs.ty); + assert_ty_eq!(self, lhs.ty, rhs.ty); let result_type = lhs.ty; self.emit() .$inst_name(result_type, None, lhs.def, rhs.def) @@ -313,7 +325,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> { } => pointee, ty => panic!("store called on variable that wasn't a pointer: {:?}", ty), }; - assert_eq!(ptr_elem_ty, val.ty); + assert_ty_eq!(self, ptr_elem_ty, val.ty); self.emit().store(ptr.def, val.def, None, empty()).unwrap(); val } @@ -507,7 +519,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> { // TODO: Do we want to assert signedness matches the opcode? Is it possible to have one that doesn't match? Does // spir-v allow nonmatching instructions? use IntPredicate::*; - assert_eq!(lhs.ty, rhs.ty); + assert_ty_eq!(self, lhs.ty, rhs.ty); let b = SpirvType::Bool.def(self); let mut e = self.emit(); match op { @@ -528,7 +540,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> { fn fcmp(&mut self, op: RealPredicate, lhs: Self::Value, rhs: Self::Value) -> Self::Value { use RealPredicate::*; - assert_eq!(lhs.ty, rhs.ty); + assert_ty_eq!(self, lhs.ty, rhs.ty); let b = SpirvType::Bool.def(self); let mut e = self.emit(); match op { @@ -594,7 +606,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> { then_val: Self::Value, else_val: Self::Value, ) -> Self::Value { - assert_eq!(then_val.ty, else_val.ty); + assert_ty_eq!(self, then_val.ty, else_val.ty); let result_type = then_val.ty; self.emit() .select(result_type, None, cond.def, then_val.def, else_val.def) @@ -659,7 +671,9 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> { fn insert_value(&mut self, agg_val: Self::Value, elt: Self::Value, idx: u64) -> Self::Value { match self.lookup_type(agg_val.ty) { - SpirvType::Adt { field_types } => assert_eq!(field_types[idx as usize], elt.ty), + SpirvType::Adt { field_types } => { + assert_ty_eq!(self, field_types[idx as usize], elt.ty) + } other => panic!("insert_value not implemented on type {:?}", other), }; self.emit() diff --git a/rustc_codegen_spirv/src/builder/mod.rs b/rustc_codegen_spirv/src/builder/mod.rs index 8e4563c6cf..f11d95e9f5 100644 --- a/rustc_codegen_spirv/src/builder/mod.rs +++ b/rustc_codegen_spirv/src/builder/mod.rs @@ -202,11 +202,25 @@ impl<'a, 'spv, 'tcx> ArgAbiMethods<'tcx> for Builder<'a, 'spv, 'tcx> { fn store_arg( &mut self, - _arg_abi: &ArgAbi<'tcx, Ty<'tcx>>, - _val: Self::Value, - _dst: PlaceRef<'tcx, Self::Value>, + arg_abi: &ArgAbi<'tcx, Ty<'tcx>>, + val: Self::Value, + dst: PlaceRef<'tcx, Self::Value>, ) { - todo!() + if arg_abi.is_ignore() { + return; + } + if arg_abi.is_sized_indirect() { + OperandValue::Ref(val, None, arg_abi.layout.align.abi).store(self, dst) + } else if arg_abi.is_unsized_indirect() { + panic!("unsized `ArgAbi` must be handled through `store_fn_arg`"); + } else if let PassMode::Cast(cast) = arg_abi.mode { + panic!( + "TODO: PassMode::Cast not implemented yet for store_arg: {:?}", + cast + ); + } else { + OperandValue::Immediate(val).store(self, dst); + } } fn arg_memory_ty(&self, arg_abi: &ArgAbi<'tcx, Ty<'tcx>>) -> Self::Type { diff --git a/rustc_codegen_spirv/src/builder_spirv.rs b/rustc_codegen_spirv/src/builder_spirv.rs index 4946ce6630..42c2ef50a9 100644 --- a/rustc_codegen_spirv/src/builder_spirv.rs +++ b/rustc_codegen_spirv/src/builder_spirv.rs @@ -133,6 +133,7 @@ impl BuilderSpirv { fn def_constant(&self, ty: Word, val: Operand) -> Word { let mut builder = self.builder.borrow_mut(); + // TODO: Cache these instead of doing a search. for inst in &builder.module_ref().types_global_values { if inst.class.opcode == Op::Constant && inst.result_type == Some(ty) diff --git a/rustc_codegen_spirv/src/codegen_cx.rs b/rustc_codegen_spirv/src/codegen_cx.rs index e1079d50ad..a3a4888ee7 100644 --- a/rustc_codegen_spirv/src/codegen_cx.rs +++ b/rustc_codegen_spirv/src/codegen_cx.rs @@ -33,11 +33,18 @@ use std::collections::HashMap; pub struct CodegenCx<'spv, 'tcx> { pub tcx: TyCtxt<'tcx>, pub codegen_unit: &'tcx CodegenUnit<'tcx>, + /// Not actually used much, builder is probably what you want (see comment on ModuleSpirv type) pub spirv_module: &'spv ModuleSpirv, + /// Spir-v module builder pub builder: BuilderSpirv, + /// Map from MIR function to spir-v function ID pub function_defs: RefCell, SpirvValue>>, + /// Map from function ID to parameter list pub function_parameter_values: RefCell>>, + /// Map from ID to structure pub type_defs: RefCell>, + /// Inverse of type_defs (used to cache generating types) + pub type_cache: RefCell>, } impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> { @@ -54,6 +61,7 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> { function_defs: RefCell::new(HashMap::new()), function_parameter_values: RefCell::new(HashMap::new()), type_defs: RefCell::new(HashMap::new()), + type_cache: RefCell::new(HashMap::new()), } }