spv-in parsing Op::AtomicIIncrement (#5702)

Parse spirv::Op::AtomicIIncrement, add atomic_i_increment test.
This commit is contained in:
Schell Carl Scivally 2024-05-30 16:39:32 +12:00 committed by GitHub
parent 60a14c67fb
commit 480d4dbd73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 150 additions and 11 deletions

View File

@ -89,6 +89,7 @@ By @stefnotch in [#5410](https://github.com/gfx-rs/wgpu/pull/5410)
#### Naga #### Naga
- Implement `WGSL`'s `unpack4xI8`,`unpack4xU8`,`pack4xI8` and `pack4xU8`. By @VlaDexa in [#5424](https://github.com/gfx-rs/wgpu/pull/5424) - Implement `WGSL`'s `unpack4xI8`,`unpack4xU8`,`pack4xI8` and `pack4xU8`. By @VlaDexa in [#5424](https://github.com/gfx-rs/wgpu/pull/5424)
- Began work adding support for atomics to the SPIR-V frontend. Tracking issue is [here](https://github.com/gfx-rs/wgpu/issues/4489). By @schell in [#5702](https://github.com/gfx-rs/wgpu/pull/5702).
### Changes ### Changes

View File

@ -564,6 +564,20 @@ enum SignAnchor {
Operand, Operand,
} }
enum AtomicOpInst {
AtomicIIncrement,
}
#[allow(dead_code)]
struct AtomicOp {
instruction: AtomicOpInst,
result_type_id: spirv::Word,
result_id: spirv::Word,
pointer_id: spirv::Word,
scope_id: spirv::Word,
memory_semantics_id: spirv::Word,
}
pub struct Frontend<I> { pub struct Frontend<I> {
data: I, data: I,
data_offset: usize, data_offset: usize,
@ -575,6 +589,8 @@ pub struct Frontend<I> {
future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>, future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>,
lookup_member: FastHashMap<(Handle<crate::Type>, MemberIndex), LookupMember>, lookup_member: FastHashMap<(Handle<crate::Type>, MemberIndex), LookupMember>,
handle_sampling: FastHashMap<Handle<crate::GlobalVariable>, image::SamplingFlags>, handle_sampling: FastHashMap<Handle<crate::GlobalVariable>, image::SamplingFlags>,
// Used to upgrade types used in atomic ops to atomic types, keyed by pointer id
lookup_atomic: FastHashMap<spirv::Word, AtomicOp>,
lookup_type: FastHashMap<spirv::Word, LookupType>, lookup_type: FastHashMap<spirv::Word, LookupType>,
lookup_void_type: Option<spirv::Word>, lookup_void_type: Option<spirv::Word>,
lookup_storage_buffer_types: FastHashMap<Handle<crate::Type>, crate::StorageAccess>, lookup_storage_buffer_types: FastHashMap<Handle<crate::Type>, crate::StorageAccess>,
@ -630,6 +646,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
future_member_decor: FastHashMap::default(), future_member_decor: FastHashMap::default(),
handle_sampling: FastHashMap::default(), handle_sampling: FastHashMap::default(),
lookup_member: FastHashMap::default(), lookup_member: FastHashMap::default(),
lookup_atomic: FastHashMap::default(),
lookup_type: FastHashMap::default(), lookup_type: FastHashMap::default(),
lookup_void_type: None, lookup_void_type: None,
lookup_storage_buffer_types: FastHashMap::default(), lookup_storage_buffer_types: FastHashMap::default(),
@ -3943,7 +3960,81 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
); );
emitter.start(ctx.expressions); emitter.start(ctx.expressions);
} }
_ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), Op::AtomicIIncrement => {
inst.expect(6)?;
let start = self.data_offset;
let span = self.span_from_with_op(start);
let result_type_id = self.next()?;
let result_id = self.next()?;
let pointer_id = self.next()?;
let scope_id = self.next()?;
let memory_semantics_id = self.next()?;
// Store the op for a later pass where we "upgrade" the pointer type
let atomic = AtomicOp {
instruction: AtomicOpInst::AtomicIIncrement,
result_type_id,
result_id,
pointer_id,
scope_id,
memory_semantics_id,
};
self.lookup_atomic.insert(pointer_id, atomic);
log::trace!("\t\t\tlooking up expr {:?}", pointer_id);
let (p_lexp_handle, p_lexp_ty_id) = {
let lexp = self.lookup_expression.lookup(pointer_id)?;
let handle = get_expr_handle!(pointer_id, &lexp);
(handle, lexp.type_id)
};
log::trace!("\t\t\tlooking up type {pointer_id:?}");
let p_ty = self.lookup_type.lookup(p_lexp_ty_id)?;
let p_ty_base_id =
p_ty.base_id.ok_or(Error::InvalidAccessType(p_lexp_ty_id))?;
log::trace!("\t\t\tlooking up base type {p_ty_base_id:?} of {p_ty:?}");
let p_base_ty = self.lookup_type.lookup(p_ty_base_id)?;
// Create an expression for our result
let r_lexp_handle = {
let expr = crate::Expression::AtomicResult {
ty: p_base_ty.handle,
comparison: false,
};
let handle = ctx.expressions.append(expr, span);
self.lookup_expression.insert(
result_id,
LookupExpression {
handle,
type_id: result_type_id,
block_id,
},
);
handle
};
// Create a literal "1" since WGSL lacks an increment operation
let one_lexp_handle = make_index_literal(
ctx,
1,
&mut block,
&mut emitter,
p_base_ty.handle,
p_lexp_ty_id,
span,
)?;
// Create a statement for the op itself
let stmt = crate::Statement::Atomic {
pointer: p_lexp_handle,
fun: crate::AtomicFunction::Add,
value: one_lexp_handle,
result: r_lexp_handle,
};
block.push(stmt, span);
}
_ => {
return Err(Error::UnsupportedInstruction(self.state, inst.op));
}
} }
}; };
@ -5593,4 +5684,38 @@ mod test {
]; ];
let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap(); let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap();
} }
#[test]
fn atomic_i_inc() {
let _ = env_logger::builder()
.is_test(true)
.filter_level(log::LevelFilter::Trace)
.try_init();
let bytes = include_bytes!("../../../tests/in/spv/atomic_i_increment.spv");
let m = super::parse_u8_slice(bytes, &Default::default()).unwrap();
let mut validator = crate::valid::Validator::new(
crate::valid::ValidationFlags::empty(),
Default::default(),
);
let info = validator.validate(&m).unwrap();
let wgsl =
crate::back::wgsl::write_string(&m, &info, crate::back::wgsl::WriterFlags::empty())
.unwrap();
log::info!("atomic_i_increment:\n{wgsl}");
let m = match crate::front::wgsl::parse_str(&wgsl) {
Ok(m) => m,
Err(e) => {
log::error!("{}", e.emit_to_string(&wgsl));
// at this point we know atomics create invalid modules
// so simply bail
return;
}
};
let mut validator =
crate::valid::Validator::new(crate::valid::ValidationFlags::all(), Default::default());
if let Err(e) = validator.validate(&m) {
log::error!("{}", e.emit_to_string(&wgsl));
}
}
} }

Binary file not shown.

View File

@ -14,14 +14,15 @@ const BASE_DIR_OUT: &str = "tests/out";
bitflags::bitflags! { bitflags::bitflags! {
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
struct Targets: u32 { struct Targets: u32 {
const IR = 0x1; const IR = 1;
const ANALYSIS = 0x2; const ANALYSIS = 1 << 1;
const SPIRV = 0x4; const SPIRV = 1 << 2;
const METAL = 0x8; const METAL = 1 << 3;
const GLSL = 0x10; const GLSL = 1 << 4;
const DOT = 0x20; const DOT = 1 << 5;
const HLSL = 0x40; const HLSL = 1 << 6;
const WGSL = 0x80; const WGSL = 1 << 7;
const NO_VALIDATION = 1 << 8;
} }
} }
@ -292,7 +293,13 @@ fn check_targets(
} }
} }
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) let validation_flags = if targets.contains(Targets::NO_VALIDATION) {
naga::valid::ValidationFlags::empty()
} else {
naga::valid::ValidationFlags::all()
};
let info = naga::valid::Validator::new(validation_flags, capabilities)
.subgroup_stages(subgroup_stages) .subgroup_stages(subgroup_stages)
.subgroup_operations(subgroup_operations) .subgroup_operations(subgroup_operations)
.validate(module) .validate(module)
@ -317,7 +324,7 @@ fn check_targets(
} }
} }
naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) naga::valid::Validator::new(validation_flags, capabilities)
.subgroup_stages(subgroup_stages) .subgroup_stages(subgroup_stages)
.subgroup_operations(subgroup_operations) .subgroup_operations(subgroup_operations)
.validate(module) .validate(module)
@ -979,6 +986,12 @@ fn convert_spv_all() {
false, false,
Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
); );
convert_spv(
"atomic_i_increment",
false,
// TODO(@schell): remove Targets::NO_VALIDATION when OpAtomicIIncrement lands
Targets::IR | Targets::NO_VALIDATION,
);
} }
#[cfg(feature = "glsl-in")] #[cfg(feature = "glsl-in")]