ADD metal namespace for uint4 (#6417)

This commit is contained in:
Xiaopeng Li 2024-10-17 17:45:29 +08:00 committed by GitHub
parent 59f56e0263
commit 74ef445bca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 7 deletions

View File

@ -3323,7 +3323,10 @@ impl<W: Write> Writer<W> {
let name = self.namer.call("");
self.start_baking_expression(result, &context.expression, &name)?;
self.named_expressions.insert(result, name);
write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?;
write!(
self.out,
"{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
)?;
if let Some(predicate) = predicate {
self.put_expression(predicate, &context.expression, true)?;
} else {
@ -4487,7 +4490,7 @@ template <typename A>
let name = self.namer.call("unpackUint32x4");
writeln!(
self.out,
"uint4 {name}(uint b0, \
"{NAMESPACE}::uint4 {name}(uint b0, \
uint b1, \
uint b2, \
uint b3, \
@ -4506,7 +4509,7 @@ template <typename A>
)?;
writeln!(
self.out,
"{}return uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
"{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
(b15 << 24 | b14 << 16 | b13 << 8 | b12));",

View File

@ -11,8 +11,8 @@ void main_1(
) {
uint _e5 = subgroup_size_1;
uint _e6 = subgroup_invocation_id_1;
metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((_e6 & 1u) == 1u), 0, 0, 0);
metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
metal::uint4 unnamed = metal::uint4((uint64_t)metal::simd_ballot((_e6 & 1u) == 1u), 0, 0, 0);
metal::uint4 unnamed_1 = metal::uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
bool unnamed_2 = metal::simd_all(_e6 != 0u);
bool unnamed_3 = metal::simd_any(_e6 == 0u);
uint unnamed_4 = metal::simd_sum(_e6);

View File

@ -19,8 +19,8 @@ kernel void main_(
) {
const Structure sizes = { num_subgroups, subgroup_size };
metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup);
metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0);
metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
metal::uint4 unnamed = metal::uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0);
metal::uint4 unnamed_1 = metal::uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
bool unnamed_2 = metal::simd_all(subgroup_invocation_id != 0u);
bool unnamed_3 = metal::simd_any(subgroup_invocation_id == 0u);
uint unnamed_4 = metal::simd_sum(subgroup_invocation_id);