asm: infer the result type of OpAccessChain (when indexing into arrays).

This commit is contained in:
Eduard-Mihai Burtescu 2021-04-15 17:45:54 +03:00 committed by Eduard-Mihai Burtescu
parent fedac0f4a5
commit bb7adc912f
2 changed files with 56 additions and 6 deletions

View File

@ -591,12 +591,14 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
struct Ambiguous;
/// Construct a type from `pat`, replacing `TyPat::Var(i)` with `ty_vars[i]`.
/// `leftover_operands` is used for `IndexComposite` patterns, if any exist.
/// If the pattern isn't constraining enough to determine an unique type,
/// `Err(Ambiguous)` is returned instead.
fn subst_ty_pat(
cx: &CodegenCx<'_>,
pat: &TyPat<'_>,
ty_vars: &[Option<Word>],
leftover_operands: &[dr::Operand],
) -> Result<Word, Ambiguous> {
Ok(match pat {
&TyPat::Var(i) => match ty_vars.get(i) {
@ -604,17 +606,37 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
_ => return Err(Ambiguous),
},
TyPat::Pointer(_, pat) => SpirvType::Pointer {
pointee: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
}
.def(DUMMY_SP, cx),
TyPat::Vector4(pat) => SpirvType::Vector {
element: subst_ty_pat(cx, pat, ty_vars)?,
element: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
count: 4,
}
.def(DUMMY_SP, cx),
TyPat::SampledImage(pat) => SpirvType::SampledImage {
image_type: subst_ty_pat(cx, pat, ty_vars)?,
image_type: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
}
.def(DUMMY_SP, cx),
TyPat::IndexComposite(pat) => {
let mut ty = subst_ty_pat(cx, pat, ty_vars, leftover_operands)?;
for _index in leftover_operands {
// FIXME(eddyb) support more than just arrays, by looking
// up the indices (of struct fields) as constant integers.
ty = match cx.lookup_type(ty) {
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element } => element,
_ => return Err(Ambiguous),
};
}
ty
}
_ => return Err(Ambiguous),
})
}
@ -631,11 +653,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
let mut combined_ty_vars = [None];
let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any());
let mut operands = instruction.operands.iter();
let mut next_id_operand = || operands.find_map(|o| o.id_ref_any());
while let TyListPat::Cons { first: pat, suffix } = *sig.input_types {
sig.input_types = suffix;
let match_result = match id_to_type_map.get(&ids.next()?) {
let match_result = match id_to_type_map.get(&next_id_operand()?) {
Some(&ty) => match_ty_pat(self, pat, ty),
// Non-value ID operand (or value operand of unknown type),
@ -673,14 +696,19 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
TyListPat::Any => {}
TyListPat::Nil => {
if ids.next().is_some() {
if next_id_operand().is_some() {
return None;
}
}
_ => return None,
}
match subst_ty_pat(self, sig.output_type.unwrap(), &combined_ty_vars) {
match subst_ty_pat(
self,
sig.output_type.unwrap(),
&combined_ty_vars,
operands.as_slice(),
) {
Ok(ty) => Some(ty),
Err(Ambiguous) => None,
}

View File

@ -0,0 +1,22 @@
// Tests that `asm!` can infer the result type of `OpAccessChain`,
// when used to index arrays.
// build-pass
use spirv_std as _;
use glam::Vec4;
#[spirv(fragment)]
pub fn main(#[spirv(push_constant)] array_in: &[Vec4; 16], i: u32, out: &mut Vec4) {
unsafe {
asm!(
"%val_ptr = OpAccessChain _ {array_ptr} {index}",
"%val = OpLoad _ %val_ptr",
"OpStore {out_ptr} %val",
array_ptr = in(reg) array_in,
index = in(reg) i,
out_ptr = in(reg) out,
);
}
}