diff --git a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs index 7d09d6cf6f..5d19a615f5 100644 --- a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs +++ b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs @@ -16,6 +16,7 @@ use rustc_hir::LlvmInlineAsmInner; use rustc_middle::bug; use rustc_span::{Span, DUMMY_SP}; use rustc_target::asm::{InlineAsmRegClass, InlineAsmRegOrRegClass, SpirVInlineAsmRegClass}; +use std::convert::TryFrom; pub struct InstructionTable { table: FxHashMap<&'static str, &'static rspirv::grammar::Instruction<'static>>, @@ -698,13 +699,23 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { TyPat::IndexComposite(pat) => { let mut ty = subst_ty_pat(cx, pat, ty_vars, leftover_operands)?; - for _index in leftover_operands { - // FIXME(eddyb) support more than just arrays, by looking - // up the indices (of struct fields) as constant integers. + for index in leftover_operands { + let index_to_usize = || match *index { + // FIXME(eddyb) support more than just literals, + // by looking up `IdRef`s as constant integers. + dr::Operand::LiteralInt32(i) => usize::try_from(i).ok(), + + _ => None, + }; ty = match cx.lookup_type(ty) { SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => element, + SpirvType::Adt { field_types, .. } => *index_to_usize() + .and_then(|i| field_types.get(i)) + .ok_or(Ambiguous)?, + + // FIXME(eddyb) support more than just arrays and structs. _ => return Err(Ambiguous), }; } diff --git a/tests/ui/lang/asm/infer-access-chain-slice.rs b/tests/ui/lang/asm/infer-access-chain-slice.rs new file mode 100644 index 0000000000..de4e821941 --- /dev/null +++ b/tests/ui/lang/asm/infer-access-chain-slice.rs @@ -0,0 +1,32 @@ +// Tests that `asm!` can infer the result type of `OpAccessChain`, +// when used to index slices. + +// build-pass + +use spirv_std as _; + +use glam::Vec4; + +#[spirv(fragment)] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slice_in: &[Vec4], + i: u32, + out: &mut Vec4, +) { + unsafe { + asm!( + // HACK(eddyb) we can't pass in the `&[T]` to `asm!` directly, + // and `as *const T` casts would require some special-casing + // to avoid actually going through an `OpTypePointer T`, so + // instead we extract the data pointer in the `asm!` itself. + "%slice_ptr = OpLoad _ {slice_ptr_ptr}", + "%data_ptr = OpCompositeExtract _ %slice_ptr 0", + "%val_ptr = OpAccessChain _ %data_ptr {index}", + "%val = OpLoad _ %val_ptr", + "OpStore {out_ptr} %val", + slice_ptr_ptr = in(reg) &slice_in, + index = in(reg) i, + out_ptr = in(reg) out, + ); + } +}