diff --git a/rustc_codegen_spirv/src/ctx.rs b/rustc_codegen_spirv/src/ctx.rs index 43afc62212..2f9b232d09 100644 --- a/rustc_codegen_spirv/src/ctx.rs +++ b/rustc_codegen_spirv/src/ctx.rs @@ -80,6 +80,12 @@ impl<'tcx> Context<'tcx> { } } } + + /// rspirv doesn't cache type_pointer, so cache it ourselves here. + pub fn type_pointer(&mut self, storage_class: StorageClass, pointee_type: Word) -> Word { + self.spirv_helper + .type_pointer(&mut self.spirv, storage_class, pointee_type) + } } /// FnCtx is the "bag of random variables" used for state when compiling a particular function - i.e. variables that are @@ -128,10 +134,8 @@ impl<'ctx, 'tcx> FnCtx<'ctx, 'tcx> { } /// rspirv doesn't cache type_pointer, so cache it ourselves here. - pub fn type_pointer(&mut self, pointee_type: Word) -> Word { - self.ctx - .spirv_helper - .type_pointer(&mut self.ctx.spirv, pointee_type) + pub fn type_pointer(&mut self, storage_class: StorageClass, pointee_type: Word) -> Word { + self.ctx.type_pointer(storage_class, pointee_type) } // copied from rustc_codegen_cranelift @@ -168,7 +172,7 @@ impl ForwardReference { } struct SpirvHelper { - pointer: HashMap, + pointer: HashMap<(Word, StorageClass), Word>, } impl SpirvHelper { @@ -178,11 +182,15 @@ impl SpirvHelper { } } - fn type_pointer(&mut self, spirv: &mut Builder, pointee_type: Word) -> Word { - // TODO: StorageClass + fn type_pointer( + &mut self, + spirv: &mut Builder, + storage_class: StorageClass, + pointee_type: Word, + ) -> Word { *self .pointer - .entry(pointee_type) - .or_insert_with(|| spirv.type_pointer(None, StorageClass::Generic, pointee_type)) + .entry((pointee_type, storage_class)) + .or_insert_with(|| spirv.type_pointer(None, storage_class, pointee_type)) } } diff --git a/rustc_codegen_spirv/src/trans.rs b/rustc_codegen_spirv/src/trans.rs index f060af929e..c8fec1b3fd 100644 --- a/rustc_codegen_spirv/src/trans.rs +++ b/rustc_codegen_spirv/src/trans.rs @@ -85,9 +85,7 @@ pub fn trans_locals<'tcx>(ctx: &mut FnCtx<'_, 'tcx>, parameters_spirv: Vec for (local, decl) in ctx.body.local_decls.iter_enumerated() { // TODO: ZSTs let local_type = trans_type(ctx, decl.ty); - let local_ptr_type_def = - ctx.spirv() - .type_pointer(None, StorageClass::Function, local_type.def()); + let local_ptr_type_def = ctx.type_pointer(StorageClass::Function, local_type.def()); let local_ptr_type = SpirvType::Pointer { def: local_ptr_type_def, pointee: Box::new(local_type), @@ -136,7 +134,7 @@ fn trans_type<'tcx>(ctx: &mut FnCtx<'_, 'tcx>, ty: Ty<'tcx>) -> SpirvType { let pointee_type = trans_type(ctx, type_and_mut.ty); // note: use custom cache SpirvType::Pointer { - def: ctx.type_pointer(pointee_type.def()), + def: ctx.type_pointer(StorageClass::Generic, pointee_type.def()), pointee: Box::new(pointee_type), } } @@ -144,7 +142,7 @@ fn trans_type<'tcx>(ctx: &mut FnCtx<'_, 'tcx>, ty: Ty<'tcx>) -> SpirvType { let pointee_type = trans_type(ctx, pointee_ty); // note: use custom cache SpirvType::Pointer { - def: ctx.type_pointer(pointee_type.def()), + def: ctx.type_pointer(StorageClass::Generic, pointee_type.def()), pointee: Box::new(pointee_type), } } @@ -534,10 +532,9 @@ fn trans_place<'tcx>(ctx: &mut FnCtx<'_, 'tcx>, place: &Place<'tcx>) -> (Word, S let pointer = if access_chain.is_empty() { local.def } else { - let result_ptr_type = - ctx.ctx - .spirv - .type_pointer(None, StorageClass::Function, result_type.def()); + let result_ptr_type = ctx + .ctx + .type_pointer(StorageClass::Function, result_type.def()); ctx.ctx .spirv .access_chain(result_ptr_type, None, local.def, access_chain) diff --git a/rustc_codegen_spirv/test/lib.rs b/rustc_codegen_spirv/test/lib.rs index 33f3079ca8..cbf8e0aeb6 100644 --- a/rustc_codegen_spirv/test/lib.rs +++ b/rustc_codegen_spirv/test/lib.rs @@ -78,7 +78,7 @@ fn go(code: &str, expected: &str) { let cmd = Command::new("rustc") .args(&[ - #[cfg(target_os = "unix")] + #[cfg(unix)] "-Zcodegen-backend=target/debug/librustc_codegen_spirv.so", #[cfg(target_os = "windows")] "-Zcodegen-backend=target/debug/rustc_codegen_spirv.dll",