From 66d7387f0d7ad6b264a1e29539de3c2c0bd38b93 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Mon, 15 Apr 2024 21:19:15 +0200 Subject: [PATCH] [msl] refactor chain of if blocks to `match` --- naga/src/back/msl/writer.rs | 254 +++++++++++++++++++----------------- 1 file changed, 133 insertions(+), 121 deletions(-) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index d7a5413e3..8c3216a05 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1853,133 +1853,145 @@ impl Writer { _ => {} } - if fun == Mf::Distance && scalar_argument { - write!(self.out, "{NAMESPACE}::abs(")?; - self.put_expression(arg, context, false)?; - write!(self.out, " - ")?; - self.put_expression(arg1.unwrap(), context, false)?; - write!(self.out, ")")?; - } else if fun == Mf::FindLsb { - let scalar = context.resolve_type(arg).scalar().unwrap(); - let constant = scalar.width * 8 + 1; + match fun { + Mf::Distance if scalar_argument => { + write!(self.out, "{NAMESPACE}::abs(")?; + self.put_expression(arg, context, false)?; + write!(self.out, " - ")?; + self.put_expression(arg1.unwrap(), context, false)?; + write!(self.out, ")")?; + } + Mf::FindLsb => { + let scalar = context.resolve_type(arg).scalar().unwrap(); + let constant = scalar.width * 8 + 1; - write!(self.out, "((({NAMESPACE}::ctz(")?; - self.put_expression(arg, context, true)?; - write!(self.out, ") + 1) % {constant}) - 1)")?; - } else if fun == Mf::FindMsb { - let inner = context.resolve_type(arg); - let scalar = inner.scalar().unwrap(); - let constant = scalar.width * 8 - 1; - - write!( - self.out, - "{NAMESPACE}::select({constant} - {NAMESPACE}::clz(" - )?; - - if scalar.kind == crate::ScalarKind::Sint { - write!(self.out, "{NAMESPACE}::select(")?; + write!(self.out, "((({NAMESPACE}::ctz(")?; self.put_expression(arg, context, true)?; - write!(self.out, ", ~")?; + write!(self.out, ") + 1) % {constant}) - 1)")?; + } + Mf::FindMsb => { + let inner = context.resolve_type(arg); + let scalar = inner.scalar().unwrap(); + let constant = scalar.width * 8 - 1; + + write!( + self.out, + "{NAMESPACE}::select({constant} - {NAMESPACE}::clz(" + )?; + + if scalar.kind == crate::ScalarKind::Sint { + write!(self.out, "{NAMESPACE}::select(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ~")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " < 0)")?; + } else { + self.put_expression(arg, context, true)?; + } + + write!(self.out, "), ")?; + + // or metal will complain that select is ambiguous + match *inner { + crate::TypeInner::Vector { size, scalar } => { + let size = back::vector_size_str(size); + let name = scalar.to_msl_name(); + write!(self.out, "{name}{size}")?; + } + crate::TypeInner::Scalar(scalar) => { + let name = scalar.to_msl_name(); + write!(self.out, "{name}")?; + } + _ => (), + } + + write!(self.out, "(-1), ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " == 0 || ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " == -1)")?; + } + Mf::Unpack2x16float => { + write!(self.out, "float2(as_type(")?; + self.put_expression(arg, context, false)?; + write!(self.out, "))")?; + } + Mf::Pack2x16float => { + write!(self.out, "as_type(half2(")?; + self.put_expression(arg, context, false)?; + write!(self.out, "))")?; + } + Mf::ExtractBits => { + // The behavior of ExtractBits is undefined when offset + count > bit_width. We need + // to first sanitize the offset and count first. If we don't do this, Apple chips + // will return out-of-spec values if the extracted range is not within the bit width. + // + // This encodes the exact formula specified by the wgsl spec, without temporary values: + // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin + // + // w = sizeof(x) * 8 + // o = min(offset, w) + // tmp = w - o + // c = min(count, tmp) + // + // bitfieldExtract(x, o, c) + // + // extract_bits(e, min(offset, w), min(count, w - min(offset, w)))) + + let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8; + + write!(self.out, "{NAMESPACE}::extract_bits(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", {NAMESPACE}::min(")?; + self.put_expression(arg1.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?; + self.put_expression(arg2.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?; + self.put_expression(arg1.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u)))")?; + } + Mf::InsertBits => { + // The behavior of InsertBits has the same issue as ExtractBits. + // + // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w)))) + + let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8; + + write!(self.out, "{NAMESPACE}::insert_bits(")?; self.put_expression(arg, context, true)?; write!(self.out, ", ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " < 0)")?; - } else { - self.put_expression(arg, context, true)?; + self.put_expression(arg1.unwrap(), context, true)?; + write!(self.out, ", {NAMESPACE}::min(")?; + self.put_expression(arg2.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?; + self.put_expression(arg3.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?; + self.put_expression(arg2.unwrap(), context, true)?; + write!(self.out, ", {scalar_bits}u)))")?; } - - write!(self.out, "), ")?; - - // or metal will complain that select is ambiguous - match *inner { - crate::TypeInner::Vector { size, scalar } => { - let size = back::vector_size_str(size); - let name = scalar.to_msl_name(); - write!(self.out, "{name}{size}")?; - } - crate::TypeInner::Scalar(scalar) => { - let name = scalar.to_msl_name(); - write!(self.out, "{name}")?; - } - _ => (), + Mf::Radians => { + write!(self.out, "((")?; + self.put_expression(arg, context, false)?; + write!(self.out, ") * 0.017453292519943295474)")?; + } + Mf::Degrees => { + write!(self.out, "((")?; + self.put_expression(arg, context, false)?; + write!(self.out, ") * 57.295779513082322865)")?; + } + Mf::Modf | Mf::Frexp => { + write!(self.out, "{fun_name}")?; + self.put_call_parameters(iter::once(arg), context)?; + } + _ => { + write!(self.out, "{NAMESPACE}::{fun_name}")?; + self.put_call_parameters( + iter::once(arg).chain(arg1).chain(arg2).chain(arg3), + context, + )?; } - - write!(self.out, "(-1), ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " == 0 || ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " == -1)")?; - } else if fun == Mf::Unpack2x16float { - write!(self.out, "float2(as_type(")?; - self.put_expression(arg, context, false)?; - write!(self.out, "))")?; - } else if fun == Mf::Pack2x16float { - write!(self.out, "as_type(half2(")?; - self.put_expression(arg, context, false)?; - write!(self.out, "))")?; - } else if fun == Mf::ExtractBits { - // The behavior of ExtractBits is undefined when offset + count > bit_width. We need - // to first sanitize the offset and count first. If we don't do this, Apple chips - // will return out-of-spec values if the extracted range is not within the bit width. - // - // This encodes the exact formula specified by the wgsl spec, without temporary values: - // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin - // - // w = sizeof(x) * 8 - // o = min(offset, w) - // tmp = w - o - // c = min(count, tmp) - // - // bitfieldExtract(x, o, c) - // - // extract_bits(e, min(offset, w), min(count, w - min(offset, w)))) - - let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8; - - write!(self.out, "{NAMESPACE}::extract_bits(")?; - self.put_expression(arg, context, true)?; - write!(self.out, ", {NAMESPACE}::min(")?; - self.put_expression(arg1.unwrap(), context, true)?; - write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?; - self.put_expression(arg2.unwrap(), context, true)?; - write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?; - self.put_expression(arg1.unwrap(), context, true)?; - write!(self.out, ", {scalar_bits}u)))")?; - } else if fun == Mf::InsertBits { - // The behavior of InsertBits has the same issue as ExtractBits. - // - // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w)))) - - let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8; - - write!(self.out, "{NAMESPACE}::insert_bits(")?; - self.put_expression(arg, context, true)?; - write!(self.out, ", ")?; - self.put_expression(arg1.unwrap(), context, true)?; - write!(self.out, ", {NAMESPACE}::min(")?; - self.put_expression(arg2.unwrap(), context, true)?; - write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?; - self.put_expression(arg3.unwrap(), context, true)?; - write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?; - self.put_expression(arg2.unwrap(), context, true)?; - write!(self.out, ", {scalar_bits}u)))")?; - } else if fun == Mf::Radians { - write!(self.out, "((")?; - self.put_expression(arg, context, false)?; - write!(self.out, ") * 0.017453292519943295474)")?; - } else if fun == Mf::Degrees { - write!(self.out, "((")?; - self.put_expression(arg, context, false)?; - write!(self.out, ") * 57.295779513082322865)")?; - } else if fun == Mf::Modf || fun == Mf::Frexp { - write!(self.out, "{fun_name}")?; - self.put_call_parameters(iter::once(arg), context)?; - } else { - write!(self.out, "{NAMESPACE}::{fun_name}")?; - self.put_call_parameters( - iter::once(arg).chain(arg1).chain(arg2).chain(arg3), - context, - )?; } } crate::Expression::As {