mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-25 00:04:11 +00:00
asm: infer the result type of OpAccessChain (when indexing into arrays).
This commit is contained in:
parent
fedac0f4a5
commit
bb7adc912f
@ -591,12 +591,14 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
|
||||
struct Ambiguous;
|
||||
|
||||
/// Construct a type from `pat`, replacing `TyPat::Var(i)` with `ty_vars[i]`.
|
||||
/// `leftover_operands` is used for `IndexComposite` patterns, if any exist.
|
||||
/// If the pattern isn't constraining enough to determine an unique type,
|
||||
/// `Err(Ambiguous)` is returned instead.
|
||||
fn subst_ty_pat(
|
||||
cx: &CodegenCx<'_>,
|
||||
pat: &TyPat<'_>,
|
||||
ty_vars: &[Option<Word>],
|
||||
leftover_operands: &[dr::Operand],
|
||||
) -> Result<Word, Ambiguous> {
|
||||
Ok(match pat {
|
||||
&TyPat::Var(i) => match ty_vars.get(i) {
|
||||
@ -604,17 +606,37 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
|
||||
_ => return Err(Ambiguous),
|
||||
},
|
||||
|
||||
TyPat::Pointer(_, pat) => SpirvType::Pointer {
|
||||
pointee: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
|
||||
}
|
||||
.def(DUMMY_SP, cx),
|
||||
|
||||
TyPat::Vector4(pat) => SpirvType::Vector {
|
||||
element: subst_ty_pat(cx, pat, ty_vars)?,
|
||||
element: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
|
||||
count: 4,
|
||||
}
|
||||
.def(DUMMY_SP, cx),
|
||||
|
||||
TyPat::SampledImage(pat) => SpirvType::SampledImage {
|
||||
image_type: subst_ty_pat(cx, pat, ty_vars)?,
|
||||
image_type: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
|
||||
}
|
||||
.def(DUMMY_SP, cx),
|
||||
|
||||
TyPat::IndexComposite(pat) => {
|
||||
let mut ty = subst_ty_pat(cx, pat, ty_vars, leftover_operands)?;
|
||||
for _index in leftover_operands {
|
||||
// FIXME(eddyb) support more than just arrays, by looking
|
||||
// up the indices (of struct fields) as constant integers.
|
||||
ty = match cx.lookup_type(ty) {
|
||||
SpirvType::Array { element, .. }
|
||||
| SpirvType::RuntimeArray { element } => element,
|
||||
|
||||
_ => return Err(Ambiguous),
|
||||
};
|
||||
}
|
||||
ty
|
||||
}
|
||||
|
||||
_ => return Err(Ambiguous),
|
||||
})
|
||||
}
|
||||
@ -631,11 +653,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
|
||||
|
||||
let mut combined_ty_vars = [None];
|
||||
|
||||
let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any());
|
||||
let mut operands = instruction.operands.iter();
|
||||
let mut next_id_operand = || operands.find_map(|o| o.id_ref_any());
|
||||
while let TyListPat::Cons { first: pat, suffix } = *sig.input_types {
|
||||
sig.input_types = suffix;
|
||||
|
||||
let match_result = match id_to_type_map.get(&ids.next()?) {
|
||||
let match_result = match id_to_type_map.get(&next_id_operand()?) {
|
||||
Some(&ty) => match_ty_pat(self, pat, ty),
|
||||
|
||||
// Non-value ID operand (or value operand of unknown type),
|
||||
@ -673,14 +696,19 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
|
||||
|
||||
TyListPat::Any => {}
|
||||
TyListPat::Nil => {
|
||||
if ids.next().is_some() {
|
||||
if next_id_operand().is_some() {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
|
||||
match subst_ty_pat(self, sig.output_type.unwrap(), &combined_ty_vars) {
|
||||
match subst_ty_pat(
|
||||
self,
|
||||
sig.output_type.unwrap(),
|
||||
&combined_ty_vars,
|
||||
operands.as_slice(),
|
||||
) {
|
||||
Ok(ty) => Some(ty),
|
||||
Err(Ambiguous) => None,
|
||||
}
|
||||
|
22
tests/ui/lang/asm/infer-access-chain-array.rs
Normal file
22
tests/ui/lang/asm/infer-access-chain-array.rs
Normal file
@ -0,0 +1,22 @@
|
||||
// Tests that `asm!` can infer the result type of `OpAccessChain`,
|
||||
// when used to index arrays.
|
||||
|
||||
// build-pass
|
||||
|
||||
use spirv_std as _;
|
||||
|
||||
use glam::Vec4;
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main(#[spirv(push_constant)] array_in: &[Vec4; 16], i: u32, out: &mut Vec4) {
|
||||
unsafe {
|
||||
asm!(
|
||||
"%val_ptr = OpAccessChain _ {array_ptr} {index}",
|
||||
"%val = OpLoad _ %val_ptr",
|
||||
"OpStore {out_ptr} %val",
|
||||
array_ptr = in(reg) array_in,
|
||||
index = in(reg) i,
|
||||
out_ptr = in(reg) out,
|
||||
);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user