[msl] refactor chain of if blocks to match

This commit is contained in:
teoxoy 2024-04-15 21:19:15 +02:00 committed by Teodor Tanasoaia
parent 7078b0a061
commit 66d7387f0d

View File

@ -1853,133 +1853,145 @@ impl<W: Write> Writer<W> {
_ => {}
}
if fun == Mf::Distance && scalar_argument {
write!(self.out, "{NAMESPACE}::abs(")?;
self.put_expression(arg, context, false)?;
write!(self.out, " - ")?;
self.put_expression(arg1.unwrap(), context, false)?;
write!(self.out, ")")?;
} else if fun == Mf::FindLsb {
let scalar = context.resolve_type(arg).scalar().unwrap();
let constant = scalar.width * 8 + 1;
match fun {
Mf::Distance if scalar_argument => {
write!(self.out, "{NAMESPACE}::abs(")?;
self.put_expression(arg, context, false)?;
write!(self.out, " - ")?;
self.put_expression(arg1.unwrap(), context, false)?;
write!(self.out, ")")?;
}
Mf::FindLsb => {
let scalar = context.resolve_type(arg).scalar().unwrap();
let constant = scalar.width * 8 + 1;
write!(self.out, "((({NAMESPACE}::ctz(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ") + 1) % {constant}) - 1)")?;
} else if fun == Mf::FindMsb {
let inner = context.resolve_type(arg);
let scalar = inner.scalar().unwrap();
let constant = scalar.width * 8 - 1;
write!(
self.out,
"{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
)?;
if scalar.kind == crate::ScalarKind::Sint {
write!(self.out, "{NAMESPACE}::select(")?;
write!(self.out, "((({NAMESPACE}::ctz(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ~")?;
write!(self.out, ") + 1) % {constant}) - 1)")?;
}
Mf::FindMsb => {
let inner = context.resolve_type(arg);
let scalar = inner.scalar().unwrap();
let constant = scalar.width * 8 - 1;
write!(
self.out,
"{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
)?;
if scalar.kind == crate::ScalarKind::Sint {
write!(self.out, "{NAMESPACE}::select(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ~")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " < 0)")?;
} else {
self.put_expression(arg, context, true)?;
}
write!(self.out, "), ")?;
// or metal will complain that select is ambiguous
match *inner {
crate::TypeInner::Vector { size, scalar } => {
let size = back::vector_size_str(size);
let name = scalar.to_msl_name();
write!(self.out, "{name}{size}")?;
}
crate::TypeInner::Scalar(scalar) => {
let name = scalar.to_msl_name();
write!(self.out, "{name}")?;
}
_ => (),
}
write!(self.out, "(-1), ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " == 0 || ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " == -1)")?;
}
Mf::Unpack2x16float => {
write!(self.out, "float2(as_type<half2>(")?;
self.put_expression(arg, context, false)?;
write!(self.out, "))")?;
}
Mf::Pack2x16float => {
write!(self.out, "as_type<uint>(half2(")?;
self.put_expression(arg, context, false)?;
write!(self.out, "))")?;
}
Mf::ExtractBits => {
// The behavior of ExtractBits is undefined when offset + count > bit_width. We need
// to first sanitize the offset and count first. If we don't do this, Apple chips
// will return out-of-spec values if the extracted range is not within the bit width.
//
// This encodes the exact formula specified by the wgsl spec, without temporary values:
// https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
//
// w = sizeof(x) * 8
// o = min(offset, w)
// tmp = w - o
// c = min(count, tmp)
//
// bitfieldExtract(x, o, c)
//
// extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
write!(self.out, "{NAMESPACE}::extract_bits(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", {NAMESPACE}::min(")?;
self.put_expression(arg1.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
self.put_expression(arg2.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
self.put_expression(arg1.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u)))")?;
}
Mf::InsertBits => {
// The behavior of InsertBits has the same issue as ExtractBits.
//
// insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
write!(self.out, "{NAMESPACE}::insert_bits(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " < 0)")?;
} else {
self.put_expression(arg, context, true)?;
self.put_expression(arg1.unwrap(), context, true)?;
write!(self.out, ", {NAMESPACE}::min(")?;
self.put_expression(arg2.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
self.put_expression(arg3.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
self.put_expression(arg2.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u)))")?;
}
write!(self.out, "), ")?;
// or metal will complain that select is ambiguous
match *inner {
crate::TypeInner::Vector { size, scalar } => {
let size = back::vector_size_str(size);
let name = scalar.to_msl_name();
write!(self.out, "{name}{size}")?;
}
crate::TypeInner::Scalar(scalar) => {
let name = scalar.to_msl_name();
write!(self.out, "{name}")?;
}
_ => (),
Mf::Radians => {
write!(self.out, "((")?;
self.put_expression(arg, context, false)?;
write!(self.out, ") * 0.017453292519943295474)")?;
}
Mf::Degrees => {
write!(self.out, "((")?;
self.put_expression(arg, context, false)?;
write!(self.out, ") * 57.295779513082322865)")?;
}
Mf::Modf | Mf::Frexp => {
write!(self.out, "{fun_name}")?;
self.put_call_parameters(iter::once(arg), context)?;
}
_ => {
write!(self.out, "{NAMESPACE}::{fun_name}")?;
self.put_call_parameters(
iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
context,
)?;
}
write!(self.out, "(-1), ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " == 0 || ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " == -1)")?;
} else if fun == Mf::Unpack2x16float {
write!(self.out, "float2(as_type<half2>(")?;
self.put_expression(arg, context, false)?;
write!(self.out, "))")?;
} else if fun == Mf::Pack2x16float {
write!(self.out, "as_type<uint>(half2(")?;
self.put_expression(arg, context, false)?;
write!(self.out, "))")?;
} else if fun == Mf::ExtractBits {
// The behavior of ExtractBits is undefined when offset + count > bit_width. We need
// to first sanitize the offset and count first. If we don't do this, Apple chips
// will return out-of-spec values if the extracted range is not within the bit width.
//
// This encodes the exact formula specified by the wgsl spec, without temporary values:
// https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
//
// w = sizeof(x) * 8
// o = min(offset, w)
// tmp = w - o
// c = min(count, tmp)
//
// bitfieldExtract(x, o, c)
//
// extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
write!(self.out, "{NAMESPACE}::extract_bits(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", {NAMESPACE}::min(")?;
self.put_expression(arg1.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
self.put_expression(arg2.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
self.put_expression(arg1.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u)))")?;
} else if fun == Mf::InsertBits {
// The behavior of InsertBits has the same issue as ExtractBits.
//
// insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
write!(self.out, "{NAMESPACE}::insert_bits(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ")?;
self.put_expression(arg1.unwrap(), context, true)?;
write!(self.out, ", {NAMESPACE}::min(")?;
self.put_expression(arg2.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
self.put_expression(arg3.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
self.put_expression(arg2.unwrap(), context, true)?;
write!(self.out, ", {scalar_bits}u)))")?;
} else if fun == Mf::Radians {
write!(self.out, "((")?;
self.put_expression(arg, context, false)?;
write!(self.out, ") * 0.017453292519943295474)")?;
} else if fun == Mf::Degrees {
write!(self.out, "((")?;
self.put_expression(arg, context, false)?;
write!(self.out, ") * 57.295779513082322865)")?;
} else if fun == Mf::Modf || fun == Mf::Frexp {
write!(self.out, "{fun_name}")?;
self.put_call_parameters(iter::once(arg), context)?;
} else {
write!(self.out, "{NAMESPACE}::{fun_name}")?;
self.put_call_parameters(
iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
context,
)?;
}
}
crate::Expression::As {