entry: move call arg ajustments to declare_shader_interface_for_param.

This commit is contained in:
Eduard-Mihai Burtescu 2021-04-07 13:58:24 +03:00 committed by Eduard-Mihai Burtescu
parent 630e5a61d8
commit 3f641638c8

View File

@ -129,102 +129,23 @@ impl<'tcx> CodegenCx<'tcx> {
};
let mut op_entry_point_interface_operands = vec![];
let mut decoration_locations = HashMap::new();
let interface_globals = arg_abis
.iter()
.zip(hir_params)
.map(|(entry_fn_arg, hir_param)| {
self.declare_shader_interface_for_param(
entry_fn_arg.layout,
hir_param,
&mut op_entry_point_interface_operands,
&mut decoration_locations,
)
})
.collect::<Vec<_>>();
let mut bx = Builder::new_block(self, stub_fn, "");
// Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s,
// or accessing the sole field of an "interface block" `OpTypeStruct`),
// to match the argument type we have to pass to the Rust entry `fn`.
let arguments: Vec<_> = interface_globals
.iter()
.zip(arg_abis)
.zip(hir_params)
.flat_map(
|((&(global_var, storage_class), entry_fn_arg), hir_param)| {
bx.set_span(hir_param.span);
let var_value_spirv_type = match self.lookup_type(global_var.ty) {
SpirvType::Pointer { pointee } => pointee,
_ => unreachable!(),
};
let (first, second) = match entry_fn_arg.layout.ty.kind() {
TyKind::Ref(_, pointee_ty, _) => {
let arg_pointee_spirv_type = self
.layout_of(pointee_ty)
.spirv_type(hir_param.ty_span, self);
if let SpirvType::InterfaceBlock { inner_type } =
self.lookup_type(var_value_spirv_type)
{
assert_ty_eq!(self, arg_pointee_spirv_type, inner_type);
let inner = bx.struct_gep(global_var, 0);
match entry_fn_arg.mode {
PassMode::Direct(_) => (inner, None),
// Unsized pointee with length (i.e. `&[T]`).
PassMode::Pair(..) => {
// FIXME(eddyb) shouldn't this be `usize`?
let len_spirv_type = self.type_isize();
let len = bx
.emit()
.array_length(
len_spirv_type,
None,
global_var.def(&bx),
0,
)
.unwrap()
.with_type(len_spirv_type);
(inner, Some(len))
}
_ => unreachable!(),
}
} else {
assert_ty_eq!(self, arg_pointee_spirv_type, var_value_spirv_type);
assert_matches!(entry_fn_arg.mode, PassMode::Direct(_));
(global_var, None)
}
}
_ => {
assert_eq!(storage_class, StorageClass::Input);
let arg_spirv_type =
entry_fn_arg.layout.spirv_type(hir_param.ty_span, self);
assert_ty_eq!(self, arg_spirv_type, var_value_spirv_type);
match entry_fn_arg.mode {
PassMode::Indirect { .. } => (global_var, None),
PassMode::Direct(_) => {
(bx.load(global_var, entry_fn_arg.layout.align.abi), None)
}
_ => unreachable!(),
}
}
};
std::iter::once(first).chain(second)
},
let mut call_args = vec![];
let mut decoration_locations = HashMap::new();
for (entry_arg_abi, hir_param) in arg_abis.iter().zip(hir_params) {
bx.set_span(hir_param.span);
self.declare_shader_interface_for_param(
entry_arg_abi,
hir_param,
&mut op_entry_point_interface_operands,
&mut bx,
&mut call_args,
&mut decoration_locations,
)
.collect();
}
bx.set_span(span);
bx.call(entry_func, &arguments, None);
bx.call(entry_func, &call_args, None);
bx.ret_void();
let stub_fn_id = stub_fn.def_cx(self);
@ -336,26 +257,38 @@ impl<'tcx> CodegenCx<'tcx> {
fn declare_shader_interface_for_param(
&self,
layout: TyAndLayout<'tcx>,
entry_arg_abi: &ArgAbi<'tcx, Ty<'tcx>>,
hir_param: &hir::Param<'tcx>,
op_entry_point_interface_operands: &mut Vec<Word>,
bx: &mut Builder<'_, 'tcx>,
call_args: &mut Vec<SpirvValue>,
decoration_locations: &mut HashMap<StorageClass, u32>,
) -> (SpirvValue, StorageClass) {
) {
let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.hir().attrs(hir_param.hir_id));
// Pre-allocate the module-scoped `OpVariable`'s *Result* ID.
let var = self.emit_global().id();
let (mut value_spirv_type, storage_class) =
self.infer_param_ty_and_storage_class(layout, hir_param, &attrs);
let (value_spirv_type, storage_class) =
self.infer_param_ty_and_storage_class(entry_arg_abi.layout, hir_param, &attrs);
// Certain storage classes require an `OpTypeStruct` decorated with `Block`,
// which we represent with `SpirvType::InterfaceBlock` (see its doc comment).
// This "interface block" construct is also required for "runtime arrays".
let is_unsized = self.lookup_type(value_spirv_type).sizeof(self).is_none();
match storage_class {
let var_ptr_spirv_type;
let (value_ptr, value_len) = match storage_class {
StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer => {
if is_unsized {
var_ptr_spirv_type = self.type_ptr_to(
SpirvType::InterfaceBlock {
inner_type: value_spirv_type,
}
.def(hir_param.span, self),
);
let value_ptr = bx.struct_gep(var.with_type(var_ptr_spirv_type), 0);
let value_len = if is_unsized {
match self.lookup_type(value_spirv_type) {
SpirvType::RuntimeArray { .. } => {}
_ => self.tcx.sess.span_err(
@ -363,14 +296,24 @@ impl<'tcx> CodegenCx<'tcx> {
"only plain slices are supported as unsized types",
),
}
}
value_spirv_type = SpirvType::InterfaceBlock {
inner_type: value_spirv_type,
}
.def(hir_param.span, self);
// FIXME(eddyb) shouldn't this be `usize`?
let len_spirv_type = self.type_isize();
let len = bx
.emit()
.array_length(len_spirv_type, None, var, 0)
.unwrap();
Some(len.with_type(len_spirv_type))
} else {
None
};
(value_ptr, value_len)
}
_ => {
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
if is_unsized {
self.tcx.sess.span_fatal(
hir_param.ty_span,
@ -380,7 +323,30 @@ impl<'tcx> CodegenCx<'tcx> {
),
);
}
(var.with_type(var_ptr_spirv_type), None)
}
};
// Compute call argument(s) to match what the Rust entry `fn` expects,
// starting from the `value_ptr` pointing to a `value_spirv_type`
// (e.g. `Input` doesn't use indirection, so we have to load from it).
if let TyKind::Ref(..) = entry_arg_abi.layout.ty.kind() {
call_args.push(value_ptr);
match entry_arg_abi.mode {
PassMode::Direct(_) => assert_eq!(value_len, None),
PassMode::Pair(..) => call_args.push(value_len.unwrap()),
_ => unreachable!(),
}
} else {
assert_eq!(storage_class, StorageClass::Input);
call_args.push(match entry_arg_abi.mode {
PassMode::Indirect { .. } => value_ptr,
PassMode::Direct(_) => bx.load(value_ptr, entry_arg_abi.layout.align.abi),
_ => unreachable!(),
});
assert_eq!(value_len, None);
}
// FIXME(eddyb) check whether the storage class is compatible with the
@ -457,12 +423,8 @@ impl<'tcx> CodegenCx<'tcx> {
}
// Emit the `OpVariable` with its *Result* ID set to `var`.
let var_spirv_type = SpirvType::Pointer {
pointee: value_spirv_type,
}
.def(hir_param.span, self);
self.emit_global()
.variable(var_spirv_type, Some(var), storage_class, None);
.variable(var_ptr_spirv_type, Some(var), storage_class, None);
// Record this `OpVariable` as needing to be added (if applicable),
// to the *Interface* operands of the `OpEntryPoint` instruction.
@ -475,8 +437,6 @@ impl<'tcx> CodegenCx<'tcx> {
op_entry_point_interface_operands.push(var);
}
}
(var.with_type(var_spirv_type), storage_class)
}
// Kernel mode takes its interface as function parameters(??)