msl: fix packed vec access (#1634)

This commit is contained in:
Dzmitry Malyshau 2021-12-28 00:13:13 -05:00 committed by GitHub
parent 5a26606a09
commit 2738ad80b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 130 additions and 96 deletions

View File

@ -47,20 +47,20 @@ impl<'a> Display for TypeContext<'a> {
// work around Metal toolchain bug with `uint` typedef
crate::ScalarKind::Uint => write!(out, "{}::uint", NAMESPACE),
_ => {
let kind_str = scalar_kind_string(kind);
let kind_str = kind.to_msl_name();
write!(out, "{}", kind_str)
}
}
}
crate::TypeInner::Atomic { kind, .. } => {
write!(out, "{}::atomic_{}", NAMESPACE, scalar_kind_string(kind))
write!(out, "{}::atomic_{}", NAMESPACE, kind.to_msl_name())
}
crate::TypeInner::Vector { size, kind, .. } => {
write!(
out,
"{}::{}{}",
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
back::vector_size_str(size),
)
}
@ -69,7 +69,7 @@ impl<'a> Display for TypeContext<'a> {
out,
"{}::{}{}x{}",
NAMESPACE,
scalar_kind_string(crate::ScalarKind::Float),
crate::ScalarKind::Float.to_msl_name(),
back::vector_size_str(columns),
back::vector_size_str(rows),
)
@ -96,7 +96,7 @@ impl<'a> Display for TypeContext<'a> {
Some(name) => name,
None => return Ok(()),
};
write!(out, "{} {}&", class_name, scalar_kind_string(kind),)
write!(out, "{} {}&", class_name, kind.to_msl_name(),)
}
crate::TypeInner::ValuePointer {
size: Some(size),
@ -113,7 +113,7 @@ impl<'a> Display for TypeContext<'a> {
"{} {}::{}{}&",
class_name,
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
back::vector_size_str(size),
)
}
@ -178,7 +178,7 @@ impl<'a> Display for TypeContext<'a> {
("texture", "", format.into(), access)
}
};
let base_name = scalar_kind_string(kind);
let base_name = kind.to_msl_name();
let array_str = if arrayed { "_array" } else { "" };
write!(
out,
@ -316,12 +316,14 @@ pub struct Writer<W> {
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
}
fn scalar_kind_string(kind: crate::ScalarKind) -> &'static str {
match kind {
crate::ScalarKind::Float => "float",
crate::ScalarKind::Sint => "int",
crate::ScalarKind::Uint => "uint",
crate::ScalarKind::Bool => "bool",
impl crate::ScalarKind {
fn to_msl_name(self) -> &'static str {
match self {
Self::Float => "float",
Self::Sint => "int",
Self::Uint => "uint",
Self::Bool => "bool",
}
}
}
@ -481,6 +483,29 @@ impl<'a> ExpressionContext<'a> {
) -> Option<index::IndexableLength> {
index::access_needs_check(base, index, self.module, self.function, self.info)
}
// Because packed vectors such as `packed_float3` cannot be directly loaded,
// we convert them to unpacked vectors like `float3` on load.
fn get_packed_vec_kind(
&self,
expr_handle: Handle<crate::Expression>,
) -> Option<crate::ScalarKind> {
match self.function.expressions[expr_handle] {
crate::Expression::AccessIndex { base, index } => {
let ty = match *self.resolve_type(base) {
crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
ref ty => ty,
};
match *ty {
crate::TypeInner::Struct {
ref members, span, ..
} => should_pack_struct_member(members, span, index as usize, self.module),
_ => None,
}
}
_ => None,
}
}
}
struct StatementContext<'a> {
@ -652,7 +677,7 @@ impl<W: Write> Writer<W> {
) -> BackendResult {
match context.module.types[ty].inner {
crate::TypeInner::Scalar { width: 4, kind } if components.len() == 1 => {
write!(self.out, "{}", scalar_kind_string(kind))?;
write!(self.out, "{}", kind.to_msl_name())?;
self.put_call_parameters(components.iter().cloned(), context)?;
}
crate::TypeInner::Vector { size, kind, .. } => {
@ -660,7 +685,7 @@ impl<W: Write> Writer<W> {
self.out,
"{}::{}{}",
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
back::vector_size_str(size)
)?;
self.put_call_parameters(components.iter().cloned(), context)?;
@ -671,7 +696,7 @@ impl<W: Write> Writer<W> {
self.out,
"{}::{}{}x{}",
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
back::vector_size_str(columns),
back::vector_size_str(rows)
)?;
@ -845,7 +870,7 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Scalar { kind, .. } => kind,
_ => return Err(Error::Validation),
};
let scalar = scalar_kind_string(scalar_kind);
let scalar = scalar_kind.to_msl_name();
let size = back::vector_size_str(size);
write!(self.out, "{}::{}{}(", NAMESPACE, scalar, size)?;
@ -1246,7 +1271,7 @@ impl<W: Write> Writer<W> {
kind,
convert,
} => {
let scalar = scalar_kind_string(kind);
let scalar = kind.to_msl_name();
let (src_kind, src_width) = match *context.resolve_type(expr) {
crate::TypeInner::Scalar { kind, width }
| crate::TypeInner::Vector { kind, width, .. } => (kind, width),
@ -1487,7 +1512,15 @@ impl<W: Write> Writer<W> {
write!(self.out, ".{}", name)?;
}
crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
self.put_access_chain(base, policy, context)?;
let wrap_packed_vec_scalar_kind = context.get_packed_vec_kind(base);
//Note: this doesn't work for left-hand side
if let Some(scalar_kind) = wrap_packed_vec_scalar_kind {
write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?;
self.put_access_chain(base, policy, context)?;
write!(self.out, ")")?;
} else {
self.put_access_chain(base, policy, context)?;
}
write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
}
_ => {
@ -1614,23 +1647,7 @@ impl<W: Write> Writer<W> {
policy: index::BoundsCheckPolicy,
context: &ExpressionContext,
) -> BackendResult {
// Because packed vectors such as `packed_float3` cannot be directly multipied by
// matrices, we convert them to unpacked vectors like `float3` on load.
let wrap_packed_vec_scalar_kind = match context.function.expressions[pointer] {
crate::Expression::AccessIndex { base, index } => {
let ty = match *context.resolve_type(base) {
crate::TypeInner::Pointer { base, .. } => &context.module.types[base].inner,
ref ty => ty,
};
match *ty {
crate::TypeInner::Struct {
ref members, span, ..
} => should_pack_struct_member(members, span, index as usize, context.module),
_ => None,
}
}
_ => None,
};
let wrap_packed_vec_scalar_kind = context.get_packed_vec_kind(pointer);
let is_atomic = match *context.resolve_type(pointer) {
crate::TypeInner::Pointer { base, .. } => match context.module.types[base].inner {
crate::TypeInner::Atomic { .. } => true,
@ -1640,12 +1657,7 @@ impl<W: Write> Writer<W> {
};
if let Some(scalar_kind) = wrap_packed_vec_scalar_kind {
write!(
self.out,
"{}::{}3(",
NAMESPACE,
scalar_kind_string(scalar_kind)
)?;
write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ")")?;
} else if is_atomic {
@ -1761,15 +1773,22 @@ impl<W: Write> Writer<W> {
};
write!(self.out, "{}", ty_name)?;
}
TypeResolution::Value(crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
}) => {
// work around Metal toolchain bug with `uint` typedef
write!(self.out, "{}::uint", NAMESPACE)?;
}
TypeResolution::Value(crate::TypeInner::Scalar { kind, .. }) => {
write!(self.out, "{}", scalar_kind_string(kind))?;
write!(self.out, "{}", kind.to_msl_name())?;
}
TypeResolution::Value(crate::TypeInner::Vector { size, kind, .. }) => {
write!(
self.out,
"{}::{}{}",
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
back::vector_size_str(size)
)?;
}
@ -1778,7 +1797,7 @@ impl<W: Write> Writer<W> {
self.out,
"{}::{}{}x{}",
NAMESPACE,
scalar_kind_string(crate::ScalarKind::Float),
crate::ScalarKind::Float.to_msl_name(),
back::vector_size_str(columns),
back::vector_size_str(rows),
)?;
@ -2360,7 +2379,7 @@ impl<W: Write> Writer<W> {
"{}{}::packed_{}3 {};",
back::INDENT,
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
member_name
)?;
}

View File

@ -16,6 +16,7 @@ var<storage> alignment: Foo;
[[stage(compute), workgroup_size(1)]]
fn main() {
wg[3] = alignment.v1;
wg[2] = alignment.v3.x;
atomicStore(&at, 2u);
// Valid, Foo and at is in function scope

View File

@ -21,6 +21,8 @@ void main() {
bool at = true;
float _e7 = _group_0_binding_1_cs.v1_;
wg[3] = _e7;
float _e12 = _group_0_binding_1_cs.v3_.x;
wg[2] = _e12;
at_1 = 2u;
return;
}

View File

@ -17,6 +17,8 @@ void main()
float _expr7 = asfloat(alignment.Load(12));
wg[3] = _expr7;
float _expr12 = asfloat(alignment.Load(0+0));
wg[2] = _expr12;
at_1 = 2u;
return;
}

View File

@ -45,7 +45,7 @@ kernel void main_(
metal::float2 pos;
metal::float2 vel;
metal::uint i = 0u;
uint index = global_invocation_id.x;
metal::uint index = global_invocation_id.x;
if (index >= NUM_PARTICLES) {
return;
}

View File

@ -20,6 +20,8 @@ kernel void main_(
bool at = true;
float _e7 = alignment.v1_;
wg.inner[3] = _e7;
float _e12 = metal::float3(alignment.v3_).x;
wg.inner[2] = _e12;
metal::atomic_store_explicit(&at_1, 2u, metal::memory_order_relaxed);
return;
}

View File

@ -63,7 +63,7 @@ fragment fs_mainOutput fs_main(
}
loop_init = false;
metal::uint _e12 = i;
uint _e15 = u_globals.num_lights.x;
metal::uint _e15 = u_globals.num_lights.x;
if (_e12 >= metal::min(_e15, c_max_lights)) {
break;
}

View File

@ -1,21 +1,21 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 43
; Bound: 48
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %28 "main"
OpExecutionMode %28 LocalSize 1 1 1
OpDecorate %13 ArrayStride 4
OpMemberDecorate %15 0 Offset 0
OpMemberDecorate %15 1 Offset 12
OpDecorate %20 NonWritable
OpDecorate %20 DescriptorSet 0
OpDecorate %20 Binding 1
OpDecorate %21 Block
OpMemberDecorate %21 0 Offset 0
OpEntryPoint GLCompute %29 "main"
OpExecutionMode %29 LocalSize 1 1 1
OpDecorate %14 ArrayStride 4
OpMemberDecorate %16 0 Offset 0
OpMemberDecorate %16 1 Offset 12
OpDecorate %21 NonWritable
OpDecorate %21 DescriptorSet 0
OpDecorate %21 Binding 1
OpDecorate %22 Block
OpMemberDecorate %22 0 Offset 0
%2 = OpTypeVoid
%4 = OpTypeBool
%3 = OpConstantTrue %4
@ -23,42 +23,48 @@ OpMemberDecorate %21 0 Offset 0
%5 = OpConstant %6 10
%8 = OpTypeInt 32 1
%7 = OpConstant %8 3
%9 = OpConstant %6 2
%11 = OpTypeFloat 32
%10 = OpConstant %11 1.0
%12 = OpConstantTrue %4
%13 = OpTypeArray %11 %5
%14 = OpTypeVector %11 3
%15 = OpTypeStruct %14 %11
%17 = OpTypePointer Workgroup %13
%16 = OpVariable %17 Workgroup
%19 = OpTypePointer Workgroup %6
%18 = OpVariable %19 Workgroup
%21 = OpTypeStruct %15
%22 = OpTypePointer StorageBuffer %21
%20 = OpVariable %22 StorageBuffer
%24 = OpTypePointer Function %11
%26 = OpTypePointer Function %4
%29 = OpTypeFunction %2
%30 = OpTypePointer StorageBuffer %15
%31 = OpConstant %6 0
%34 = OpTypePointer Workgroup %11
%35 = OpTypePointer StorageBuffer %11
%36 = OpConstant %6 1
%39 = OpConstant %6 3
%41 = OpConstant %8 2
%42 = OpConstant %6 256
%28 = OpFunction %2 None %29
%27 = OpLabel
%23 = OpVariable %24 Function %10
%25 = OpVariable %26 Function %12
%32 = OpAccessChain %30 %20 %31
OpBranch %33
%33 = OpLabel
%37 = OpAccessChain %35 %32 %36
%38 = OpLoad %11 %37
%40 = OpAccessChain %34 %16 %39
OpStore %40 %38
OpAtomicStore %18 %41 %42 %9
%9 = OpConstant %8 2
%10 = OpConstant %6 2
%12 = OpTypeFloat 32
%11 = OpConstant %12 1.0
%13 = OpConstantTrue %4
%14 = OpTypeArray %12 %5
%15 = OpTypeVector %12 3
%16 = OpTypeStruct %15 %12
%18 = OpTypePointer Workgroup %14
%17 = OpVariable %18 Workgroup
%20 = OpTypePointer Workgroup %6
%19 = OpVariable %20 Workgroup
%22 = OpTypeStruct %16
%23 = OpTypePointer StorageBuffer %22
%21 = OpVariable %23 StorageBuffer
%25 = OpTypePointer Function %12
%27 = OpTypePointer Function %4
%30 = OpTypeFunction %2
%31 = OpTypePointer StorageBuffer %16
%32 = OpConstant %6 0
%35 = OpTypePointer Workgroup %12
%36 = OpTypePointer StorageBuffer %12
%37 = OpConstant %6 1
%40 = OpConstant %6 3
%42 = OpTypePointer StorageBuffer %15
%43 = OpTypePointer StorageBuffer %12
%47 = OpConstant %6 256
%29 = OpFunction %2 None %30
%28 = OpLabel
%24 = OpVariable %25 Function %11
%26 = OpVariable %27 Function %13
%33 = OpAccessChain %31 %21 %32
OpBranch %34
%34 = OpLabel
%38 = OpAccessChain %36 %33 %37
%39 = OpLoad %12 %38
%41 = OpAccessChain %35 %17 %40
OpStore %41 %39
%44 = OpAccessChain %43 %33 %32 %32
%45 = OpLoad %12 %44
%46 = OpAccessChain %35 %17 %10
OpStore %46 %45
OpAtomicStore %19 %9 %47 %10
OpReturn
OpFunctionEnd

View File

@ -17,6 +17,8 @@ fn main() {
let _e7 = alignment.v1_;
wg[3] = _e7;
let _e12 = alignment.v3_.x;
wg[2] = _e12;
atomicStore((&at_1), 2u);
return;
}