mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-25 08:14:12 +00:00
entry: move call arg ajustments to declare_shader_interface_for_param
.
This commit is contained in:
parent
630e5a61d8
commit
3f641638c8
@ -129,102 +129,23 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
};
|
||||
|
||||
let mut op_entry_point_interface_operands = vec![];
|
||||
|
||||
let mut bx = Builder::new_block(self, stub_fn, "");
|
||||
let mut call_args = vec![];
|
||||
let mut decoration_locations = HashMap::new();
|
||||
let interface_globals = arg_abis
|
||||
.iter()
|
||||
.zip(hir_params)
|
||||
.map(|(entry_fn_arg, hir_param)| {
|
||||
for (entry_arg_abi, hir_param) in arg_abis.iter().zip(hir_params) {
|
||||
bx.set_span(hir_param.span);
|
||||
self.declare_shader_interface_for_param(
|
||||
entry_fn_arg.layout,
|
||||
entry_arg_abi,
|
||||
hir_param,
|
||||
&mut op_entry_point_interface_operands,
|
||||
&mut bx,
|
||||
&mut call_args,
|
||||
&mut decoration_locations,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut bx = Builder::new_block(self, stub_fn, "");
|
||||
// Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s,
|
||||
// or accessing the sole field of an "interface block" `OpTypeStruct`),
|
||||
// to match the argument type we have to pass to the Rust entry `fn`.
|
||||
let arguments: Vec<_> = interface_globals
|
||||
.iter()
|
||||
.zip(arg_abis)
|
||||
.zip(hir_params)
|
||||
.flat_map(
|
||||
|((&(global_var, storage_class), entry_fn_arg), hir_param)| {
|
||||
bx.set_span(hir_param.span);
|
||||
|
||||
let var_value_spirv_type = match self.lookup_type(global_var.ty) {
|
||||
SpirvType::Pointer { pointee } => pointee,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let (first, second) = match entry_fn_arg.layout.ty.kind() {
|
||||
TyKind::Ref(_, pointee_ty, _) => {
|
||||
let arg_pointee_spirv_type = self
|
||||
.layout_of(pointee_ty)
|
||||
.spirv_type(hir_param.ty_span, self);
|
||||
|
||||
if let SpirvType::InterfaceBlock { inner_type } =
|
||||
self.lookup_type(var_value_spirv_type)
|
||||
{
|
||||
assert_ty_eq!(self, arg_pointee_spirv_type, inner_type);
|
||||
|
||||
let inner = bx.struct_gep(global_var, 0);
|
||||
|
||||
match entry_fn_arg.mode {
|
||||
PassMode::Direct(_) => (inner, None),
|
||||
|
||||
// Unsized pointee with length (i.e. `&[T]`).
|
||||
PassMode::Pair(..) => {
|
||||
// FIXME(eddyb) shouldn't this be `usize`?
|
||||
let len_spirv_type = self.type_isize();
|
||||
|
||||
let len = bx
|
||||
.emit()
|
||||
.array_length(
|
||||
len_spirv_type,
|
||||
None,
|
||||
global_var.def(&bx),
|
||||
0,
|
||||
)
|
||||
.unwrap()
|
||||
.with_type(len_spirv_type);
|
||||
|
||||
(inner, Some(len))
|
||||
}
|
||||
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
assert_ty_eq!(self, arg_pointee_spirv_type, var_value_spirv_type);
|
||||
assert_matches!(entry_fn_arg.mode, PassMode::Direct(_));
|
||||
(global_var, None)
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
assert_eq!(storage_class, StorageClass::Input);
|
||||
|
||||
let arg_spirv_type =
|
||||
entry_fn_arg.layout.spirv_type(hir_param.ty_span, self);
|
||||
|
||||
assert_ty_eq!(self, arg_spirv_type, var_value_spirv_type);
|
||||
|
||||
match entry_fn_arg.mode {
|
||||
PassMode::Indirect { .. } => (global_var, None),
|
||||
PassMode::Direct(_) => {
|
||||
(bx.load(global_var, entry_fn_arg.layout.align.abi), None)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
};
|
||||
std::iter::once(first).chain(second)
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
bx.set_span(span);
|
||||
bx.call(entry_func, &arguments, None);
|
||||
bx.call(entry_func, &call_args, None);
|
||||
bx.ret_void();
|
||||
|
||||
let stub_fn_id = stub_fn.def_cx(self);
|
||||
@ -336,26 +257,38 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
|
||||
fn declare_shader_interface_for_param(
|
||||
&self,
|
||||
layout: TyAndLayout<'tcx>,
|
||||
entry_arg_abi: &ArgAbi<'tcx, Ty<'tcx>>,
|
||||
hir_param: &hir::Param<'tcx>,
|
||||
op_entry_point_interface_operands: &mut Vec<Word>,
|
||||
bx: &mut Builder<'_, 'tcx>,
|
||||
call_args: &mut Vec<SpirvValue>,
|
||||
decoration_locations: &mut HashMap<StorageClass, u32>,
|
||||
) -> (SpirvValue, StorageClass) {
|
||||
) {
|
||||
let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.hir().attrs(hir_param.hir_id));
|
||||
|
||||
// Pre-allocate the module-scoped `OpVariable`'s *Result* ID.
|
||||
let var = self.emit_global().id();
|
||||
|
||||
let (mut value_spirv_type, storage_class) =
|
||||
self.infer_param_ty_and_storage_class(layout, hir_param, &attrs);
|
||||
let (value_spirv_type, storage_class) =
|
||||
self.infer_param_ty_and_storage_class(entry_arg_abi.layout, hir_param, &attrs);
|
||||
|
||||
// Certain storage classes require an `OpTypeStruct` decorated with `Block`,
|
||||
// which we represent with `SpirvType::InterfaceBlock` (see its doc comment).
|
||||
// This "interface block" construct is also required for "runtime arrays".
|
||||
let is_unsized = self.lookup_type(value_spirv_type).sizeof(self).is_none();
|
||||
match storage_class {
|
||||
let var_ptr_spirv_type;
|
||||
let (value_ptr, value_len) = match storage_class {
|
||||
StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer => {
|
||||
if is_unsized {
|
||||
var_ptr_spirv_type = self.type_ptr_to(
|
||||
SpirvType::InterfaceBlock {
|
||||
inner_type: value_spirv_type,
|
||||
}
|
||||
.def(hir_param.span, self),
|
||||
);
|
||||
|
||||
let value_ptr = bx.struct_gep(var.with_type(var_ptr_spirv_type), 0);
|
||||
|
||||
let value_len = if is_unsized {
|
||||
match self.lookup_type(value_spirv_type) {
|
||||
SpirvType::RuntimeArray { .. } => {}
|
||||
_ => self.tcx.sess.span_err(
|
||||
@ -363,14 +296,24 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
"only plain slices are supported as unsized types",
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
value_spirv_type = SpirvType::InterfaceBlock {
|
||||
inner_type: value_spirv_type,
|
||||
}
|
||||
.def(hir_param.span, self);
|
||||
// FIXME(eddyb) shouldn't this be `usize`?
|
||||
let len_spirv_type = self.type_isize();
|
||||
let len = bx
|
||||
.emit()
|
||||
.array_length(len_spirv_type, None, var, 0)
|
||||
.unwrap();
|
||||
|
||||
Some(len.with_type(len_spirv_type))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(value_ptr, value_len)
|
||||
}
|
||||
_ => {
|
||||
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
|
||||
|
||||
if is_unsized {
|
||||
self.tcx.sess.span_fatal(
|
||||
hir_param.ty_span,
|
||||
@ -380,7 +323,30 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
(var.with_type(var_ptr_spirv_type), None)
|
||||
}
|
||||
};
|
||||
|
||||
// Compute call argument(s) to match what the Rust entry `fn` expects,
|
||||
// starting from the `value_ptr` pointing to a `value_spirv_type`
|
||||
// (e.g. `Input` doesn't use indirection, so we have to load from it).
|
||||
if let TyKind::Ref(..) = entry_arg_abi.layout.ty.kind() {
|
||||
call_args.push(value_ptr);
|
||||
match entry_arg_abi.mode {
|
||||
PassMode::Direct(_) => assert_eq!(value_len, None),
|
||||
PassMode::Pair(..) => call_args.push(value_len.unwrap()),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
assert_eq!(storage_class, StorageClass::Input);
|
||||
|
||||
call_args.push(match entry_arg_abi.mode {
|
||||
PassMode::Indirect { .. } => value_ptr,
|
||||
PassMode::Direct(_) => bx.load(value_ptr, entry_arg_abi.layout.align.abi),
|
||||
_ => unreachable!(),
|
||||
});
|
||||
assert_eq!(value_len, None);
|
||||
}
|
||||
|
||||
// FIXME(eddyb) check whether the storage class is compatible with the
|
||||
@ -457,12 +423,8 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
}
|
||||
|
||||
// Emit the `OpVariable` with its *Result* ID set to `var`.
|
||||
let var_spirv_type = SpirvType::Pointer {
|
||||
pointee: value_spirv_type,
|
||||
}
|
||||
.def(hir_param.span, self);
|
||||
self.emit_global()
|
||||
.variable(var_spirv_type, Some(var), storage_class, None);
|
||||
.variable(var_ptr_spirv_type, Some(var), storage_class, None);
|
||||
|
||||
// Record this `OpVariable` as needing to be added (if applicable),
|
||||
// to the *Interface* operands of the `OpEntryPoint` instruction.
|
||||
@ -475,8 +437,6 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
op_entry_point_interface_operands.push(var);
|
||||
}
|
||||
}
|
||||
|
||||
(var.with_type(var_spirv_type), storage_class)
|
||||
}
|
||||
|
||||
// Kernel mode takes its interface as function parameters(??)
|
||||
|
Loading…
Reference in New Issue
Block a user