diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index e5a8a3ecf8..4378965e7d 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -137,7 +137,8 @@ impl<'tcx> CodegenCx<'tcx> { }) .collect::>(); let mut bx = Builder::new_block(self, stub_fn, ""); - // Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s). + // Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s), + // to match the argument type we have to pass to the Rust entry `fn`. let arguments: Vec<_> = interface_globals .iter() .zip(arg_abis) @@ -146,24 +147,42 @@ impl<'tcx> CodegenCx<'tcx> { |((&(global_var, storage_class), entry_fn_arg), hir_param)| { bx.set_span(hir_param.span); + let var_value_spirv_type = match self.lookup_type(global_var.ty) { + SpirvType::Pointer { pointee } => pointee, + _ => unreachable!(), + }; + let mut dst_len_arg = None; let arg = match entry_fn_arg.layout.ty.kind() { - TyKind::Ref(_, ty, _) => { - if !ty.is_sized(self.tcx.at(span), self.param_env()) { - dst_len_arg.replace( - self.dst_length_argument(&mut bx, ty, hir_param, global_var), - ); + TyKind::Ref(_, pointee_ty, _) => { + let arg_pointee_spirv_type = + self.layout_of(pointee_ty).spirv_type(hir_param.span, self); + + assert_ty_eq!(self, arg_pointee_spirv_type, var_value_spirv_type); + + if !pointee_ty.is_sized(self.tcx.at(span), self.param_env()) { + dst_len_arg.replace(self.dst_length_argument( + &mut bx, pointee_ty, hir_param, global_var, + )); } global_var } - _ => match entry_fn_arg.mode { - PassMode::Indirect { .. } => global_var, - PassMode::Direct(_) => { - assert_eq!(storage_class, StorageClass::Input); - bx.load(global_var, entry_fn_arg.layout.align.abi) + _ => { + assert_eq!(storage_class, StorageClass::Input); + + let arg_spirv_type = + entry_fn_arg.layout.spirv_type(hir_param.span, self); + + assert_ty_eq!(self, arg_spirv_type, var_value_spirv_type); + + match entry_fn_arg.mode { + PassMode::Indirect { .. } => global_var, + PassMode::Direct(_) => { + bx.load(global_var, entry_fn_arg.layout.align.abi) + } + _ => unreachable!(), } - _ => unreachable!(), - }, + } }; std::iter::once(arg).chain(dst_len_arg) },