diff --git a/CHANGELOG.md b/CHANGELOG.md index 0946e7099..cdca81a71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,10 @@ By @teoxoy [#6134](https://github.com/gfx-rs/wgpu/pull/6134). - Fix incorrect hlsl image output type conversion. By @atlv24 in [#6123](https://github.com/gfx-rs/wgpu/pull/6123) +#### Naga + +- Accept only `vec3` (not `vecN`) for the `cross` built-in. By @ErichDonGubler in [#6171](https://github.com/gfx-rs/wgpu/pull/6171). + #### General - If GL context creation fails retry with GLES. By @Rapdorian in [#5996](https://github.com/gfx-rs/wgpu/pull/5996) diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 1d1420aef..39c1ab749 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1172,8 +1172,8 @@ impl super::Validator { Sc { kind: Sk::Float, .. }, - .. - } => {} + size: vector_size, + } if fun != Mf::Cross || vector_size == crate::VectorSize::Tri => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } if arg1_ty != arg_ty { diff --git a/naga/tests/validation.rs b/naga/tests/validation.rs index f64b40884..296697986 100644 --- a/naga/tests/validation.rs +++ b/naga/tests/validation.rs @@ -260,3 +260,54 @@ fn emit_workgroup_uniform_load_result() { variant(true).expect("module should validate"); assert!(variant(false).is_err()); } + +#[cfg(feature = "wgsl-in")] +#[test] +fn bad_cross_builtin_args() { + let cases = [ + ( + "vec2(0., 1.)", + "\ +error: Entry point main at Compute is invalid + ┌─ wgsl:3:13 + │ +3 │ let a = cross(vec2(0., 1.), vec2(0., 1.)); + │ ^^^^^ naga::Expression [6] + │ + = Expression [6] is invalid + = Argument [0] to Cross as expression [2] has an invalid type. + +", + ), + ( + "vec4(0., 1., 2., 3.)", + "\ +error: Entry point main at Compute is invalid + ┌─ wgsl:3:13 + │ +3 │ let a = cross(vec4(0., 1., 2., 3.), vec4(0., 1., 2., 3.)); + │ ^^^^^ naga::Expression [10] + │ + = Expression [10] is invalid + = Argument [0] to Cross as expression [4] has an invalid type. + +", + ), + ]; + + for (invalid_arg, expected_err) in cases { + let source = format!( + "\ +@compute @workgroup_size(1) +fn main() {{ + let a = cross({invalid_arg}, {invalid_arg}); +}} +" + ); + let module = naga::front::wgsl::parse_str(&source).unwrap(); + let err = valid::Validator::new(Default::default(), valid::Capabilities::all()) + .validate_no_overrides(&module) + .expect_err("module should be invalid"); + assert_eq!(err.emit_to_string(&source), expected_err); + } +}