diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 48ab5c745c..d96ed8d991 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -129,102 +129,23 @@ impl<'tcx> CodegenCx<'tcx> { }; let mut op_entry_point_interface_operands = vec![]; - let mut decoration_locations = HashMap::new(); - let interface_globals = arg_abis - .iter() - .zip(hir_params) - .map(|(entry_fn_arg, hir_param)| { - self.declare_shader_interface_for_param( - entry_fn_arg.layout, - hir_param, - &mut op_entry_point_interface_operands, - &mut decoration_locations, - ) - }) - .collect::>(); + let mut bx = Builder::new_block(self, stub_fn, ""); - // Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s, - // or accessing the sole field of an "interface block" `OpTypeStruct`), - // to match the argument type we have to pass to the Rust entry `fn`. - let arguments: Vec<_> = interface_globals - .iter() - .zip(arg_abis) - .zip(hir_params) - .flat_map( - |((&(global_var, storage_class), entry_fn_arg), hir_param)| { - bx.set_span(hir_param.span); - - let var_value_spirv_type = match self.lookup_type(global_var.ty) { - SpirvType::Pointer { pointee } => pointee, - _ => unreachable!(), - }; - - let (first, second) = match entry_fn_arg.layout.ty.kind() { - TyKind::Ref(_, pointee_ty, _) => { - let arg_pointee_spirv_type = self - .layout_of(pointee_ty) - .spirv_type(hir_param.ty_span, self); - - if let SpirvType::InterfaceBlock { inner_type } = - self.lookup_type(var_value_spirv_type) - { - assert_ty_eq!(self, arg_pointee_spirv_type, inner_type); - - let inner = bx.struct_gep(global_var, 0); - - match entry_fn_arg.mode { - PassMode::Direct(_) => (inner, None), - - // Unsized pointee with length (i.e. `&[T]`). - PassMode::Pair(..) => { - // FIXME(eddyb) shouldn't this be `usize`? - let len_spirv_type = self.type_isize(); - - let len = bx - .emit() - .array_length( - len_spirv_type, - None, - global_var.def(&bx), - 0, - ) - .unwrap() - .with_type(len_spirv_type); - - (inner, Some(len)) - } - - _ => unreachable!(), - } - } else { - assert_ty_eq!(self, arg_pointee_spirv_type, var_value_spirv_type); - assert_matches!(entry_fn_arg.mode, PassMode::Direct(_)); - (global_var, None) - } - } - _ => { - assert_eq!(storage_class, StorageClass::Input); - - let arg_spirv_type = - entry_fn_arg.layout.spirv_type(hir_param.ty_span, self); - - assert_ty_eq!(self, arg_spirv_type, var_value_spirv_type); - - match entry_fn_arg.mode { - PassMode::Indirect { .. } => (global_var, None), - PassMode::Direct(_) => { - (bx.load(global_var, entry_fn_arg.layout.align.abi), None) - } - _ => unreachable!(), - } - } - }; - std::iter::once(first).chain(second) - }, + let mut call_args = vec![]; + let mut decoration_locations = HashMap::new(); + for (entry_arg_abi, hir_param) in arg_abis.iter().zip(hir_params) { + bx.set_span(hir_param.span); + self.declare_shader_interface_for_param( + entry_arg_abi, + hir_param, + &mut op_entry_point_interface_operands, + &mut bx, + &mut call_args, + &mut decoration_locations, ) - .collect(); + } bx.set_span(span); - bx.call(entry_func, &arguments, None); + bx.call(entry_func, &call_args, None); bx.ret_void(); let stub_fn_id = stub_fn.def_cx(self); @@ -336,26 +257,38 @@ impl<'tcx> CodegenCx<'tcx> { fn declare_shader_interface_for_param( &self, - layout: TyAndLayout<'tcx>, + entry_arg_abi: &ArgAbi<'tcx, Ty<'tcx>>, hir_param: &hir::Param<'tcx>, op_entry_point_interface_operands: &mut Vec, + bx: &mut Builder<'_, 'tcx>, + call_args: &mut Vec, decoration_locations: &mut HashMap, - ) -> (SpirvValue, StorageClass) { + ) { let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.hir().attrs(hir_param.hir_id)); // Pre-allocate the module-scoped `OpVariable`'s *Result* ID. let var = self.emit_global().id(); - let (mut value_spirv_type, storage_class) = - self.infer_param_ty_and_storage_class(layout, hir_param, &attrs); + let (value_spirv_type, storage_class) = + self.infer_param_ty_and_storage_class(entry_arg_abi.layout, hir_param, &attrs); // Certain storage classes require an `OpTypeStruct` decorated with `Block`, // which we represent with `SpirvType::InterfaceBlock` (see its doc comment). // This "interface block" construct is also required for "runtime arrays". let is_unsized = self.lookup_type(value_spirv_type).sizeof(self).is_none(); - match storage_class { + let var_ptr_spirv_type; + let (value_ptr, value_len) = match storage_class { StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer => { - if is_unsized { + var_ptr_spirv_type = self.type_ptr_to( + SpirvType::InterfaceBlock { + inner_type: value_spirv_type, + } + .def(hir_param.span, self), + ); + + let value_ptr = bx.struct_gep(var.with_type(var_ptr_spirv_type), 0); + + let value_len = if is_unsized { match self.lookup_type(value_spirv_type) { SpirvType::RuntimeArray { .. } => {} _ => self.tcx.sess.span_err( @@ -363,14 +296,24 @@ impl<'tcx> CodegenCx<'tcx> { "only plain slices are supported as unsized types", ), } - } - value_spirv_type = SpirvType::InterfaceBlock { - inner_type: value_spirv_type, - } - .def(hir_param.span, self); + // FIXME(eddyb) shouldn't this be `usize`? + let len_spirv_type = self.type_isize(); + let len = bx + .emit() + .array_length(len_spirv_type, None, var, 0) + .unwrap(); + + Some(len.with_type(len_spirv_type)) + } else { + None + }; + + (value_ptr, value_len) } _ => { + var_ptr_spirv_type = self.type_ptr_to(value_spirv_type); + if is_unsized { self.tcx.sess.span_fatal( hir_param.ty_span, @@ -380,7 +323,30 @@ impl<'tcx> CodegenCx<'tcx> { ), ); } + + (var.with_type(var_ptr_spirv_type), None) } + }; + + // Compute call argument(s) to match what the Rust entry `fn` expects, + // starting from the `value_ptr` pointing to a `value_spirv_type` + // (e.g. `Input` doesn't use indirection, so we have to load from it). + if let TyKind::Ref(..) = entry_arg_abi.layout.ty.kind() { + call_args.push(value_ptr); + match entry_arg_abi.mode { + PassMode::Direct(_) => assert_eq!(value_len, None), + PassMode::Pair(..) => call_args.push(value_len.unwrap()), + _ => unreachable!(), + } + } else { + assert_eq!(storage_class, StorageClass::Input); + + call_args.push(match entry_arg_abi.mode { + PassMode::Indirect { .. } => value_ptr, + PassMode::Direct(_) => bx.load(value_ptr, entry_arg_abi.layout.align.abi), + _ => unreachable!(), + }); + assert_eq!(value_len, None); } // FIXME(eddyb) check whether the storage class is compatible with the @@ -457,12 +423,8 @@ impl<'tcx> CodegenCx<'tcx> { } // Emit the `OpVariable` with its *Result* ID set to `var`. - let var_spirv_type = SpirvType::Pointer { - pointee: value_spirv_type, - } - .def(hir_param.span, self); self.emit_global() - .variable(var_spirv_type, Some(var), storage_class, None); + .variable(var_ptr_spirv_type, Some(var), storage_class, None); // Record this `OpVariable` as needing to be added (if applicable), // to the *Interface* operands of the `OpEntryPoint` instruction. @@ -475,8 +437,6 @@ impl<'tcx> CodegenCx<'tcx> { op_entry_point_interface_operands.push(var); } } - - (var.with_type(var_spirv_type), storage_class) } // Kernel mode takes its interface as function parameters(??)