entry: check the type of the argument we compute from the global OpVariable.

This commit is contained in:
Eduard-Mihai Burtescu 2021-04-05 10:32:30 +03:00 committed by Eduard-Mihai Burtescu
parent 6c3ce3fac2
commit ad859e681e

View File

@ -137,7 +137,8 @@ impl<'tcx> CodegenCx<'tcx> {
})
.collect::<Vec<_>>();
let mut bx = Builder::new_block(self, stub_fn, "");
// Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s).
// Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s),
// to match the argument type we have to pass to the Rust entry `fn`.
let arguments: Vec<_> = interface_globals
.iter()
.zip(arg_abis)
@ -146,24 +147,42 @@ impl<'tcx> CodegenCx<'tcx> {
|((&(global_var, storage_class), entry_fn_arg), hir_param)| {
bx.set_span(hir_param.span);
let var_value_spirv_type = match self.lookup_type(global_var.ty) {
SpirvType::Pointer { pointee } => pointee,
_ => unreachable!(),
};
let mut dst_len_arg = None;
let arg = match entry_fn_arg.layout.ty.kind() {
TyKind::Ref(_, ty, _) => {
if !ty.is_sized(self.tcx.at(span), self.param_env()) {
dst_len_arg.replace(
self.dst_length_argument(&mut bx, ty, hir_param, global_var),
);
TyKind::Ref(_, pointee_ty, _) => {
let arg_pointee_spirv_type =
self.layout_of(pointee_ty).spirv_type(hir_param.span, self);
assert_ty_eq!(self, arg_pointee_spirv_type, var_value_spirv_type);
if !pointee_ty.is_sized(self.tcx.at(span), self.param_env()) {
dst_len_arg.replace(self.dst_length_argument(
&mut bx, pointee_ty, hir_param, global_var,
));
}
global_var
}
_ => match entry_fn_arg.mode {
PassMode::Indirect { .. } => global_var,
PassMode::Direct(_) => {
assert_eq!(storage_class, StorageClass::Input);
bx.load(global_var, entry_fn_arg.layout.align.abi)
_ => {
assert_eq!(storage_class, StorageClass::Input);
let arg_spirv_type =
entry_fn_arg.layout.spirv_type(hir_param.span, self);
assert_ty_eq!(self, arg_spirv_type, var_value_spirv_type);
match entry_fn_arg.mode {
PassMode::Indirect { .. } => global_var,
PassMode::Direct(_) => {
bx.load(global_var, entry_fn_arg.layout.align.abi)
}
_ => unreachable!(),
}
_ => unreachable!(),
},
}
};
std::iter::once(arg).chain(dst_len_arg)
},