mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-12-01 19:23:38 +00:00
[wgsl-in] Avoid splatting all binary operator expressions (#2440)
* [wgsl-in] Avoid splatting all binary operator expressions Fixes #2439. * [wgsl-in] Expand binary_op_splat function comment
This commit is contained in:
parent
f6e99a4603
commit
3da9355125
@ -504,13 +504,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Insert splats, if needed by the non-'*' operations.
|
/// Insert splats, if needed by the non-'*' operations.
|
||||||
|
///
|
||||||
|
/// See the "Binary arithmetic expressions with mixed scalar and vector operands"
|
||||||
|
/// table in the WebGPU Shading Language specification for relevant operators.
|
||||||
|
///
|
||||||
|
/// Multiply is not handled here as backends are expected to handle vec*scalar
|
||||||
|
/// operations, so inserting splats into the IR increases size needlessly.
|
||||||
fn binary_op_splat(
|
fn binary_op_splat(
|
||||||
&mut self,
|
&mut self,
|
||||||
op: crate::BinaryOperator,
|
op: crate::BinaryOperator,
|
||||||
left: &mut Handle<crate::Expression>,
|
left: &mut Handle<crate::Expression>,
|
||||||
right: &mut Handle<crate::Expression>,
|
right: &mut Handle<crate::Expression>,
|
||||||
) -> Result<(), Error<'source>> {
|
) -> Result<(), Error<'source>> {
|
||||||
if op != crate::BinaryOperator::Multiply {
|
if matches!(
|
||||||
|
op,
|
||||||
|
crate::BinaryOperator::Add
|
||||||
|
| crate::BinaryOperator::Subtract
|
||||||
|
| crate::BinaryOperator::Divide
|
||||||
|
| crate::BinaryOperator::Modulo
|
||||||
|
) {
|
||||||
self.grow_types(*left)?.grow_types(*right)?;
|
self.grow_types(*left)?.grow_types(*right)?;
|
||||||
|
|
||||||
let left_size = match *self.resolved_inner(*left) {
|
let left_size = match *self.resolved_inner(*left) {
|
||||||
|
@ -387,6 +387,54 @@ fn parse_expressions() {
|
|||||||
}").unwrap();
|
}").unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn binary_expression_mixed_scalar_and_vector_operands() {
|
||||||
|
for (operand, expect_splat) in [
|
||||||
|
('<', false),
|
||||||
|
('>', false),
|
||||||
|
('&', false),
|
||||||
|
('|', false),
|
||||||
|
('+', true),
|
||||||
|
('-', true),
|
||||||
|
('*', false),
|
||||||
|
('/', true),
|
||||||
|
('%', true),
|
||||||
|
] {
|
||||||
|
let module = parse_str(&format!(
|
||||||
|
"
|
||||||
|
const some_vec = vec3<f32>(1.0, 1.0, 1.0);
|
||||||
|
@fragment
|
||||||
|
fn main() -> @location(0) vec4<f32> {{
|
||||||
|
if (all(1.0 {operand} some_vec)) {{
|
||||||
|
return vec4(0.0);
|
||||||
|
}}
|
||||||
|
return vec4(1.0);
|
||||||
|
}}
|
||||||
|
"
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let expressions = &&module.entry_points[0].function.expressions;
|
||||||
|
|
||||||
|
let found_expressions = expressions
|
||||||
|
.iter()
|
||||||
|
.filter(|&(_, e)| {
|
||||||
|
if let crate::Expression::Binary { left, .. } = *e {
|
||||||
|
matches!(
|
||||||
|
(expect_splat, &expressions[left]),
|
||||||
|
(false, &crate::Expression::Literal(crate::Literal::F32(..)))
|
||||||
|
| (true, &crate::Expression::Splat { .. })
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.count();
|
||||||
|
|
||||||
|
assert_eq!(found_expressions, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_pointers() {
|
fn parse_pointers() {
|
||||||
parse_str(
|
parse_str(
|
||||||
|
Loading…
Reference in New Issue
Block a user