[wgsl-in] Avoid splatting all binary operator expressions (#2440)

* [wgsl-in] Avoid splatting all binary operator expressions

Fixes #2439.

* [wgsl-in] Expand binary_op_splat function comment
This commit is contained in:
Fredrik Fornwall 2023-08-18 13:07:24 +02:00 committed by GitHub
parent f6e99a4603
commit 3da9355125
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 1 deletions

View File

@ -504,13 +504,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
} }
/// Insert splats, if needed by the non-'*' operations. /// Insert splats, if needed by the non-'*' operations.
///
/// See the "Binary arithmetic expressions with mixed scalar and vector operands"
/// table in the WebGPU Shading Language specification for relevant operators.
///
/// Multiply is not handled here as backends are expected to handle vec*scalar
/// operations, so inserting splats into the IR increases size needlessly.
fn binary_op_splat( fn binary_op_splat(
&mut self, &mut self,
op: crate::BinaryOperator, op: crate::BinaryOperator,
left: &mut Handle<crate::Expression>, left: &mut Handle<crate::Expression>,
right: &mut Handle<crate::Expression>, right: &mut Handle<crate::Expression>,
) -> Result<(), Error<'source>> { ) -> Result<(), Error<'source>> {
if op != crate::BinaryOperator::Multiply { if matches!(
op,
crate::BinaryOperator::Add
| crate::BinaryOperator::Subtract
| crate::BinaryOperator::Divide
| crate::BinaryOperator::Modulo
) {
self.grow_types(*left)?.grow_types(*right)?; self.grow_types(*left)?.grow_types(*right)?;
let left_size = match *self.resolved_inner(*left) { let left_size = match *self.resolved_inner(*left) {

View File

@ -387,6 +387,54 @@ fn parse_expressions() {
}").unwrap(); }").unwrap();
} }
#[test]
fn binary_expression_mixed_scalar_and_vector_operands() {
for (operand, expect_splat) in [
('<', false),
('>', false),
('&', false),
('|', false),
('+', true),
('-', true),
('*', false),
('/', true),
('%', true),
] {
let module = parse_str(&format!(
"
const some_vec = vec3<f32>(1.0, 1.0, 1.0);
@fragment
fn main() -> @location(0) vec4<f32> {{
if (all(1.0 {operand} some_vec)) {{
return vec4(0.0);
}}
return vec4(1.0);
}}
"
))
.unwrap();
let expressions = &&module.entry_points[0].function.expressions;
let found_expressions = expressions
.iter()
.filter(|&(_, e)| {
if let crate::Expression::Binary { left, .. } = *e {
matches!(
(expect_splat, &expressions[left]),
(false, &crate::Expression::Literal(crate::Literal::F32(..)))
| (true, &crate::Expression::Splat { .. })
)
} else {
false
}
})
.count();
assert_eq!(found_expressions, 1);
}
}
#[test] #[test]
fn parse_pointers() { fn parse_pointers() {
parse_str( parse_str(