From 2a2c851c40aa81390c93a304a3f79841ec886449 Mon Sep 17 00:00:00 2001
From: Andy Leiserson <aleiserson@mozilla.com>
Date: Thu, 27 Mar 2025 15:05:31 -0700
Subject: [PATCH] [naga] Two structs with the same members are not equivalent

Fixes #5796
---
 naga/src/arena/unique_arena.rs          |   3 +-
 naga/src/front/wgsl/lower/conversion.rs |   2 +-
 naga/src/front/wgsl/lower/mod.rs        |   7 +-
 naga/src/ir/mod.rs                      |  28 +++-
 naga/src/proc/mod.rs                    |  10 +-
 naga/src/proc/type_methods.rs           |  30 ++++-
 naga/src/proc/typifier.rs               |  45 ++++++-
 naga/src/valid/compose.rs               |  12 +-
 naga/src/valid/function.rs              |  61 +++++----
 naga/src/valid/interface.rs             |   7 +-
 naga/src/valid/mod.rs                   |  12 +-
 naga/tests/naga/wgsl_errors.rs          | 164 ++++++++++++++++++++++--
 12 files changed, 302 insertions(+), 79 deletions(-)

diff --git a/naga/src/arena/unique_arena.rs b/naga/src/arena/unique_arena.rs
index c93fb6bb6..a5db7d339 100644
--- a/naga/src/arena/unique_arena.rs
+++ b/naga/src/arena/unique_arena.rs
@@ -16,7 +16,8 @@ use crate::{FastIndexSet, Span};
 /// The element type must implement `Eq` and `Hash`. Insertions of equivalent
 /// elements, according to `Eq`, all return the same `Handle`.
 ///
-/// Once inserted, elements may not be mutated.
+/// Once inserted, elements generally may not be mutated, although a `replace`
+/// method exists to support rare cases.
 ///
 /// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like,
 /// `UniqueArena` is `HashSet`-like.
diff --git a/naga/src/front/wgsl/lower/conversion.rs b/naga/src/front/wgsl/lower/conversion.rs
index 56095b0a3..bfd0fee91 100644
--- a/naga/src/front/wgsl/lower/conversion.rs
+++ b/naga/src/front/wgsl/lower/conversion.rs
@@ -46,7 +46,7 @@ impl<'source> super::ExpressionContext<'source, '_, '_> {
         }
 
         // If `expr` already has the requested type, we're done.
-        if expr_inner.non_struct_equivalent(goal_inner, types) {
+        if self.module.compare_types(expr_resolution, goal_ty) {
             return Ok(expr);
         }
 
diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs
index 997bd035b..53f7d0dc8 100644
--- a/naga/src/front/wgsl/lower/mod.rs
+++ b/naga/src/front/wgsl/lower/mod.rs
@@ -1311,9 +1311,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
                     })?;
 
                 let init_ty = ectx.register_type(init)?;
-                let explicit_inner = &ectx.module.types[explicit_ty].inner;
-                let init_inner = &ectx.module.types[init_ty].inner;
-                if !explicit_inner.non_struct_equivalent(init_inner, &ectx.module.types) {
+                if !ectx.module.compare_types(
+                    &crate::proc::TypeResolution::Handle(explicit_ty),
+                    &crate::proc::TypeResolution::Handle(init_ty),
+                ) {
                     return Err(Box::new(Error::InitializationTypeMismatch {
                         name: name.span,
                         expected: ectx.type_to_string(explicit_ty),
diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs
index fa7bf6379..e42cb86c7 100644
--- a/naga/src/ir/mod.rs
+++ b/naga/src/ir/mod.rs
@@ -636,6 +636,15 @@ pub struct Type {
 }
 
 /// Enum with additional information, depending on the kind of type.
+///
+/// Comparison using `==` is not reliable in the case of [`Pointer`],
+/// [`ValuePointer`], or [`Struct`] variants. For these variants,
+/// use [`TypeInner::non_struct_equivalent`] or [`compare_types`].
+///
+/// [`compare_types`]: crate::proc::compare_types
+/// [`ValuePointer`]: TypeInner::ValuePointer
+/// [`Pointer`]: TypeInner::Pointer
+/// [`Struct`]: TypeInner::Struct
 #[derive(Clone, Debug, Eq, Hash, PartialEq)]
 #[cfg_attr(feature = "serialize", derive(Serialize))]
 #[cfg_attr(feature = "deserialize", derive(Deserialize))]
@@ -656,8 +665,9 @@ pub enum TypeInner {
     /// Pointer to another type.
     ///
     /// Pointers to scalars and vectors should be treated as equivalent to
-    /// [`ValuePointer`] types. Use the [`TypeInner::equivalent`] method to
-    /// compare types in a way that treats pointers correctly.
+    /// [`ValuePointer`] types. Use either [`TypeInner::non_struct_equivalent`]
+    /// or [`compare_types`] to compare types in a way that treats pointers
+    /// correctly.
     ///
     /// ## Pointers to non-`SIZED` types
     ///
@@ -679,6 +689,7 @@ pub enum TypeInner {
     /// [`ValuePointer`]: TypeInner::ValuePointer
     /// [`GlobalVariable`]: Expression::GlobalVariable
     /// [`AccessIndex`]: Expression::AccessIndex
+    /// [`compare_types`]: crate::proc::compare_types
     Pointer {
         base: Handle<Type>,
         space: AddressSpace,
@@ -690,12 +701,13 @@ pub enum TypeInner {
     /// `Scalar` or `Vector` type. This is for use in [`TypeResolution::Value`]
     /// variants; see the documentation for [`TypeResolution`] for details.
     ///
-    /// Use the [`TypeInner::equivalent`] method to compare types that could be
-    /// pointers, to ensure that `Pointer` and `ValuePointer` types are
-    /// recognized as equivalent.
+    /// Use [`TypeInner::non_struct_equivalent`] or [`compare_types`] to compare
+    /// types that could be pointers, to ensure that `Pointer` and
+    /// `ValuePointer` types are recognized as equivalent.
     ///
     /// [`TypeResolution`]: crate::proc::TypeResolution
     /// [`TypeResolution::Value`]: crate::proc::TypeResolution::Value
+    /// [`compare_types`]: crate::proc::compare_types
     ValuePointer {
         size: Option<VectorSize>,
         scalar: Scalar,
@@ -744,9 +756,15 @@ pub enum TypeInner {
     /// struct, which may be a dynamically sized [`Array`]. The
     /// `Struct` type itself is `SIZED` when all its members are `SIZED`.
     ///
+    /// Two structure types with different names are not equivalent. Because
+    /// this variant does not contain the name, it is not possible to use it
+    /// to compare struct types. Use [`compare_types`] to compare two types
+    /// that may be structs.
+    ///
     /// [`DATA`]: crate::valid::TypeFlags::DATA
     /// [`SIZED`]: crate::∅TypeFlags::SIZED
     /// [`Array`]: TypeInner::Array
+    /// [`compare_types`]: crate::proc::compare_types
     Struct {
         members: Vec<StructMember>,
         //TODO: should this be unaligned?
diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs
index 4ca4da554..23743b1f4 100644
--- a/naga/src/proc/mod.rs
+++ b/naga/src/proc/mod.rs
@@ -23,7 +23,7 @@ pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
 pub use terminator::ensure_block_returns;
 use thiserror::Error;
 pub use type_methods::min_max_float_representable_by;
-pub use typifier::{ResolveContext, ResolveError, TypeResolution};
+pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
 
 impl From<super::StorageFormat> for super::Scalar {
     fn from(format: super::StorageFormat) -> Self {
@@ -403,6 +403,10 @@ impl crate::Module {
             global_expressions: &self.global_expressions,
         }
     }
+
+    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
+        compare_types(lhs, rhs, &self.types)
+    }
 }
 
 #[derive(Debug)]
@@ -491,6 +495,10 @@ impl GlobalCtx<'_> {
             _ => get(*self, handle, arena),
         }
     }
+
+    pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
+        compare_types(lhs, rhs, self.types)
+    }
 }
 
 #[derive(Error, Debug, Clone, Copy, PartialEq)]
diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs
index ea779e7a9..289756a67 100644
--- a/naga/src/proc/type_methods.rs
+++ b/naga/src/proc/type_methods.rs
@@ -4,6 +4,8 @@
 //! [`Scalar`]: crate::Scalar
 //! [`ScalarKind`]: crate::ScalarKind
 
+use crate::ir;
+
 use super::TypeResolution;
 
 impl crate::ScalarKind {
@@ -255,24 +257,38 @@ impl crate::TypeInner {
         }
     }
 
-    /// Compare `self` and `rhs` as types.
+    /// Compare value type `self` and `rhs` as types.
     ///
     /// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
-    /// `ValuePointer` and `Pointer` types as equivalent. This method
+    /// [`ValuePointer`] and [`Pointer`] types as equivalent. This method
     /// cannot be used for structs, because it cannot distinguish two
     /// structs with different names but the same members. For structs,
-    /// use `Module::compare_types`.
+    /// use [`compare_types`].
     ///
-    /// When you know that one side of the comparison is never a pointer, it's
-    /// fine to not bother with canonicalization, and just compare `TypeInner`
-    /// values with `==`.
+    /// When you know that one side of the comparison is never a pointer or
+    /// struct, it's fine to not bother with canonicalization, and just
+    /// compare `TypeInner` values with `==`.
+    ///
+    /// # Panics
+    ///
+    /// If both `self` and `rhs` are structs.
+    ///
+    /// [`compare_types`]: crate::proc::compare_types
+    /// [`ValuePointer`]: ir::TypeInner::ValuePointer
+    /// [`Pointer`]: ir::TypeInner::Pointer
     pub fn non_struct_equivalent(
         &self,
-        rhs: &crate::TypeInner,
+        rhs: &ir::TypeInner,
         types: &crate::UniqueArena<crate::Type>,
     ) -> bool {
         let left = self.canonical_form(types);
         let right = rhs.canonical_form(types);
+
+        let left_struct = matches!(*self, ir::TypeInner::Struct { .. });
+        let right_struct = matches!(*rhs, ir::TypeInner::Struct { .. });
+
+        assert!(!left_struct || !right_struct);
+
         left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
     }
 
diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs
index baa8203ad..f7d1ea483 100644
--- a/naga/src/proc/typifier.rs
+++ b/naga/src/proc/typifier.rs
@@ -2,8 +2,11 @@ use alloc::{format, string::String};
 
 use thiserror::Error;
 
-use crate::arena::{Arena, Handle, UniqueArena};
-use crate::common::ForDebugWithTypes;
+use crate::{
+    arena::{Arena, Handle, UniqueArena},
+    common::ForDebugWithTypes,
+    ir,
+};
 
 /// The result of computing an expression's type.
 ///
@@ -773,6 +776,44 @@ impl<'a> ResolveContext<'a> {
     }
 }
 
+/// Compare two types.
+///
+/// This is the most general way of comparing two types, as it can distinguish
+/// two structs with different names but the same members. For other ways, see
+/// [`TypeInner::non_struct_equivalent`] and [`TypeInner::eq`].
+///
+/// In Naga code, this is usually called via the like-named methods on [`Module`],
+/// [`GlobalCtx`], and `BlockContext`.
+///
+/// [`TypeInner::non_struct_equivalent`]: crate::ir::TypeInner::non_struct_equivalent
+/// [`TypeInner::eq`]: crate::ir::TypeInner
+/// [`Module`]: crate::ir::Module
+/// [`GlobalCtx`]: crate::proc::GlobalCtx
+pub fn compare_types(
+    lhs: &TypeResolution,
+    rhs: &TypeResolution,
+    types: &UniqueArena<crate::Type>,
+) -> bool {
+    match lhs {
+        &TypeResolution::Handle(lhs_handle)
+            if matches!(
+                types[lhs_handle],
+                ir::Type {
+                    inner: ir::TypeInner::Struct { .. },
+                    ..
+                }
+            ) =>
+        {
+            // Structs can only be in the arena, not in a TypeResolution::Value
+            rhs.handle()
+                .is_some_and(|rhs_handle| lhs_handle == rhs_handle)
+        }
+        _ => lhs
+            .inner_with(types)
+            .non_struct_equivalent(rhs.inner_with(types), types),
+    }
+}
+
 #[test]
 fn test_error_size() {
     assert_eq!(size_of::<ResolveError>(), 32);
diff --git a/naga/src/valid/compose.rs b/naga/src/valid/compose.rs
index bba13952e..20df99cbf 100644
--- a/naga/src/valid/compose.rs
+++ b/naga/src/valid/compose.rs
@@ -84,11 +84,7 @@ pub fn validate_compose(
                 });
             }
             for (index, comp_res) in component_resolutions.enumerate() {
-                let base_inner = &gctx.types[base].inner;
-                let comp_res_inner = comp_res.inner_with(gctx.types);
-                // We don't support arrays of pointers, but it seems best not to
-                // embed that assumption here, so use `TypeInner::equivalent`.
-                if !base_inner.non_struct_equivalent(comp_res_inner, gctx.types) {
+                if !gctx.compare_types(&TypeResolution::Handle(base), &comp_res) {
                     log::error!("Array component[{}] type {:?}", index, comp_res);
                     return Err(ComposeError::ComponentType {
                         index: index as u32,
@@ -105,11 +101,7 @@ pub fn validate_compose(
             }
             for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate()
             {
-                let member_inner = &gctx.types[member.ty].inner;
-                let comp_res_inner = comp_res.inner_with(gctx.types);
-                // We don't support pointers in structs, but it seems best not to embed
-                // that assumption here, so use `TypeInner::equivalent`.
-                if !comp_res_inner.non_struct_equivalent(member_inner, gctx.types) {
+                if !gctx.compare_types(&TypeResolution::Handle(member.ty), &comp_res) {
                     log::error!("Struct component[{}] type {:?}", index, comp_res);
                     return Err(ComposeError::ComponentType {
                         index: index as u32,
diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs
index 7a2517597..7865f1fc4 100644
--- a/naga/src/valid/function.rs
+++ b/naga/src/valid/function.rs
@@ -120,8 +120,11 @@ pub enum FunctionError {
     ContinueOutsideOfLoop,
     #[error("The `return` is called within a `continuing` block")]
     InvalidReturnSpot,
-    #[error("The `return` value {0:?} does not match the function return value")]
-    InvalidReturnType(Option<Handle<crate::Expression>>),
+    #[error("The `return` expression {expression:?} does not match the declared return type {expected_ty:?}")]
+    InvalidReturnType {
+        expression: Option<Handle<crate::Expression>>,
+        expected_ty: Option<Handle<crate::Type>>,
+    },
     #[error("The `if` condition {0:?} is not a boolean scalar")]
     InvalidIfType(Handle<crate::Expression>),
     #[error("The `switch` value {0:?} is not an integer scalar")]
@@ -310,8 +313,8 @@ impl<'a> BlockContext<'a> {
         self.info[handle].ty.inner_with(self.types)
     }
 
-    fn inner_type<'t>(&'t self, ty: &'t TypeResolution) -> &'t crate::TypeInner {
-        ty.inner_with(self.types)
+    fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
+        crate::proc::compare_types(lhs, rhs, self.types)
     }
 }
 
@@ -338,8 +341,7 @@ impl super::Validator {
                     CallError::Argument { index, source }
                         .with_span_handle(expr, context.expressions)
                 })?;
-            let arg_inner = &context.types[arg.ty].inner;
-            if !ty.inner_with(context.types).non_struct_equivalent(arg_inner, context.types) {
+            if !context.compare_types(&TypeResolution::Handle(arg.ty), ty) {
                 return Err(CallError::ArgumentType {
                     index,
                     required: arg.ty,
@@ -964,13 +966,12 @@ impl super::Validator {
                     let value_ty = value
                         .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
                         .transpose()?;
-                    let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
                     // We can't return pointers, but it seems best not to embed that
                     // assumption here, so use `TypeInner::equivalent` for comparison.
-                    let okay = match (value_ty, expected_ty) {
+                    let okay = match (value_ty, context.return_type) {
                         (None, None) => true,
-                        (Some(value_inner), Some(expected_inner)) => {
-                            value_inner.inner_with(context.types).non_struct_equivalent(expected_inner, context.types)
+                        (Some(value_inner), Some(expected_ty)) => {
+                            context.compare_types(value_inner, &TypeResolution::Handle(expected_ty))
                         }
                         (_, _) => false,
                     };
@@ -979,14 +980,20 @@ impl super::Validator {
                         log::error!(
                             "Returning {:?} where {:?} is expected",
                             value_ty,
-                            expected_ty
+                            context.return_type,
                         );
                         if let Some(handle) = value {
-                            return Err(FunctionError::InvalidReturnType(value)
-                                .with_span_handle(handle, context.expressions));
+                            return Err(FunctionError::InvalidReturnType {
+                                expression: value,
+                                expected_ty: context.return_type,
+                            }
+                            .with_span_handle(handle, context.expressions));
                         } else {
-                            return Err(FunctionError::InvalidReturnType(value)
-                                .with_span_static(span, "invalid return"));
+                            return Err(FunctionError::InvalidReturnType {
+                                expression: value,
+                                expected_ty: context.return_type,
+                            }
+                            .with_span_static(span, "invalid return"));
                         }
                     }
                     finished = true;
@@ -1036,7 +1043,8 @@ impl super::Validator {
                         }
                     }
 
-                    let value_ty = context.resolve_type_inner(value, &self.valid_expression_set)?;
+                    let value_tr = context.resolve_type(value, &self.valid_expression_set)?;
+                    let value_ty = value_tr.inner_with(context.types);
                     match *value_ty {
                         Ti::Image { .. } | Ti::Sampler { .. } => {
                             return Err(FunctionError::InvalidStoreTexture {
@@ -1053,16 +1061,19 @@ impl super::Validator {
                     }
 
                     let pointer_ty = context.resolve_pointer_type(pointer);
-                    let good = match pointer_ty
-                        .pointer_base_type()
+                    let pointer_base_tr = pointer_ty.pointer_base_type();
+                    let pointer_base_ty = pointer_base_tr
                         .as_ref()
-                        .map(|ty| context.inner_type(ty))
-                    {
+                        .map(|ty| ty.inner_with(context.types));
+                    let good = if let Some(&Ti::Atomic(ref scalar)) = pointer_base_ty {
                         // The Naga IR allows storing a scalar to an atomic.
-                        Some(&Ti::Atomic(ref scalar)) => *value_ty == Ti::Scalar(*scalar),
-                        Some(other) => *value_ty == *other,
-                        None => false,
+                        *value_ty == Ti::Scalar(*scalar)
+                    } else if let Some(tr) = pointer_base_tr {
+                        context.compare_types(value_tr, &tr)
+                    } else {
+                        false
                     };
+
                     if !good {
                         return Err(FunctionError::InvalidStoreTypes { pointer, value }
                             .with_span()
@@ -1640,9 +1651,7 @@ impl super::Validator {
         }
 
         if let Some(init) = var.init {
-            let decl_ty = &gctx.types[var.ty].inner;
-            let init_ty = fun_info[init].ty.inner_with(gctx.types);
-            if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
+            if !gctx.compare_types(&TypeResolution::Handle(var.ty), &fun_info[init].ty) {
                 return Err(LocalVariableError::InitializerType);
             }
 
diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs
index 97027936b..3792c71ab 100644
--- a/naga/src/valid/interface.rs
+++ b/naga/src/valid/interface.rs
@@ -636,9 +636,10 @@ impl super::Validator {
                 return Err(GlobalVariableError::InitializerExprType);
             }
 
-            let decl_ty = &gctx.types[var.ty].inner;
-            let init_ty = mod_info[init].inner_with(gctx.types);
-            if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
+            if !gctx.compare_types(
+                &crate::proc::TypeResolution::Handle(var.ty),
+                &mod_info[init],
+            ) {
                 return Err(GlobalVariableError::InitializerType);
             }
         }
diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs
index 8a1f89e92..c8a02db1a 100644
--- a/naga/src/valid/mod.rs
+++ b/naga/src/valid/mod.rs
@@ -532,9 +532,7 @@ impl Validator {
             return Err(ConstantError::InitializerExprType);
         }
 
-        let decl_ty = &gctx.types[con.ty].inner;
-        let init_ty = mod_info[con.init].inner_with(gctx.types);
-        if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
+        if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
             return Err(ConstantError::InvalidType);
         }
 
@@ -560,9 +558,8 @@ impl Validator {
             return Err(OverrideError::NonConstructibleType);
         }
 
-        let decl_ty = &gctx.types[o.ty].inner;
-        match decl_ty {
-            &crate::TypeInner::Scalar(
+        match gctx.types[o.ty].inner {
+            crate::TypeInner::Scalar(
                 crate::Scalar::BOOL
                 | crate::Scalar::I32
                 | crate::Scalar::U32
@@ -574,8 +571,7 @@ impl Validator {
         }
 
         if let Some(init) = o.init {
-            let init_ty = mod_info[init].inner_with(gctx.types);
-            if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
+            if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
                 return Err(OverrideError::InvalidType);
             }
         } else if self.overrides_resolved {
diff --git a/naga/tests/naga/wgsl_errors.rs b/naga/tests/naga/wgsl_errors.rs
index db7534e65..800cba25f 100644
--- a/naga/tests/naga/wgsl_errors.rs
+++ b/naga/tests/naga/wgsl_errors.rs
@@ -291,15 +291,15 @@ fn constructor_parameter_type_mismatch() {
                 _ = mat2x2<f32>(array(0, 1), vec2(2, 3));
             }
         "#,
-        r#"error: automatic conversions cannot convert `array<{AbstractInt}, 2>` to `vec2<f32>`
+        "error: automatic conversions cannot convert `array<{AbstractInt}, 2>` to `vec2<f32>`
   ┌─ wgsl:3:21
   │
 3 │                 _ = mat2x2<f32>(array(0, 1), vec2(2, 3));
   │                     ^^^^^^^^^^^ ^^^^^^^^^^^ this expression has type array<{AbstractInt}, 2>
-  │                     │            
+  │                     │\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20
   │                     a value of type vec2<f32> is required here
 
-"#,
+",
     );
 }
 
@@ -1289,6 +1289,127 @@ fn invalid_structs() {
     }
 }
 
+#[test]
+fn struct_type_mismatch_in_assignment() {
+    check_validation!(
+        "
+        struct Foo { a: u32 };
+        struct Bar { a: u32 };
+        fn main() {
+            var x: Bar = Bar(1);
+            x = Foo(1);
+        }
+        ":
+        Err(naga::valid::ValidationError::Function {
+            handle: _,
+            name: function_name,
+            source: naga::valid::FunctionError::InvalidStoreTypes { .. },
+        })
+        // The validation error is reported at the call, i.e., in `main`
+        if function_name == "main"
+    );
+}
+
+#[test]
+fn struct_type_mismatch_in_let_decl() {
+    check(
+        "
+        struct Foo { a: u32 };
+        struct Bar { a: u32 };
+        fn main() {
+            let x: Bar = Foo(1);
+        }
+        ",
+        "error: the type of `x` is expected to be `Bar`, but got `Foo`
+  ┌─ wgsl:5:17
+  │
+5 │             let x: Bar = Foo(1);
+  │                 ^ definition of `x`
+
+",
+    );
+}
+
+#[test]
+fn struct_type_mismatch_in_return_value() {
+    check_validation!(
+        "
+        struct Foo { a: u32 };
+        struct Bar { a: u32 };
+        fn bar() -> Bar {
+            return Foo(1);
+        }
+        ":
+        Err(naga::valid::ValidationError::Function {
+            handle: _,
+            name: function_name,
+            source: naga::valid::FunctionError::InvalidReturnType { .. }
+        }) if function_name == "bar"
+    );
+}
+
+#[test]
+fn struct_type_mismatch_in_argument() {
+    check_validation!(
+        "
+        struct Foo { a: u32 };
+        struct Bar { a: u32 };
+        fn bar(a: Bar) {}
+        fn main() {
+            bar(Foo(1));
+        }
+        ":
+        Err(naga::valid::ValidationError::Function {
+            name: function_name,
+            source: naga::valid::FunctionError::InvalidCall {
+                function: _,
+                error: naga::valid::CallError::ArgumentType { index, .. },
+            },
+            ..
+        })
+        // The validation error is reported at the call, i.e., in `main`
+        if function_name == "main" && *index == 0
+    );
+}
+
+#[test]
+fn struct_type_mismatch_in_global_var() {
+    check(
+        "
+        struct Foo { a: u32 };
+        struct Bar { a: u32 };
+
+        var<uniform> foo: Foo = Bar(1);
+        ",
+        "error: the type of `foo` is expected to be `Foo`, but got `Bar`
+  ┌─ wgsl:5:22
+  │
+5 │         var<uniform> foo: Foo = Bar(1);
+  │                      ^^^ definition of `foo`
+
+",
+    );
+}
+
+#[test]
+fn struct_type_mismatch_in_global_const() {
+    check(
+        "
+        struct Foo { a: u32 };
+        struct Bar { a: u32 };
+
+        const foo: Foo = Bar(1);
+        ",
+        "error: the type of `foo` is expected to be `Foo`, but got `Bar`
+  ┌─ wgsl:5:15
+  │
+5 │         const foo: Foo = Bar(1);
+  │               ^^^ definition of `foo`
+
+",
+    );
+}
+
 #[test]
 fn invalid_functions() {
     check_validation! {
@@ -1408,7 +1529,7 @@ fn invalid_return_type() {
     check_validation! {
         "fn invalid_return_type() -> i32 { return 0u; }":
         Err(naga::valid::ValidationError::Function {
-            source: naga::valid::FunctionError::InvalidReturnType(Some(_)),
+            source: naga::valid::FunctionError::InvalidReturnType { .. },
             ..
         })
     };
@@ -2576,15 +2697,15 @@ fn function_param_redefinition_as_param() {
         "
         fn x(a: f32, a: vec2<f32>) {}
     ",
-        r###"error: redefinition of `a`
+        "error: redefinition of `a`
   ┌─ wgsl:2:14
   │
 2 │         fn x(a: f32, a: vec2<f32>) {}
   │              ^       ^ redefinition of `a`
-  │              │        
+  │              │\x20\x20\x20\x20\x20\x20\x20\x20
   │              previous definition of `a`
 
-"###,
+",
     )
 }
 
@@ -2608,6 +2729,25 @@ fn function_param_redefinition_as_local() {
     )
 }
 
+#[test]
+fn struct_redefinition() {
+    check(
+        "
+        struct Foo { a: u32 };
+        struct Foo { a: u32 };
+    ",
+        "error: redefinition of `Foo`
+  ┌─ wgsl:2:16
+  │
+2 │         struct Foo { a: u32 };
+  │                ^^^ previous definition of `Foo`
+3 │         struct Foo { a: u32 };
+  │                ^^^ redefinition of `Foo`
+
+",
+    );
+}
+
 #[test]
 fn struct_member_redefinition() {
     check(
@@ -2635,7 +2775,7 @@ fn function_must_return_value() {
         "fn func() -> i32 {
         }":
         Err(naga::valid::ValidationError::Function {
-            source: naga::valid::FunctionError::InvalidReturnType(_),
+            source: naga::valid::FunctionError::InvalidReturnType { .. },
             ..
         })
     );
@@ -2644,7 +2784,7 @@ fn function_must_return_value() {
             let y = x + 10;
         }":
         Err(naga::valid::ValidationError::Function {
-            source: naga::valid::FunctionError::InvalidReturnType(_),
+            source: naga::valid::FunctionError::InvalidReturnType { .. },
             ..
         })
     );
@@ -2658,15 +2798,15 @@ fn constructor_type_error_span() {
             var a: array<i32, 1> = array<i32, 1>(1.0);
         }
     ",
-        r###"error: automatic conversions cannot convert `{AbstractFloat}` to `i32`
+        "error: automatic conversions cannot convert `{AbstractFloat}` to `i32`
   ┌─ wgsl:3:36
   │
 3 │             var a: array<i32, 1> = array<i32, 1>(1.0);
   │                                    ^^^^^^^^^^^^^ ^^^ this expression has type {AbstractFloat}
-  │                                    │              
+  │                                    │\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20
   │                                    a value of type i32 is required here
 
-"###,
+",
     )
 }