[naga] Two structs with the same members are not equivalent

Fixes #5796
This commit is contained in:
Andy Leiserson 2025-03-27 15:05:31 -07:00 committed by Teodor Tanasoaia
parent 19429a1dc9
commit 2a2c851c40
12 changed files with 302 additions and 79 deletions

View File

@ -16,7 +16,8 @@ use crate::{FastIndexSet, Span};
/// The element type must implement `Eq` and `Hash`. Insertions of equivalent
/// elements, according to `Eq`, all return the same `Handle`.
///
/// Once inserted, elements may not be mutated.
/// Once inserted, elements generally may not be mutated, although a `replace`
/// method exists to support rare cases.
///
/// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like,
/// `UniqueArena` is `HashSet`-like.

View File

@ -46,7 +46,7 @@ impl<'source> super::ExpressionContext<'source, '_, '_> {
}
// If `expr` already has the requested type, we're done.
if expr_inner.non_struct_equivalent(goal_inner, types) {
if self.module.compare_types(expr_resolution, goal_ty) {
return Ok(expr);
}

View File

@ -1311,9 +1311,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
})?;
let init_ty = ectx.register_type(init)?;
let explicit_inner = &ectx.module.types[explicit_ty].inner;
let init_inner = &ectx.module.types[init_ty].inner;
if !explicit_inner.non_struct_equivalent(init_inner, &ectx.module.types) {
if !ectx.module.compare_types(
&crate::proc::TypeResolution::Handle(explicit_ty),
&crate::proc::TypeResolution::Handle(init_ty),
) {
return Err(Box::new(Error::InitializationTypeMismatch {
name: name.span,
expected: ectx.type_to_string(explicit_ty),

View File

@ -636,6 +636,15 @@ pub struct Type {
}
/// Enum with additional information, depending on the kind of type.
///
/// Comparison using `==` is not reliable in the case of [`Pointer`],
/// [`ValuePointer`], or [`Struct`] variants. For these variants,
/// use [`TypeInner::non_struct_equivalent`] or [`compare_types`].
///
/// [`compare_types`]: crate::proc::compare_types
/// [`ValuePointer`]: TypeInner::ValuePointer
/// [`Pointer`]: TypeInner::Pointer
/// [`Struct`]: TypeInner::Struct
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
@ -656,8 +665,9 @@ pub enum TypeInner {
/// Pointer to another type.
///
/// Pointers to scalars and vectors should be treated as equivalent to
/// [`ValuePointer`] types. Use the [`TypeInner::equivalent`] method to
/// compare types in a way that treats pointers correctly.
/// [`ValuePointer`] types. Use either [`TypeInner::non_struct_equivalent`]
/// or [`compare_types`] to compare types in a way that treats pointers
/// correctly.
///
/// ## Pointers to non-`SIZED` types
///
@ -679,6 +689,7 @@ pub enum TypeInner {
/// [`ValuePointer`]: TypeInner::ValuePointer
/// [`GlobalVariable`]: Expression::GlobalVariable
/// [`AccessIndex`]: Expression::AccessIndex
/// [`compare_types`]: crate::proc::compare_types
Pointer {
base: Handle<Type>,
space: AddressSpace,
@ -690,12 +701,13 @@ pub enum TypeInner {
/// `Scalar` or `Vector` type. This is for use in [`TypeResolution::Value`]
/// variants; see the documentation for [`TypeResolution`] for details.
///
/// Use the [`TypeInner::equivalent`] method to compare types that could be
/// pointers, to ensure that `Pointer` and `ValuePointer` types are
/// recognized as equivalent.
/// Use [`TypeInner::non_struct_equivalent`] or [`compare_types`] to compare
/// types that could be pointers, to ensure that `Pointer` and
/// `ValuePointer` types are recognized as equivalent.
///
/// [`TypeResolution`]: crate::proc::TypeResolution
/// [`TypeResolution::Value`]: crate::proc::TypeResolution::Value
/// [`compare_types`]: crate::proc::compare_types
ValuePointer {
size: Option<VectorSize>,
scalar: Scalar,
@ -744,9 +756,15 @@ pub enum TypeInner {
/// struct, which may be a dynamically sized [`Array`]. The
/// `Struct` type itself is `SIZED` when all its members are `SIZED`.
///
/// Two structure types with different names are not equivalent. Because
/// this variant does not contain the name, it is not possible to use it
/// to compare struct types. Use [`compare_types`] to compare two types
/// that may be structs.
///
/// [`DATA`]: crate::valid::TypeFlags::DATA
/// [`SIZED`]: crate::∅TypeFlags::SIZED
/// [`Array`]: TypeInner::Array
/// [`compare_types`]: crate::proc::compare_types
Struct {
members: Vec<StructMember>,
//TODO: should this be unaligned?

View File

@ -23,7 +23,7 @@ pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
pub use terminator::ensure_block_returns;
use thiserror::Error;
pub use type_methods::min_max_float_representable_by;
pub use typifier::{ResolveContext, ResolveError, TypeResolution};
pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
impl From<super::StorageFormat> for super::Scalar {
fn from(format: super::StorageFormat) -> Self {
@ -403,6 +403,10 @@ impl crate::Module {
global_expressions: &self.global_expressions,
}
}
pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
compare_types(lhs, rhs, &self.types)
}
}
#[derive(Debug)]
@ -491,6 +495,10 @@ impl GlobalCtx<'_> {
_ => get(*self, handle, arena),
}
}
pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
compare_types(lhs, rhs, self.types)
}
}
#[derive(Error, Debug, Clone, Copy, PartialEq)]

View File

@ -4,6 +4,8 @@
//! [`Scalar`]: crate::Scalar
//! [`ScalarKind`]: crate::ScalarKind
use crate::ir;
use super::TypeResolution;
impl crate::ScalarKind {
@ -255,24 +257,38 @@ impl crate::TypeInner {
}
}
/// Compare `self` and `rhs` as types.
/// Compare value type `self` and `rhs` as types.
///
/// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
/// `ValuePointer` and `Pointer` types as equivalent. This method
/// [`ValuePointer`] and [`Pointer`] types as equivalent. This method
/// cannot be used for structs, because it cannot distinguish two
/// structs with different names but the same members. For structs,
/// use `Module::compare_types`.
/// use [`compare_types`].
///
/// When you know that one side of the comparison is never a pointer, it's
/// fine to not bother with canonicalization, and just compare `TypeInner`
/// values with `==`.
/// When you know that one side of the comparison is never a pointer or
/// struct, it's fine to not bother with canonicalization, and just
/// compare `TypeInner` values with `==`.
///
/// # Panics
///
/// If both `self` and `rhs` are structs.
///
/// [`compare_types`]: crate::proc::compare_types
/// [`ValuePointer`]: ir::TypeInner::ValuePointer
/// [`Pointer`]: ir::TypeInner::Pointer
pub fn non_struct_equivalent(
&self,
rhs: &crate::TypeInner,
rhs: &ir::TypeInner,
types: &crate::UniqueArena<crate::Type>,
) -> bool {
let left = self.canonical_form(types);
let right = rhs.canonical_form(types);
let left_struct = matches!(*self, ir::TypeInner::Struct { .. });
let right_struct = matches!(*rhs, ir::TypeInner::Struct { .. });
assert!(!left_struct || !right_struct);
left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
}

View File

@ -2,8 +2,11 @@ use alloc::{format, string::String};
use thiserror::Error;
use crate::arena::{Arena, Handle, UniqueArena};
use crate::common::ForDebugWithTypes;
use crate::{
arena::{Arena, Handle, UniqueArena},
common::ForDebugWithTypes,
ir,
};
/// The result of computing an expression's type.
///
@ -773,6 +776,44 @@ impl<'a> ResolveContext<'a> {
}
}
/// Compare two types.
///
/// This is the most general way of comparing two types, as it can distinguish
/// two structs with different names but the same members. For other ways, see
/// [`TypeInner::non_struct_equivalent`] and [`TypeInner::eq`].
///
/// In Naga code, this is usually called via the like-named methods on [`Module`],
/// [`GlobalCtx`], and `BlockContext`.
///
/// [`TypeInner::non_struct_equivalent`]: crate::ir::TypeInner::non_struct_equivalent
/// [`TypeInner::eq`]: crate::ir::TypeInner
/// [`Module`]: crate::ir::Module
/// [`GlobalCtx`]: crate::proc::GlobalCtx
pub fn compare_types(
lhs: &TypeResolution,
rhs: &TypeResolution,
types: &UniqueArena<crate::Type>,
) -> bool {
match lhs {
&TypeResolution::Handle(lhs_handle)
if matches!(
types[lhs_handle],
ir::Type {
inner: ir::TypeInner::Struct { .. },
..
}
) =>
{
// Structs can only be in the arena, not in a TypeResolution::Value
rhs.handle()
.is_some_and(|rhs_handle| lhs_handle == rhs_handle)
}
_ => lhs
.inner_with(types)
.non_struct_equivalent(rhs.inner_with(types), types),
}
}
#[test]
fn test_error_size() {
assert_eq!(size_of::<ResolveError>(), 32);

View File

@ -84,11 +84,7 @@ pub fn validate_compose(
});
}
for (index, comp_res) in component_resolutions.enumerate() {
let base_inner = &gctx.types[base].inner;
let comp_res_inner = comp_res.inner_with(gctx.types);
// We don't support arrays of pointers, but it seems best not to
// embed that assumption here, so use `TypeInner::equivalent`.
if !base_inner.non_struct_equivalent(comp_res_inner, gctx.types) {
if !gctx.compare_types(&TypeResolution::Handle(base), &comp_res) {
log::error!("Array component[{}] type {:?}", index, comp_res);
return Err(ComposeError::ComponentType {
index: index as u32,
@ -105,11 +101,7 @@ pub fn validate_compose(
}
for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate()
{
let member_inner = &gctx.types[member.ty].inner;
let comp_res_inner = comp_res.inner_with(gctx.types);
// We don't support pointers in structs, but it seems best not to embed
// that assumption here, so use `TypeInner::equivalent`.
if !comp_res_inner.non_struct_equivalent(member_inner, gctx.types) {
if !gctx.compare_types(&TypeResolution::Handle(member.ty), &comp_res) {
log::error!("Struct component[{}] type {:?}", index, comp_res);
return Err(ComposeError::ComponentType {
index: index as u32,

View File

@ -120,8 +120,11 @@ pub enum FunctionError {
ContinueOutsideOfLoop,
#[error("The `return` is called within a `continuing` block")]
InvalidReturnSpot,
#[error("The `return` value {0:?} does not match the function return value")]
InvalidReturnType(Option<Handle<crate::Expression>>),
#[error("The `return` expression {expression:?} does not match the declared return type {expected_ty:?}")]
InvalidReturnType {
expression: Option<Handle<crate::Expression>>,
expected_ty: Option<Handle<crate::Type>>,
},
#[error("The `if` condition {0:?} is not a boolean scalar")]
InvalidIfType(Handle<crate::Expression>),
#[error("The `switch` value {0:?} is not an integer scalar")]
@ -310,8 +313,8 @@ impl<'a> BlockContext<'a> {
self.info[handle].ty.inner_with(self.types)
}
fn inner_type<'t>(&'t self, ty: &'t TypeResolution) -> &'t crate::TypeInner {
ty.inner_with(self.types)
fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
crate::proc::compare_types(lhs, rhs, self.types)
}
}
@ -338,8 +341,7 @@ impl super::Validator {
CallError::Argument { index, source }
.with_span_handle(expr, context.expressions)
})?;
let arg_inner = &context.types[arg.ty].inner;
if !ty.inner_with(context.types).non_struct_equivalent(arg_inner, context.types) {
if !context.compare_types(&TypeResolution::Handle(arg.ty), ty) {
return Err(CallError::ArgumentType {
index,
required: arg.ty,
@ -964,13 +966,12 @@ impl super::Validator {
let value_ty = value
.map(|expr| context.resolve_type(expr, &self.valid_expression_set))
.transpose()?;
let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
// We can't return pointers, but it seems best not to embed that
// assumption here, so use `TypeInner::equivalent` for comparison.
let okay = match (value_ty, expected_ty) {
let okay = match (value_ty, context.return_type) {
(None, None) => true,
(Some(value_inner), Some(expected_inner)) => {
value_inner.inner_with(context.types).non_struct_equivalent(expected_inner, context.types)
(Some(value_inner), Some(expected_ty)) => {
context.compare_types(value_inner, &TypeResolution::Handle(expected_ty))
}
(_, _) => false,
};
@ -979,14 +980,20 @@ impl super::Validator {
log::error!(
"Returning {:?} where {:?} is expected",
value_ty,
expected_ty
context.return_type,
);
if let Some(handle) = value {
return Err(FunctionError::InvalidReturnType(value)
.with_span_handle(handle, context.expressions));
return Err(FunctionError::InvalidReturnType {
expression: value,
expected_ty: context.return_type,
}
.with_span_handle(handle, context.expressions));
} else {
return Err(FunctionError::InvalidReturnType(value)
.with_span_static(span, "invalid return"));
return Err(FunctionError::InvalidReturnType {
expression: value,
expected_ty: context.return_type,
}
.with_span_static(span, "invalid return"));
}
}
finished = true;
@ -1036,7 +1043,8 @@ impl super::Validator {
}
}
let value_ty = context.resolve_type_inner(value, &self.valid_expression_set)?;
let value_tr = context.resolve_type(value, &self.valid_expression_set)?;
let value_ty = value_tr.inner_with(context.types);
match *value_ty {
Ti::Image { .. } | Ti::Sampler { .. } => {
return Err(FunctionError::InvalidStoreTexture {
@ -1053,16 +1061,19 @@ impl super::Validator {
}
let pointer_ty = context.resolve_pointer_type(pointer);
let good = match pointer_ty
.pointer_base_type()
let pointer_base_tr = pointer_ty.pointer_base_type();
let pointer_base_ty = pointer_base_tr
.as_ref()
.map(|ty| context.inner_type(ty))
{
.map(|ty| ty.inner_with(context.types));
let good = if let Some(&Ti::Atomic(ref scalar)) = pointer_base_ty {
// The Naga IR allows storing a scalar to an atomic.
Some(&Ti::Atomic(ref scalar)) => *value_ty == Ti::Scalar(*scalar),
Some(other) => *value_ty == *other,
None => false,
*value_ty == Ti::Scalar(*scalar)
} else if let Some(tr) = pointer_base_tr {
context.compare_types(value_tr, &tr)
} else {
false
};
if !good {
return Err(FunctionError::InvalidStoreTypes { pointer, value }
.with_span()
@ -1640,9 +1651,7 @@ impl super::Validator {
}
if let Some(init) = var.init {
let decl_ty = &gctx.types[var.ty].inner;
let init_ty = fun_info[init].ty.inner_with(gctx.types);
if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
if !gctx.compare_types(&TypeResolution::Handle(var.ty), &fun_info[init].ty) {
return Err(LocalVariableError::InitializerType);
}

View File

@ -636,9 +636,10 @@ impl super::Validator {
return Err(GlobalVariableError::InitializerExprType);
}
let decl_ty = &gctx.types[var.ty].inner;
let init_ty = mod_info[init].inner_with(gctx.types);
if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
if !gctx.compare_types(
&crate::proc::TypeResolution::Handle(var.ty),
&mod_info[init],
) {
return Err(GlobalVariableError::InitializerType);
}
}

View File

@ -532,9 +532,7 @@ impl Validator {
return Err(ConstantError::InitializerExprType);
}
let decl_ty = &gctx.types[con.ty].inner;
let init_ty = mod_info[con.init].inner_with(gctx.types);
if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
if !gctx.compare_types(&TypeResolution::Handle(con.ty), &mod_info[con.init]) {
return Err(ConstantError::InvalidType);
}
@ -560,9 +558,8 @@ impl Validator {
return Err(OverrideError::NonConstructibleType);
}
let decl_ty = &gctx.types[o.ty].inner;
match decl_ty {
&crate::TypeInner::Scalar(
match gctx.types[o.ty].inner {
crate::TypeInner::Scalar(
crate::Scalar::BOOL
| crate::Scalar::I32
| crate::Scalar::U32
@ -574,8 +571,7 @@ impl Validator {
}
if let Some(init) = o.init {
let init_ty = mod_info[init].inner_with(gctx.types);
if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) {
return Err(OverrideError::InvalidType);
}
} else if self.overrides_resolved {

View File

@ -291,15 +291,15 @@ fn constructor_parameter_type_mismatch() {
_ = mat2x2<f32>(array(0, 1), vec2(2, 3));
}
"#,
r#"error: automatic conversions cannot convert `array<{AbstractInt}, 2>` to `vec2<f32>`
"error: automatic conversions cannot convert `array<{AbstractInt}, 2>` to `vec2<f32>`
wgsl:3:21
3 _ = mat2x2<f32>(array(0, 1), vec2(2, 3));
^^^^^^^^^^^ ^^^^^^^^^^^ this expression has type array<{AbstractInt}, 2>
\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20
a value of type vec2<f32> is required here
"#,
",
);
}
@ -1289,6 +1289,127 @@ fn invalid_structs() {
}
}
#[test]
fn struct_type_mismatch_in_assignment() {
check_validation!(
"
struct Foo { a: u32 };
struct Bar { a: u32 };
fn main() {
var x: Bar = Bar(1);
x = Foo(1);
}
":
Err(naga::valid::ValidationError::Function {
handle: _,
name: function_name,
source: naga::valid::FunctionError::InvalidStoreTypes { .. },
})
// The validation error is reported at the call, i.e., in `main`
if function_name == "main"
);
}
#[test]
fn struct_type_mismatch_in_let_decl() {
check(
"
struct Foo { a: u32 };
struct Bar { a: u32 };
fn main() {
let x: Bar = Foo(1);
}
",
"error: the type of `x` is expected to be `Bar`, but got `Foo`
wgsl:5:17
5 let x: Bar = Foo(1);
^ definition of `x`
",
);
}
#[test]
fn struct_type_mismatch_in_return_value() {
check_validation!(
"
struct Foo { a: u32 };
struct Bar { a: u32 };
fn bar() -> Bar {
return Foo(1);
}
":
Err(naga::valid::ValidationError::Function {
handle: _,
name: function_name,
source: naga::valid::FunctionError::InvalidReturnType { .. }
}) if function_name == "bar"
);
}
#[test]
fn struct_type_mismatch_in_argument() {
check_validation!(
"
struct Foo { a: u32 };
struct Bar { a: u32 };
fn bar(a: Bar) {}
fn main() {
bar(Foo(1));
}
":
Err(naga::valid::ValidationError::Function {
name: function_name,
source: naga::valid::FunctionError::InvalidCall {
function: _,
error: naga::valid::CallError::ArgumentType { index, .. },
},
..
})
// The validation error is reported at the call, i.e., in `main`
if function_name == "main" && *index == 0
);
}
#[test]
fn struct_type_mismatch_in_global_var() {
check(
"
struct Foo { a: u32 };
struct Bar { a: u32 };
var<uniform> foo: Foo = Bar(1);
",
"error: the type of `foo` is expected to be `Foo`, but got `Bar`
wgsl:5:22
5 var<uniform> foo: Foo = Bar(1);
^^^ definition of `foo`
",
);
}
#[test]
fn struct_type_mismatch_in_global_const() {
check(
"
struct Foo { a: u32 };
struct Bar { a: u32 };
const foo: Foo = Bar(1);
",
"error: the type of `foo` is expected to be `Foo`, but got `Bar`
wgsl:5:15
5 const foo: Foo = Bar(1);
^^^ definition of `foo`
",
);
}
#[test]
fn invalid_functions() {
check_validation! {
@ -1408,7 +1529,7 @@ fn invalid_return_type() {
check_validation! {
"fn invalid_return_type() -> i32 { return 0u; }":
Err(naga::valid::ValidationError::Function {
source: naga::valid::FunctionError::InvalidReturnType(Some(_)),
source: naga::valid::FunctionError::InvalidReturnType { .. },
..
})
};
@ -2576,15 +2697,15 @@ fn function_param_redefinition_as_param() {
"
fn x(a: f32, a: vec2<f32>) {}
",
r###"error: redefinition of `a`
"error: redefinition of `a`
wgsl:2:14
2 fn x(a: f32, a: vec2<f32>) {}
^ ^ redefinition of `a`
\x20\x20\x20\x20\x20\x20\x20\x20
previous definition of `a`
"###,
",
)
}
@ -2608,6 +2729,25 @@ fn function_param_redefinition_as_local() {
)
}
#[test]
fn struct_redefinition() {
check(
"
struct Foo { a: u32 };
struct Foo { a: u32 };
",
"error: redefinition of `Foo`
wgsl:2:16
2 struct Foo { a: u32 };
^^^ previous definition of `Foo`
3 struct Foo { a: u32 };
^^^ redefinition of `Foo`
",
);
}
#[test]
fn struct_member_redefinition() {
check(
@ -2635,7 +2775,7 @@ fn function_must_return_value() {
"fn func() -> i32 {
}":
Err(naga::valid::ValidationError::Function {
source: naga::valid::FunctionError::InvalidReturnType(_),
source: naga::valid::FunctionError::InvalidReturnType { .. },
..
})
);
@ -2644,7 +2784,7 @@ fn function_must_return_value() {
let y = x + 10;
}":
Err(naga::valid::ValidationError::Function {
source: naga::valid::FunctionError::InvalidReturnType(_),
source: naga::valid::FunctionError::InvalidReturnType { .. },
..
})
);
@ -2658,15 +2798,15 @@ fn constructor_type_error_span() {
var a: array<i32, 1> = array<i32, 1>(1.0);
}
",
r###"error: automatic conversions cannot convert `{AbstractFloat}` to `i32`
"error: automatic conversions cannot convert `{AbstractFloat}` to `i32`
wgsl:3:36
3 var a: array<i32, 1> = array<i32, 1>(1.0);
^^^^^^^^^^^^^ ^^^ this expression has type {AbstractFloat}
\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20\x20
a value of type i32 is required here
"###,
",
)
}