fix(wgsl): narrow accepted args. of cross to vec3<$float>

This commit is contained in:
Erich Gubler 2024-08-27 14:09:00 -04:00
parent 327b92e92b
commit 7164f3eb4e
3 changed files with 57 additions and 2 deletions

View File

@ -79,6 +79,10 @@ By @teoxoy [#6134](https://github.com/gfx-rs/wgpu/pull/6134).
- Fix incorrect hlsl image output type conversion. By @atlv24 in [#6123](https://github.com/gfx-rs/wgpu/pull/6123)
#### Naga
- Accept only `vec3` (not `vecN`) for the `cross` built-in. By @ErichDonGubler in [#6171](https://github.com/gfx-rs/wgpu/pull/6171).
#### General
- If GL context creation fails retry with GLES. By @Rapdorian in [#5996](https://github.com/gfx-rs/wgpu/pull/5996)

View File

@ -1172,8 +1172,8 @@ impl super::Validator {
Sc {
kind: Sk::Float, ..
},
..
} => {}
size: vector_size,
} if fun != Mf::Cross || vector_size == crate::VectorSize::Tri => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {

View File

@ -260,3 +260,54 @@ fn emit_workgroup_uniform_load_result() {
variant(true).expect("module should validate");
assert!(variant(false).is_err());
}
#[cfg(feature = "wgsl-in")]
#[test]
fn bad_cross_builtin_args() {
let cases = [
(
"vec2(0., 1.)",
"\
error: Entry point main at Compute is invalid
wgsl:3:13
3 let a = cross(vec2(0., 1.), vec2(0., 1.));
^^^^^ naga::Expression [6]
= Expression [6] is invalid
= Argument [0] to Cross as expression [2] has an invalid type.
",
),
(
"vec4(0., 1., 2., 3.)",
"\
error: Entry point main at Compute is invalid
wgsl:3:13
3 let a = cross(vec4(0., 1., 2., 3.), vec4(0., 1., 2., 3.));
^^^^^ naga::Expression [10]
= Expression [10] is invalid
= Argument [0] to Cross as expression [4] has an invalid type.
",
),
];
for (invalid_arg, expected_err) in cases {
let source = format!(
"\
@compute @workgroup_size(1)
fn main() {{
let a = cross({invalid_arg}, {invalid_arg});
}}
"
);
let module = naga::front::wgsl::parse_str(&source).unwrap();
let err = valid::Validator::new(Default::default(), valid::Capabilities::all())
.validate_no_overrides(&module)
.expect_err("module should be invalid");
assert_eq!(err.emit_to_string(&source), expected_err);
}
}