diff --git a/CHANGELOG.md b/CHANGELOG.md index a73b79330..fb8b0d6d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -89,6 +89,7 @@ By @stefnotch in [#5410](https://github.com/gfx-rs/wgpu/pull/5410) #### Naga - Implement `WGSL`'s `unpack4xI8`,`unpack4xU8`,`pack4xI8` and `pack4xU8`. By @VlaDexa in [#5424](https://github.com/gfx-rs/wgpu/pull/5424) +- Began work adding support for atomics to the SPIR-V frontend. Tracking issue is [here](https://github.com/gfx-rs/wgpu/issues/4489). By @schell in [#5702](https://github.com/gfx-rs/wgpu/pull/5702). ### Changes diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 12610083a..480f77134 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -564,6 +564,20 @@ enum SignAnchor { Operand, } +enum AtomicOpInst { + AtomicIIncrement, +} + +#[allow(dead_code)] +struct AtomicOp { + instruction: AtomicOpInst, + result_type_id: spirv::Word, + result_id: spirv::Word, + pointer_id: spirv::Word, + scope_id: spirv::Word, + memory_semantics_id: spirv::Word, +} + pub struct Frontend { data: I, data_offset: usize, @@ -575,6 +589,8 @@ pub struct Frontend { future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>, lookup_member: FastHashMap<(Handle, MemberIndex), LookupMember>, handle_sampling: FastHashMap, image::SamplingFlags>, + // Used to upgrade types used in atomic ops to atomic types, keyed by pointer id + lookup_atomic: FastHashMap, lookup_type: FastHashMap, lookup_void_type: Option, lookup_storage_buffer_types: FastHashMap, crate::StorageAccess>, @@ -630,6 +646,7 @@ impl> Frontend { future_member_decor: FastHashMap::default(), handle_sampling: FastHashMap::default(), lookup_member: FastHashMap::default(), + lookup_atomic: FastHashMap::default(), lookup_type: FastHashMap::default(), lookup_void_type: None, lookup_storage_buffer_types: FastHashMap::default(), @@ -3943,7 +3960,81 @@ impl> Frontend { ); emitter.start(ctx.expressions); } - _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), + Op::AtomicIIncrement => { + inst.expect(6)?; + let start = self.data_offset; + let span = self.span_from_with_op(start); + let result_type_id = self.next()?; + let result_id = self.next()?; + let pointer_id = self.next()?; + let scope_id = self.next()?; + let memory_semantics_id = self.next()?; + // Store the op for a later pass where we "upgrade" the pointer type + let atomic = AtomicOp { + instruction: AtomicOpInst::AtomicIIncrement, + result_type_id, + result_id, + pointer_id, + scope_id, + memory_semantics_id, + }; + self.lookup_atomic.insert(pointer_id, atomic); + + log::trace!("\t\t\tlooking up expr {:?}", pointer_id); + + let (p_lexp_handle, p_lexp_ty_id) = { + let lexp = self.lookup_expression.lookup(pointer_id)?; + let handle = get_expr_handle!(pointer_id, &lexp); + (handle, lexp.type_id) + }; + log::trace!("\t\t\tlooking up type {pointer_id:?}"); + let p_ty = self.lookup_type.lookup(p_lexp_ty_id)?; + let p_ty_base_id = + p_ty.base_id.ok_or(Error::InvalidAccessType(p_lexp_ty_id))?; + log::trace!("\t\t\tlooking up base type {p_ty_base_id:?} of {p_ty:?}"); + let p_base_ty = self.lookup_type.lookup(p_ty_base_id)?; + + // Create an expression for our result + let r_lexp_handle = { + let expr = crate::Expression::AtomicResult { + ty: p_base_ty.handle, + comparison: false, + }; + let handle = ctx.expressions.append(expr, span); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + handle + }; + + // Create a literal "1" since WGSL lacks an increment operation + let one_lexp_handle = make_index_literal( + ctx, + 1, + &mut block, + &mut emitter, + p_base_ty.handle, + p_lexp_ty_id, + span, + )?; + + // Create a statement for the op itself + let stmt = crate::Statement::Atomic { + pointer: p_lexp_handle, + fun: crate::AtomicFunction::Add, + value: one_lexp_handle, + result: r_lexp_handle, + }; + block.push(stmt, span); + } + _ => { + return Err(Error::UnsupportedInstruction(self.state, inst.op)); + } } }; @@ -5593,4 +5684,38 @@ mod test { ]; let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap(); } + + #[test] + fn atomic_i_inc() { + let _ = env_logger::builder() + .is_test(true) + .filter_level(log::LevelFilter::Trace) + .try_init(); + let bytes = include_bytes!("../../../tests/in/spv/atomic_i_increment.spv"); + let m = super::parse_u8_slice(bytes, &Default::default()).unwrap(); + let mut validator = crate::valid::Validator::new( + crate::valid::ValidationFlags::empty(), + Default::default(), + ); + let info = validator.validate(&m).unwrap(); + let wgsl = + crate::back::wgsl::write_string(&m, &info, crate::back::wgsl::WriterFlags::empty()) + .unwrap(); + log::info!("atomic_i_increment:\n{wgsl}"); + + let m = match crate::front::wgsl::parse_str(&wgsl) { + Ok(m) => m, + Err(e) => { + log::error!("{}", e.emit_to_string(&wgsl)); + // at this point we know atomics create invalid modules + // so simply bail + return; + } + }; + let mut validator = + crate::valid::Validator::new(crate::valid::ValidationFlags::all(), Default::default()); + if let Err(e) = validator.validate(&m) { + log::error!("{}", e.emit_to_string(&wgsl)); + } + } } diff --git a/naga/tests/in/spv/atomic_i_increment.spv b/naga/tests/in/spv/atomic_i_increment.spv new file mode 100644 index 000000000..1b17c5306 Binary files /dev/null and b/naga/tests/in/spv/atomic_i_increment.spv differ diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index ee775a3e6..5e2441e0d 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -14,14 +14,15 @@ const BASE_DIR_OUT: &str = "tests/out"; bitflags::bitflags! { #[derive(Clone, Copy)] struct Targets: u32 { - const IR = 0x1; - const ANALYSIS = 0x2; - const SPIRV = 0x4; - const METAL = 0x8; - const GLSL = 0x10; - const DOT = 0x20; - const HLSL = 0x40; - const WGSL = 0x80; + const IR = 1; + const ANALYSIS = 1 << 1; + const SPIRV = 1 << 2; + const METAL = 1 << 3; + const GLSL = 1 << 4; + const DOT = 1 << 5; + const HLSL = 1 << 6; + const WGSL = 1 << 7; + const NO_VALIDATION = 1 << 8; } } @@ -292,7 +293,13 @@ fn check_targets( } } - let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) + let validation_flags = if targets.contains(Targets::NO_VALIDATION) { + naga::valid::ValidationFlags::empty() + } else { + naga::valid::ValidationFlags::all() + }; + + let info = naga::valid::Validator::new(validation_flags, capabilities) .subgroup_stages(subgroup_stages) .subgroup_operations(subgroup_operations) .validate(module) @@ -317,7 +324,7 @@ fn check_targets( } } - naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) + naga::valid::Validator::new(validation_flags, capabilities) .subgroup_stages(subgroup_stages) .subgroup_operations(subgroup_operations) .validate(module) @@ -979,6 +986,12 @@ fn convert_spv_all() { false, Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ); + convert_spv( + "atomic_i_increment", + false, + // TODO(@schell): remove Targets::NO_VALIDATION when OpAtomicIIncrement lands + Targets::IR | Targets::NO_VALIDATION, + ); } #[cfg(feature = "glsl-in")]