diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 38774a6d0..55c0d5482 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -47,20 +47,20 @@ impl<'a> Display for TypeContext<'a> { // work around Metal toolchain bug with `uint` typedef crate::ScalarKind::Uint => write!(out, "{}::uint", NAMESPACE), _ => { - let kind_str = scalar_kind_string(kind); + let kind_str = kind.to_msl_name(); write!(out, "{}", kind_str) } } } crate::TypeInner::Atomic { kind, .. } => { - write!(out, "{}::atomic_{}", NAMESPACE, scalar_kind_string(kind)) + write!(out, "{}::atomic_{}", NAMESPACE, kind.to_msl_name()) } crate::TypeInner::Vector { size, kind, .. } => { write!( out, "{}::{}{}", NAMESPACE, - scalar_kind_string(kind), + kind.to_msl_name(), back::vector_size_str(size), ) } @@ -69,7 +69,7 @@ impl<'a> Display for TypeContext<'a> { out, "{}::{}{}x{}", NAMESPACE, - scalar_kind_string(crate::ScalarKind::Float), + crate::ScalarKind::Float.to_msl_name(), back::vector_size_str(columns), back::vector_size_str(rows), ) @@ -96,7 +96,7 @@ impl<'a> Display for TypeContext<'a> { Some(name) => name, None => return Ok(()), }; - write!(out, "{} {}&", class_name, scalar_kind_string(kind),) + write!(out, "{} {}&", class_name, kind.to_msl_name(),) } crate::TypeInner::ValuePointer { size: Some(size), @@ -113,7 +113,7 @@ impl<'a> Display for TypeContext<'a> { "{} {}::{}{}&", class_name, NAMESPACE, - scalar_kind_string(kind), + kind.to_msl_name(), back::vector_size_str(size), ) } @@ -178,7 +178,7 @@ impl<'a> Display for TypeContext<'a> { ("texture", "", format.into(), access) } }; - let base_name = scalar_kind_string(kind); + let base_name = kind.to_msl_name(); let array_str = if arrayed { "_array" } else { "" }; write!( out, @@ -316,12 +316,14 @@ pub struct Writer { struct_member_pads: FastHashSet<(Handle, u32)>, } -fn scalar_kind_string(kind: crate::ScalarKind) -> &'static str { - match kind { - crate::ScalarKind::Float => "float", - crate::ScalarKind::Sint => "int", - crate::ScalarKind::Uint => "uint", - crate::ScalarKind::Bool => "bool", +impl crate::ScalarKind { + fn to_msl_name(self) -> &'static str { + match self { + Self::Float => "float", + Self::Sint => "int", + Self::Uint => "uint", + Self::Bool => "bool", + } } } @@ -481,6 +483,29 @@ impl<'a> ExpressionContext<'a> { ) -> Option { index::access_needs_check(base, index, self.module, self.function, self.info) } + + // Because packed vectors such as `packed_float3` cannot be directly loaded, + // we convert them to unpacked vectors like `float3` on load. + fn get_packed_vec_kind( + &self, + expr_handle: Handle, + ) -> Option { + match self.function.expressions[expr_handle] { + crate::Expression::AccessIndex { base, index } => { + let ty = match *self.resolve_type(base) { + crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner, + ref ty => ty, + }; + match *ty { + crate::TypeInner::Struct { + ref members, span, .. + } => should_pack_struct_member(members, span, index as usize, self.module), + _ => None, + } + } + _ => None, + } + } } struct StatementContext<'a> { @@ -652,7 +677,7 @@ impl Writer { ) -> BackendResult { match context.module.types[ty].inner { crate::TypeInner::Scalar { width: 4, kind } if components.len() == 1 => { - write!(self.out, "{}", scalar_kind_string(kind))?; + write!(self.out, "{}", kind.to_msl_name())?; self.put_call_parameters(components.iter().cloned(), context)?; } crate::TypeInner::Vector { size, kind, .. } => { @@ -660,7 +685,7 @@ impl Writer { self.out, "{}::{}{}", NAMESPACE, - scalar_kind_string(kind), + kind.to_msl_name(), back::vector_size_str(size) )?; self.put_call_parameters(components.iter().cloned(), context)?; @@ -671,7 +696,7 @@ impl Writer { self.out, "{}::{}{}x{}", NAMESPACE, - scalar_kind_string(kind), + kind.to_msl_name(), back::vector_size_str(columns), back::vector_size_str(rows) )?; @@ -845,7 +870,7 @@ impl Writer { crate::TypeInner::Scalar { kind, .. } => kind, _ => return Err(Error::Validation), }; - let scalar = scalar_kind_string(scalar_kind); + let scalar = scalar_kind.to_msl_name(); let size = back::vector_size_str(size); write!(self.out, "{}::{}{}(", NAMESPACE, scalar, size)?; @@ -1246,7 +1271,7 @@ impl Writer { kind, convert, } => { - let scalar = scalar_kind_string(kind); + let scalar = kind.to_msl_name(); let (src_kind, src_width) = match *context.resolve_type(expr) { crate::TypeInner::Scalar { kind, width } | crate::TypeInner::Vector { kind, width, .. } => (kind, width), @@ -1487,7 +1512,15 @@ impl Writer { write!(self.out, ".{}", name)?; } crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => { - self.put_access_chain(base, policy, context)?; + let wrap_packed_vec_scalar_kind = context.get_packed_vec_kind(base); + //Note: this doesn't work for left-hand side + if let Some(scalar_kind) = wrap_packed_vec_scalar_kind { + write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?; + self.put_access_chain(base, policy, context)?; + write!(self.out, ")")?; + } else { + self.put_access_chain(base, policy, context)?; + } write!(self.out, ".{}", back::COMPONENTS[index as usize])?; } _ => { @@ -1614,23 +1647,7 @@ impl Writer { policy: index::BoundsCheckPolicy, context: &ExpressionContext, ) -> BackendResult { - // Because packed vectors such as `packed_float3` cannot be directly multipied by - // matrices, we convert them to unpacked vectors like `float3` on load. - let wrap_packed_vec_scalar_kind = match context.function.expressions[pointer] { - crate::Expression::AccessIndex { base, index } => { - let ty = match *context.resolve_type(base) { - crate::TypeInner::Pointer { base, .. } => &context.module.types[base].inner, - ref ty => ty, - }; - match *ty { - crate::TypeInner::Struct { - ref members, span, .. - } => should_pack_struct_member(members, span, index as usize, context.module), - _ => None, - } - } - _ => None, - }; + let wrap_packed_vec_scalar_kind = context.get_packed_vec_kind(pointer); let is_atomic = match *context.resolve_type(pointer) { crate::TypeInner::Pointer { base, .. } => match context.module.types[base].inner { crate::TypeInner::Atomic { .. } => true, @@ -1640,12 +1657,7 @@ impl Writer { }; if let Some(scalar_kind) = wrap_packed_vec_scalar_kind { - write!( - self.out, - "{}::{}3(", - NAMESPACE, - scalar_kind_string(scalar_kind) - )?; + write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?; self.put_access_chain(pointer, policy, context)?; write!(self.out, ")")?; } else if is_atomic { @@ -1761,15 +1773,22 @@ impl Writer { }; write!(self.out, "{}", ty_name)?; } + TypeResolution::Value(crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + }) => { + // work around Metal toolchain bug with `uint` typedef + write!(self.out, "{}::uint", NAMESPACE)?; + } TypeResolution::Value(crate::TypeInner::Scalar { kind, .. }) => { - write!(self.out, "{}", scalar_kind_string(kind))?; + write!(self.out, "{}", kind.to_msl_name())?; } TypeResolution::Value(crate::TypeInner::Vector { size, kind, .. }) => { write!( self.out, "{}::{}{}", NAMESPACE, - scalar_kind_string(kind), + kind.to_msl_name(), back::vector_size_str(size) )?; } @@ -1778,7 +1797,7 @@ impl Writer { self.out, "{}::{}{}x{}", NAMESPACE, - scalar_kind_string(crate::ScalarKind::Float), + crate::ScalarKind::Float.to_msl_name(), back::vector_size_str(columns), back::vector_size_str(rows), )?; @@ -2360,7 +2379,7 @@ impl Writer { "{}{}::packed_{}3 {};", back::INDENT, NAMESPACE, - scalar_kind_string(kind), + kind.to_msl_name(), member_name )?; } diff --git a/tests/in/globals.wgsl b/tests/in/globals.wgsl index e215f039e..03f87ddfb 100644 --- a/tests/in/globals.wgsl +++ b/tests/in/globals.wgsl @@ -16,6 +16,7 @@ var alignment: Foo; [[stage(compute), workgroup_size(1)]] fn main() { wg[3] = alignment.v1; + wg[2] = alignment.v3.x; atomicStore(&at, 2u); // Valid, Foo and at is in function scope diff --git a/tests/out/glsl/globals.main.Compute.glsl b/tests/out/glsl/globals.main.Compute.glsl index d6892df82..cdc5e7260 100644 --- a/tests/out/glsl/globals.main.Compute.glsl +++ b/tests/out/glsl/globals.main.Compute.glsl @@ -21,6 +21,8 @@ void main() { bool at = true; float _e7 = _group_0_binding_1_cs.v1_; wg[3] = _e7; + float _e12 = _group_0_binding_1_cs.v3_.x; + wg[2] = _e12; at_1 = 2u; return; } diff --git a/tests/out/hlsl/globals.hlsl b/tests/out/hlsl/globals.hlsl index 0c9f1cae8..1e05d9fa7 100644 --- a/tests/out/hlsl/globals.hlsl +++ b/tests/out/hlsl/globals.hlsl @@ -17,6 +17,8 @@ void main() float _expr7 = asfloat(alignment.Load(12)); wg[3] = _expr7; + float _expr12 = asfloat(alignment.Load(0+0)); + wg[2] = _expr12; at_1 = 2u; return; } diff --git a/tests/out/msl/boids.msl b/tests/out/msl/boids.msl index 461602f88..b717466f9 100644 --- a/tests/out/msl/boids.msl +++ b/tests/out/msl/boids.msl @@ -45,7 +45,7 @@ kernel void main_( metal::float2 pos; metal::float2 vel; metal::uint i = 0u; - uint index = global_invocation_id.x; + metal::uint index = global_invocation_id.x; if (index >= NUM_PARTICLES) { return; } diff --git a/tests/out/msl/globals.msl b/tests/out/msl/globals.msl index 7aeb00250..636aa69c8 100644 --- a/tests/out/msl/globals.msl +++ b/tests/out/msl/globals.msl @@ -20,6 +20,8 @@ kernel void main_( bool at = true; float _e7 = alignment.v1_; wg.inner[3] = _e7; + float _e12 = metal::float3(alignment.v3_).x; + wg.inner[2] = _e12; metal::atomic_store_explicit(&at_1, 2u, metal::memory_order_relaxed); return; } diff --git a/tests/out/msl/shadow.msl b/tests/out/msl/shadow.msl index 33475b841..54e918de4 100644 --- a/tests/out/msl/shadow.msl +++ b/tests/out/msl/shadow.msl @@ -63,7 +63,7 @@ fragment fs_mainOutput fs_main( } loop_init = false; metal::uint _e12 = i; - uint _e15 = u_globals.num_lights.x; + metal::uint _e15 = u_globals.num_lights.x; if (_e12 >= metal::min(_e15, c_max_lights)) { break; } diff --git a/tests/out/spv/globals.spvasm b/tests/out/spv/globals.spvasm index 45aec8c2f..c9ec4c220 100644 --- a/tests/out/spv/globals.spvasm +++ b/tests/out/spv/globals.spvasm @@ -1,21 +1,21 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 43 +; Bound: 48 OpCapability Shader OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %28 "main" -OpExecutionMode %28 LocalSize 1 1 1 -OpDecorate %13 ArrayStride 4 -OpMemberDecorate %15 0 Offset 0 -OpMemberDecorate %15 1 Offset 12 -OpDecorate %20 NonWritable -OpDecorate %20 DescriptorSet 0 -OpDecorate %20 Binding 1 -OpDecorate %21 Block -OpMemberDecorate %21 0 Offset 0 +OpEntryPoint GLCompute %29 "main" +OpExecutionMode %29 LocalSize 1 1 1 +OpDecorate %14 ArrayStride 4 +OpMemberDecorate %16 0 Offset 0 +OpMemberDecorate %16 1 Offset 12 +OpDecorate %21 NonWritable +OpDecorate %21 DescriptorSet 0 +OpDecorate %21 Binding 1 +OpDecorate %22 Block +OpMemberDecorate %22 0 Offset 0 %2 = OpTypeVoid %4 = OpTypeBool %3 = OpConstantTrue %4 @@ -23,42 +23,48 @@ OpMemberDecorate %21 0 Offset 0 %5 = OpConstant %6 10 %8 = OpTypeInt 32 1 %7 = OpConstant %8 3 -%9 = OpConstant %6 2 -%11 = OpTypeFloat 32 -%10 = OpConstant %11 1.0 -%12 = OpConstantTrue %4 -%13 = OpTypeArray %11 %5 -%14 = OpTypeVector %11 3 -%15 = OpTypeStruct %14 %11 -%17 = OpTypePointer Workgroup %13 -%16 = OpVariable %17 Workgroup -%19 = OpTypePointer Workgroup %6 -%18 = OpVariable %19 Workgroup -%21 = OpTypeStruct %15 -%22 = OpTypePointer StorageBuffer %21 -%20 = OpVariable %22 StorageBuffer -%24 = OpTypePointer Function %11 -%26 = OpTypePointer Function %4 -%29 = OpTypeFunction %2 -%30 = OpTypePointer StorageBuffer %15 -%31 = OpConstant %6 0 -%34 = OpTypePointer Workgroup %11 -%35 = OpTypePointer StorageBuffer %11 -%36 = OpConstant %6 1 -%39 = OpConstant %6 3 -%41 = OpConstant %8 2 -%42 = OpConstant %6 256 -%28 = OpFunction %2 None %29 -%27 = OpLabel -%23 = OpVariable %24 Function %10 -%25 = OpVariable %26 Function %12 -%32 = OpAccessChain %30 %20 %31 -OpBranch %33 -%33 = OpLabel -%37 = OpAccessChain %35 %32 %36 -%38 = OpLoad %11 %37 -%40 = OpAccessChain %34 %16 %39 -OpStore %40 %38 -OpAtomicStore %18 %41 %42 %9 +%9 = OpConstant %8 2 +%10 = OpConstant %6 2 +%12 = OpTypeFloat 32 +%11 = OpConstant %12 1.0 +%13 = OpConstantTrue %4 +%14 = OpTypeArray %12 %5 +%15 = OpTypeVector %12 3 +%16 = OpTypeStruct %15 %12 +%18 = OpTypePointer Workgroup %14 +%17 = OpVariable %18 Workgroup +%20 = OpTypePointer Workgroup %6 +%19 = OpVariable %20 Workgroup +%22 = OpTypeStruct %16 +%23 = OpTypePointer StorageBuffer %22 +%21 = OpVariable %23 StorageBuffer +%25 = OpTypePointer Function %12 +%27 = OpTypePointer Function %4 +%30 = OpTypeFunction %2 +%31 = OpTypePointer StorageBuffer %16 +%32 = OpConstant %6 0 +%35 = OpTypePointer Workgroup %12 +%36 = OpTypePointer StorageBuffer %12 +%37 = OpConstant %6 1 +%40 = OpConstant %6 3 +%42 = OpTypePointer StorageBuffer %15 +%43 = OpTypePointer StorageBuffer %12 +%47 = OpConstant %6 256 +%29 = OpFunction %2 None %30 +%28 = OpLabel +%24 = OpVariable %25 Function %11 +%26 = OpVariable %27 Function %13 +%33 = OpAccessChain %31 %21 %32 +OpBranch %34 +%34 = OpLabel +%38 = OpAccessChain %36 %33 %37 +%39 = OpLoad %12 %38 +%41 = OpAccessChain %35 %17 %40 +OpStore %41 %39 +%44 = OpAccessChain %43 %33 %32 %32 +%45 = OpLoad %12 %44 +%46 = OpAccessChain %35 %17 %10 +OpStore %46 %45 +OpAtomicStore %19 %9 %47 %10 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/globals.wgsl b/tests/out/wgsl/globals.wgsl index cc514c1e4..a79623904 100644 --- a/tests/out/wgsl/globals.wgsl +++ b/tests/out/wgsl/globals.wgsl @@ -17,6 +17,8 @@ fn main() { let _e7 = alignment.v1_; wg[3] = _e7; + let _e12 = alignment.v3_.x; + wg[2] = _e12; atomicStore((&at_1), 2u); return; }