mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-22 06:44:14 +00:00
[naga msl-out] Implement atomicCompareExchangeWeak for MSL backend (#6265)
This commit is contained in:
parent
d9178a1876
commit
bf33e481f3
@ -93,6 +93,10 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
|
||||
|
||||
- Allow using [VK_GOOGLE_display_timing](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_GOOGLE_display_timing.html) unsafely with the `VULKAN_GOOGLE_DISPLAY_TIMING` feature. By @DJMcNab in [#6149](https://github.com/gfx-rs/wgpu/pull/6149)
|
||||
|
||||
#### Metal
|
||||
|
||||
- Implement `atomicCompareExchangeWeak`. By @AsherJingkongChen in [#6265](https://github.com/gfx-rs/wgpu/pull/6265)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Fix incorrect hlsl image output type conversion. By @atlv24 in [#6123](https://github.com/gfx-rs/wgpu/pull/6123)
|
||||
|
@ -33,6 +33,7 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
|
||||
const RAY_QUERY_FIELD_READY: &str = "ready";
|
||||
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
|
||||
|
||||
pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
|
||||
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
|
||||
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
|
||||
|
||||
@ -1279,42 +1280,6 @@ impl<W: Write> Writer<W> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn put_atomic_operation(
|
||||
&mut self,
|
||||
pointer: Handle<crate::Expression>,
|
||||
key: &str,
|
||||
value: Handle<crate::Expression>,
|
||||
context: &ExpressionContext,
|
||||
) -> BackendResult {
|
||||
// If the pointer we're passing to the atomic operation needs to be conditional
|
||||
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
|
||||
// the pointer operand should be unchecked.
|
||||
let policy = context.choose_bounds_check_policy(pointer);
|
||||
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
|
||||
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;
|
||||
|
||||
// If requested and successfully put bounds checks, continue the ternary expression.
|
||||
if checked {
|
||||
write!(self.out, " ? ")?;
|
||||
}
|
||||
|
||||
write!(
|
||||
self.out,
|
||||
"{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}"
|
||||
)?;
|
||||
self.put_access_chain(pointer, policy, context)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.put_expression(value, context, true)?;
|
||||
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
|
||||
|
||||
// Finish the ternary expression.
|
||||
if checked {
|
||||
write!(self.out, " : DefaultConstructible()")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Emit code for the arithmetic expression of the dot product.
|
||||
///
|
||||
fn put_dot_product(
|
||||
@ -3182,24 +3147,65 @@ impl<W: Write> Writer<W> {
|
||||
value,
|
||||
result,
|
||||
} => {
|
||||
let context = &context.expression;
|
||||
|
||||
// This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
|
||||
// `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
|
||||
// `Some`, we are not operating on a 64-bit value, and that if we are
|
||||
// operating on a 64-bit value, `result` is `None`.
|
||||
write!(self.out, "{level}")?;
|
||||
let fun_str = if let Some(result) = result {
|
||||
let fun_key = if let Some(result) = result {
|
||||
let res_name = Baked(result).to_string();
|
||||
self.start_baking_expression(result, &context.expression, &res_name)?;
|
||||
self.start_baking_expression(result, context, &res_name)?;
|
||||
self.named_expressions.insert(result, res_name);
|
||||
fun.to_msl()?
|
||||
} else if context.expression.resolve_type(value).scalar_width() == Some(8) {
|
||||
fun.to_msl()
|
||||
} else if context.resolve_type(value).scalar_width() == Some(8) {
|
||||
fun.to_msl_64_bit()?
|
||||
} else {
|
||||
fun.to_msl()?
|
||||
fun.to_msl()
|
||||
};
|
||||
|
||||
self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
|
||||
// done
|
||||
// If the pointer we're passing to the atomic operation needs to be conditional
|
||||
// for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
|
||||
// the pointer operand should be unchecked.
|
||||
let policy = context.choose_bounds_check_policy(pointer);
|
||||
let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
|
||||
&& self.put_bounds_checks(pointer, context, back::Level(0), "")?;
|
||||
|
||||
// If requested and successfully put bounds checks, continue the ternary expression.
|
||||
if checked {
|
||||
write!(self.out, " ? ")?;
|
||||
}
|
||||
|
||||
// Put the atomic function invocation.
|
||||
match *fun {
|
||||
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
|
||||
write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
|
||||
self.put_access_chain(pointer, policy, context)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.put_expression(cmp, context, true)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.put_expression(value, context, true)?;
|
||||
write!(self.out, ")")?;
|
||||
}
|
||||
_ => {
|
||||
write!(
|
||||
self.out,
|
||||
"{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
|
||||
)?;
|
||||
self.put_access_chain(pointer, policy, context)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.put_expression(value, context, true)?;
|
||||
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
|
||||
}
|
||||
}
|
||||
|
||||
// Finish the ternary expression.
|
||||
if checked {
|
||||
write!(self.out, " : DefaultConstructible()")?;
|
||||
}
|
||||
|
||||
// Done
|
||||
writeln!(self.out, ";")?;
|
||||
}
|
||||
crate::Statement::WorkGroupUniformLoad { pointer, result } => {
|
||||
@ -3827,7 +3833,33 @@ impl<W: Write> Writer<W> {
|
||||
}}"
|
||||
)?;
|
||||
}
|
||||
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
|
||||
&crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
|
||||
let arg_type_name = scalar.to_msl_name();
|
||||
let called_func_name = "atomic_compare_exchange_weak_explicit";
|
||||
let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
|
||||
let struct_name = &self.names[&NameKey::Type(*struct_ty)];
|
||||
|
||||
writeln!(self.out)?;
|
||||
|
||||
for address_space_name in ["device", "threadgroup"] {
|
||||
writeln!(
|
||||
self.out,
|
||||
"\
|
||||
template <typename A>
|
||||
{struct_name} {defined_func_name}(
|
||||
{address_space_name} A *atomic_ptr,
|
||||
{arg_type_name} cmp,
|
||||
{arg_type_name} v
|
||||
) {{
|
||||
bool swapped = {NAMESPACE}::{called_func_name}(
|
||||
atomic_ptr, &cmp, v,
|
||||
metal::memory_order_relaxed, metal::memory_order_relaxed
|
||||
);
|
||||
return {struct_name}{{cmp, swapped}};
|
||||
}}"
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -6065,8 +6097,8 @@ fn test_stack_size() {
|
||||
}
|
||||
|
||||
impl crate::AtomicFunction {
|
||||
fn to_msl(self) -> Result<&'static str, Error> {
|
||||
Ok(match self {
|
||||
const fn to_msl(self) -> &'static str {
|
||||
match self {
|
||||
Self::Add => "fetch_add",
|
||||
Self::Subtract => "fetch_sub",
|
||||
Self::And => "fetch_and",
|
||||
@ -6075,10 +6107,8 @@ impl crate::AtomicFunction {
|
||||
Self::Min => "fetch_min",
|
||||
Self::Max => "fetch_max",
|
||||
Self::Exchange { compare: None } => "exchange",
|
||||
Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
|
||||
"atomic CompareExchange".to_string(),
|
||||
))?,
|
||||
})
|
||||
Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_msl_64_bit(self) -> Result<&'static str, Error> {
|
||||
|
161
naga/tests/out/msl/atomicCompareExchange.msl
Normal file
161
naga/tests/out/msl/atomicCompareExchange.msl
Normal file
@ -0,0 +1,161 @@
|
||||
// language: metal1.0
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using metal::uint;
|
||||
|
||||
struct type_2 {
|
||||
metal::atomic_int inner[128];
|
||||
};
|
||||
struct type_4 {
|
||||
metal::atomic_uint inner[128];
|
||||
};
|
||||
struct _atomic_compare_exchange_resultSint4_ {
|
||||
int old_value;
|
||||
bool exchanged;
|
||||
};
|
||||
struct _atomic_compare_exchange_resultUint4_ {
|
||||
uint old_value;
|
||||
bool exchanged;
|
||||
};
|
||||
|
||||
template <typename A>
|
||||
_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit(
|
||||
device A *atomic_ptr,
|
||||
int cmp,
|
||||
int v
|
||||
) {
|
||||
bool swapped = metal::atomic_compare_exchange_weak_explicit(
|
||||
atomic_ptr, &cmp, v,
|
||||
metal::memory_order_relaxed, metal::memory_order_relaxed
|
||||
);
|
||||
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
|
||||
}
|
||||
template <typename A>
|
||||
_atomic_compare_exchange_resultSint4_ naga_atomic_compare_exchange_weak_explicit(
|
||||
threadgroup A *atomic_ptr,
|
||||
int cmp,
|
||||
int v
|
||||
) {
|
||||
bool swapped = metal::atomic_compare_exchange_weak_explicit(
|
||||
atomic_ptr, &cmp, v,
|
||||
metal::memory_order_relaxed, metal::memory_order_relaxed
|
||||
);
|
||||
return _atomic_compare_exchange_resultSint4_{cmp, swapped};
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
|
||||
device A *atomic_ptr,
|
||||
uint cmp,
|
||||
uint v
|
||||
) {
|
||||
bool swapped = metal::atomic_compare_exchange_weak_explicit(
|
||||
atomic_ptr, &cmp, v,
|
||||
metal::memory_order_relaxed, metal::memory_order_relaxed
|
||||
);
|
||||
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
|
||||
}
|
||||
template <typename A>
|
||||
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
|
||||
threadgroup A *atomic_ptr,
|
||||
uint cmp,
|
||||
uint v
|
||||
) {
|
||||
bool swapped = metal::atomic_compare_exchange_weak_explicit(
|
||||
atomic_ptr, &cmp, v,
|
||||
metal::memory_order_relaxed, metal::memory_order_relaxed
|
||||
);
|
||||
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
|
||||
}
|
||||
constant uint SIZE = 128u;
|
||||
|
||||
kernel void test_atomic_compare_exchange_i32_(
|
||||
device type_2& arr_i32_ [[user(fake0)]]
|
||||
) {
|
||||
uint i = 0u;
|
||||
int old = {};
|
||||
bool exchanged = {};
|
||||
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
|
||||
bool loop_init = true;
|
||||
LOOP_IS_REACHABLE while(true) {
|
||||
if (!loop_init) {
|
||||
uint _e27 = i;
|
||||
i = _e27 + 1u;
|
||||
}
|
||||
loop_init = false;
|
||||
uint _e2 = i;
|
||||
if (_e2 < SIZE) {
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
{
|
||||
uint _e6 = i;
|
||||
int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed);
|
||||
old = _e8;
|
||||
exchanged = false;
|
||||
LOOP_IS_REACHABLE while(true) {
|
||||
bool _e12 = exchanged;
|
||||
if (!(_e12)) {
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
{
|
||||
int _e14 = old;
|
||||
int new_ = as_type<int>(as_type<float>(_e14) + 1.0);
|
||||
uint _e20 = i;
|
||||
int _e22 = old;
|
||||
_atomic_compare_exchange_resultSint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_i32_.inner[_e20], _e22, new_);
|
||||
old = _e23.old_value;
|
||||
exchanged = _e23.exchanged;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
kernel void test_atomic_compare_exchange_u32_(
|
||||
device type_4& arr_u32_ [[user(fake0)]]
|
||||
) {
|
||||
uint i_1 = 0u;
|
||||
uint old_1 = {};
|
||||
bool exchanged_1 = {};
|
||||
bool loop_init_1 = true;
|
||||
LOOP_IS_REACHABLE while(true) {
|
||||
if (!loop_init_1) {
|
||||
uint _e27 = i_1;
|
||||
i_1 = _e27 + 1u;
|
||||
}
|
||||
loop_init_1 = false;
|
||||
uint _e2 = i_1;
|
||||
if (_e2 < SIZE) {
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
{
|
||||
uint _e6 = i_1;
|
||||
uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed);
|
||||
old_1 = _e8;
|
||||
exchanged_1 = false;
|
||||
LOOP_IS_REACHABLE while(true) {
|
||||
bool _e12 = exchanged_1;
|
||||
if (!(_e12)) {
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
{
|
||||
uint _e14 = old_1;
|
||||
uint new_1 = as_type<uint>(as_type<float>(_e14) + 1.0);
|
||||
uint _e20 = i_1;
|
||||
uint _e22 = old_1;
|
||||
_atomic_compare_exchange_resultUint4_ _e23 = naga_atomic_compare_exchange_weak_explicit(&arr_u32_.inner[_e20], _e22, new_1);
|
||||
old_1 = _e23.old_value;
|
||||
exchanged_1 = _e23.exchanged;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
48
naga/tests/out/msl/overrides-atomicCompareExchangeWeak.msl
Normal file
48
naga/tests/out/msl/overrides-atomicCompareExchangeWeak.msl
Normal file
@ -0,0 +1,48 @@
|
||||
// language: metal1.0
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using metal::uint;
|
||||
|
||||
struct _atomic_compare_exchange_resultUint4_ {
|
||||
uint old_value;
|
||||
bool exchanged;
|
||||
};
|
||||
|
||||
template <typename A>
|
||||
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
|
||||
device A *atomic_ptr,
|
||||
uint cmp,
|
||||
uint v
|
||||
) {
|
||||
bool swapped = metal::atomic_compare_exchange_weak_explicit(
|
||||
atomic_ptr, &cmp, v,
|
||||
metal::memory_order_relaxed, metal::memory_order_relaxed
|
||||
);
|
||||
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
|
||||
}
|
||||
template <typename A>
|
||||
_atomic_compare_exchange_resultUint4_ naga_atomic_compare_exchange_weak_explicit(
|
||||
threadgroup A *atomic_ptr,
|
||||
uint cmp,
|
||||
uint v
|
||||
) {
|
||||
bool swapped = metal::atomic_compare_exchange_weak_explicit(
|
||||
atomic_ptr, &cmp, v,
|
||||
metal::memory_order_relaxed, metal::memory_order_relaxed
|
||||
);
|
||||
return _atomic_compare_exchange_resultUint4_{cmp, swapped};
|
||||
}
|
||||
constant int o = 2;
|
||||
|
||||
kernel void f(
|
||||
metal::uint3 __local_invocation_id [[thread_position_in_threadgroup]]
|
||||
, threadgroup metal::atomic_uint& a
|
||||
) {
|
||||
if (metal::all(__local_invocation_id == metal::uint3(0u))) {
|
||||
metal::atomic_store_explicit(&a, 0, metal::memory_order_relaxed);
|
||||
}
|
||||
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
||||
_atomic_compare_exchange_resultUint4_ _e5 = naga_atomic_compare_exchange_weak_explicit(&a, 2u, 1u);
|
||||
return;
|
||||
}
|
@ -773,7 +773,10 @@ fn convert_wgsl() {
|
||||
"atomicOps",
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
),
|
||||
("atomicCompareExchange", Targets::SPIRV | Targets::WGSL),
|
||||
(
|
||||
"atomicCompareExchange",
|
||||
Targets::SPIRV | Targets::METAL | Targets::WGSL,
|
||||
),
|
||||
(
|
||||
"padding",
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
@ -917,7 +920,7 @@ fn convert_wgsl() {
|
||||
),
|
||||
(
|
||||
"overrides-atomicCompareExchangeWeak",
|
||||
Targets::IR | Targets::SPIRV,
|
||||
Targets::IR | Targets::SPIRV | Targets::METAL,
|
||||
),
|
||||
(
|
||||
"overrides-ray-query",
|
||||
|
Loading…
Reference in New Issue
Block a user