[naga] Support MathFunction overloads correctly.

Define a new trait, `proc::builtins::OverloadSet`, for types that
represent a Naga IR builtin function's set of overloads. The
`OverloadSet` trait includes operations needed to validate calls,
choose automatic type conversions, and generate diagnostics.

Add a new function, `ir::MathFunction::overloads`, which returns the
given `MathFunction`'s set of overloads as an `impl OverloadSet`
value. Use this in the WGSL front end, the validator, and the
typifier.

To support `MathFunction::overloads`, provide two implementations
of `OverloadSet`:

- `List` is flexible but verbose.

- `Regular` is concise but more restrictive.

Some snapshot output is affected because `TypeResolution::Handle`
values turn into `TypeResolution::Value`, since the function database
constructs the return type directly.

To work around #7405, avoid offering abstract-typed overloads of some
functions.

This addresses #6443 for `MathFunction`, although that issue covers
other categories of operations as well.
This commit is contained in:
Jim Blandy 2025-03-17 13:53:10 -07:00 committed by Connor Fitzgerald
parent 8292949478
commit 7da699608d
27 changed files with 3053 additions and 903 deletions

View File

@ -1,6 +1,6 @@
//! Displaying Naga IR terms in diagnostic output.
use crate::proc::GlobalCtx;
use crate::proc::{GlobalCtx, Rule};
use crate::{Handle, Scalar, Type, TypeInner};
#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))]
@ -83,6 +83,23 @@ impl fmt::Display for DiagnosticDisplay<(&TypeInner, GlobalCtx<'_>)> {
}
}
impl fmt::Display for DiagnosticDisplay<(&str, &Rule, GlobalCtx<'_>)> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (name, rule, ref ctx) = self.0;
#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))]
ctx.write_type_rule(name, rule, f)?;
#[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))]
{
let _ = ctx;
write!(f, "{name}({:?}) -> {:?}", rule.arguments, rule.conclusion)?;
}
Ok(())
}
}
impl fmt::Display for DiagnosticDisplay<Scalar> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let scalar = self.0;

View File

@ -133,6 +133,37 @@ pub trait TypeContext {
}
}
fn write_type_conclusion<W: Write>(
&self,
conclusion: &crate::proc::Conclusion,
out: &mut W,
) -> core::fmt::Result {
use crate::proc::Conclusion as Co;
match *conclusion {
Co::Value(ref inner) => self.write_type_inner(inner, out),
Co::Predeclared(ref predeclared) => out.write_str(&predeclared.struct_name()),
}
}
fn write_type_rule<W: Write>(
&self,
name: &str,
rule: &crate::proc::Rule,
out: &mut W,
) -> core::fmt::Result {
write!(out, "fn {name}(")?;
for (i, arg) in rule.arguments.iter().enumerate() {
if i > 0 {
out.write_str(", ")?;
}
self.write_type_resolution(arg, out)?
}
out.write_str(") -> ")?;
self.write_type_conclusion(&rule.conclusion, out)?;
Ok(())
}
fn type_to_string(&self, handle: Handle<crate::Type>) -> String {
let mut buf = String::new();
self.write_type(handle, &mut buf).unwrap();
@ -150,6 +181,12 @@ pub trait TypeContext {
self.write_type_resolution(resolution, &mut buf).unwrap();
buf
}
fn type_rule_to_string(&self, name: &str, rule: &crate::proc::Rule) -> String {
let mut buf = String::new();
self.write_type_rule(name, rule, &mut buf).unwrap();
buf
}
}
fn try_write_type_inner<C, W>(ctx: &C, inner: &TypeInner, out: &mut W) -> Result<(), WriteTypeError>

View File

@ -271,6 +271,75 @@ pub(crate) enum Error<'a> {
expected: Range<u32>,
found: u32,
},
/// No overload of this function accepts this many arguments.
TooManyArguments {
/// The name of the function being called.
function: String,
/// The function name in the call expression.
call_span: Span,
/// The first argument that is unacceptable.
arg_span: Span,
/// Maximum number of arguments accepted by any overload of
/// this function.
max_arguments: u32,
},
/// A value passed to a builtin function has a type that is not
/// accepted by any overload of the function.
WrongArgumentType {
/// The name of the function being called.
function: String,
/// The function name in the call expression.
call_span: Span,
/// The first argument whose type is unacceptable.
arg_span: Span,
/// The index of the first argument whose type is unacceptable.
arg_index: u32,
/// That argument's actual type.
arg_ty: String,
/// The set of argument types that would have been accepted for
/// this argument, given the prior arguments.
allowed: Vec<String>,
},
/// A value passed to a builtin function has a type that is not
/// accepted, given the earlier arguments' types.
InconsistentArgumentType {
/// The name of the function being called.
function: String,
/// The function name in the call expression.
call_span: Span,
/// The first unacceptable argument.
arg_span: Span,
/// The index of the first unacceptable argument.
arg_index: u32,
/// The actual type of the first unacceptable argument.
arg_ty: String,
/// The prior argument whose type made the `arg_span` argument
/// unacceptable.
inconsistent_span: Span,
/// The index of the `inconsistent_span` argument.
inconsistent_index: u32,
/// The type of the `inconsistent_span` argument.
inconsistent_ty: String,
/// The types that would have been accepted instead of the
/// first unacceptable argument.
allowed: Vec<String>,
},
FunctionReturnsVoid(Span),
FunctionMustUseUnused(Span),
FunctionMustUseReturnsVoid(Span, Span),
@ -402,7 +471,8 @@ impl<'a> Error<'a> {
"workgroup size separator (`,`) or a closing parenthesis".to_string()
}
ExpectedToken::GlobalItem => concat!(
"global item (`struct`, `const`, `var`, `alias`, `fn`, `diagnostic`, `enable`, `requires`, `;`) ",
"global item (`struct`, `const`, `var`, `alias`, ",
"`fn`, `diagnostic`, `enable`, `requires`, `;`) ",
"or the end of the file"
)
.to_string(),
@ -831,6 +901,74 @@ impl<'a> Error<'a> {
labels: vec![(span, "wrong number of arguments".into())],
notes: vec![],
},
Error::TooManyArguments {
ref function,
call_span,
arg_span,
max_arguments,
} => ParseError {
message: format!("too many arguments passed to `{function}`"),
labels: vec![
(call_span, "".into()),
(arg_span, format!("unexpected argument #{}", max_arguments + 1).into())
],
notes: vec![
format!("The `{function}` function accepts at most {max_arguments} argument(s)")
],
},
Error::WrongArgumentType {
ref function,
call_span,
arg_span,
arg_index,
ref arg_ty,
ref allowed,
} => {
let message = format!(
"wrong type passed as argument #{} to `{function}`",
arg_index + 1,
);
let labels = vec![
(call_span, "".into()),
(arg_span, format!("argument #{} has type `{arg_ty}`", arg_index + 1).into())
];
let mut notes = vec![];
notes.push(format!("`{function}` accepts the following types for argument #{}:", arg_index + 1));
notes.extend(allowed.iter().map(|ty| format!("allowed type: {ty}")));
ParseError { message, labels, notes }
},
Error::InconsistentArgumentType {
ref function,
call_span,
arg_span,
arg_index,
ref arg_ty,
inconsistent_span,
inconsistent_index,
ref inconsistent_ty,
ref allowed
} => {
let message = format!(
"inconsistent type passed as argument #{} to `{function}`",
arg_index + 1,
);
let labels = vec![
(call_span, "".into()),
(arg_span, format!("argument #{} has type {arg_ty}", arg_index + 1).into()),
(inconsistent_span, format!(
"this argument has type {inconsistent_ty}, which constrains subsequent arguments"
).into()),
];
let mut notes = vec![
format!("Because argument #{} has type {inconsistent_ty}, only the following types", inconsistent_index + 1),
format!("(or types that automatically convert to them) are accepted for argument #{}:", arg_index + 1),
];
notes.extend(allowed.iter().map(|ty| format!("allowed type: {ty}")));
ParseError { message, labels, notes }
}
Error::FunctionReturnsVoid(span) => ParseError {
message: "function does not return any value".to_string(),
labels: vec![(span, "".into())],

View File

@ -6,7 +6,7 @@ use alloc::{
};
use core::num::NonZeroU32;
use crate::common::wgsl::TypeContext;
use crate::common::wgsl::{TryToWgsl, TypeContext};
use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType};
use crate::front::wgsl::index::Index;
use crate::front::wgsl::parse::number::Number;
@ -491,6 +491,19 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
}
/// Return a wrapper around `value` suitable for formatting.
///
/// Return a wrapper around `value` that implements
/// [`core::fmt::Display`] in a form suitable for use in
/// diagnostic messages.
fn as_diagnostic_display<T>(
&self,
value: T,
) -> crate::common::DiagnosticDisplay<(T, crate::proc::GlobalCtx)> {
let ctx = self.module.to_ctx();
crate::common::DiagnosticDisplay((value, ctx))
}
fn append_expression(
&mut self,
expr: crate::Expression,
@ -2439,52 +2452,204 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
crate::Expression::Derivative { axis, ctrl, expr }
} else if let Some(fun) = conv::map_standard_fun(function.name) {
let expected = fun.argument_count() as _;
let mut args = ctx.prepare_args(arguments, expected, span);
use crate::proc::OverloadSet as _;
let arg = self.expression(args.next()?, ctx)?;
let arg1 = args
.next()
.map(|x| self.expression(x, ctx))
.ok()
.transpose()?;
let arg2 = args
.next()
.map(|x| self.expression(x, ctx))
.ok()
.transpose()?;
let arg3 = args
.next()
.map(|x| self.expression(x, ctx))
.ok()
.transpose()?;
let fun_overloads = fun.overloads();
let mut remaining_overloads = fun_overloads.clone();
let min_arguments = remaining_overloads.min_arguments();
let max_arguments = remaining_overloads.max_arguments();
if arguments.len() < min_arguments {
return Err(Box::new(Error::WrongArgumentCount {
span,
expected: min_arguments as u32..max_arguments as u32,
found: arguments.len() as u32,
}));
}
if arguments.len() > max_arguments {
return Err(Box::new(Error::TooManyArguments {
function: fun.to_wgsl_for_diagnostics(),
call_span: span,
arg_span: ctx.ast_expressions.get_span(arguments[max_arguments]),
max_arguments: max_arguments as _,
}));
}
args.finish()?;
log::debug!(
"Initial overloads: {:#?}",
remaining_overloads.for_debug(&ctx.module.types)
);
let mut unconverted_arguments = Vec::with_capacity(arguments.len());
for (arg_index, &arg) in arguments.iter().enumerate() {
let lowered = self.expression_for_abstract(arg, ctx)?;
let ty = resolve_inner!(ctx, lowered);
log::debug!(
"Supplying argument {arg_index} of type {}",
crate::common::DiagnosticDisplay((ty, ctx.module.to_ctx()))
);
let next_remaining_overloads =
remaining_overloads.arg(arg_index, ty, &ctx.module.types);
if fun == crate::MathFunction::Modf || fun == crate::MathFunction::Frexp {
if let Some((size, scalar)) = match *resolve_inner!(ctx, arg) {
crate::TypeInner::Scalar(scalar) => Some((None, scalar)),
crate::TypeInner::Vector { size, scalar, .. } => {
Some((Some(size), scalar))
// If any argument is not a constant expression, then no overloads
// that accept abstract values should be considered.
// (`OverloadSet::concrete_only` is supposed to help impose this
// restriction.) However, no `MathFunction` accepts a mix of
// abstract and concrete arguments, so we don't need to worry
// about that here.
log::debug!(
"Remaining overloads: {:#?}",
next_remaining_overloads.for_debug(&ctx.module.types)
);
// If the set of remaining overloads is empty, then this argument's type
// was unacceptable. Diagnose the problem and produce an error message.
if next_remaining_overloads.is_empty() {
let function = fun.to_wgsl_for_diagnostics();
let call_span = span;
let arg_span = ctx.ast_expressions.get_span(arg);
let arg_ty = ctx.as_diagnostic_display(ty).to_string();
// Is this type *ever* permitted for the arg_index'th argument?
// For example, `bool` is never permitted for `max`.
let only_this_argument =
fun_overloads.arg(arg_index, ty, &ctx.module.types);
if only_this_argument.is_empty() {
// No overload of `fun` accepts this type as the
// arg_index'th argument. Determine the set of types that
// would ever be allowed there.
let allowed: Vec<String> = fun_overloads
.allowed_args(arg_index, &ctx.module.to_ctx())
.iter()
.map(|ty| ctx.type_resolution_to_string(ty))
.collect();
if allowed.is_empty() {
// No overload of `fun` accepts any argument at this
// index, so it's a simple case of excess arguments.
// However, since each `MathFunction`'s overloads all
// have the same arity, we should have detected this
// earlier.
unreachable!("expected all overloads to have the same arity");
}
// Some overloads of `fun` do accept this many arguments,
// but none accept one of this type.
return Err(Box::new(Error::WrongArgumentType {
function,
call_span,
arg_span,
arg_index: arg_index as u32,
arg_ty,
allowed,
}));
}
_ => None,
} {
ctx.module.generate_predeclared_type(
if fun == crate::MathFunction::Modf {
crate::PredeclaredType::ModfResult { size, scalar }
} else {
crate::PredeclaredType::FrexpResult { size, scalar }
},
);
// This argument's type is accepted by some overloads---just
// not those overloads that remain, given the prior arguments.
// For example, `max` accepts `f32` as its second argument -
// but not if the first was `i32`.
// Build a list of the types that would have been accepted here,
// given the prior arguments.
let allowed: Vec<String> = remaining_overloads
.allowed_args(arg_index, &ctx.module.to_ctx())
.iter()
.map(|ty| ctx.type_resolution_to_string(ty))
.collect();
// Re-run the argument list to determine which prior argument
// made this one unacceptable.
let mut remaining_overloads = fun_overloads;
for (prior_index, &prior_expr) in
unconverted_arguments.iter().enumerate()
{
let prior_ty =
ctx.typifier()[prior_expr].inner_with(&ctx.module.types);
remaining_overloads = remaining_overloads.arg(
prior_index,
prior_ty,
&ctx.module.types,
);
if remaining_overloads
.arg(arg_index, ty, &ctx.module.types)
.is_empty()
{
// This is the argument that killed our dreams.
let inconsistent_span =
ctx.ast_expressions.get_span(arguments[prior_index]);
let inconsistent_ty =
ctx.as_diagnostic_display(prior_ty).to_string();
if allowed.is_empty() {
// Some overloads did accept `ty` at `arg_index`, but
// given the arguments up through `prior_expr`, we see
// no types acceptable at `arg_index`. This means that some
// overloads expect fewer arguments than others. However,
// each `MathFunction`'s overloads have the same arity, so this
// should be impossible.
unreachable!(
"expected all overloads to have the same arity"
);
}
// Report `arg`'s type as inconsistent with `prior_expr`'s
return Err(Box::new(Error::InconsistentArgumentType {
function,
call_span,
arg_span,
arg_index: arg_index as u32,
arg_ty,
inconsistent_span,
inconsistent_index: prior_index as u32,
inconsistent_ty,
allowed,
}));
}
}
unreachable!("Failed to eliminate argument type when re-tried");
}
remaining_overloads = next_remaining_overloads;
unconverted_arguments.push(lowered);
}
// Select the most preferred type rule for this call,
// given the argument types supplied above.
let rule = remaining_overloads.most_preferred();
let mut converted_arguments = Vec::with_capacity(arguments.len());
for (i, (&ast, unconverted)) in
arguments.iter().zip(unconverted_arguments).enumerate()
{
let goal_inner = rule.arguments[i].inner_with(&ctx.module.types);
let converted = match goal_inner.scalar_for_conversions(&ctx.module.types) {
Some(goal_scalar) => {
let arg_span = ctx.ast_expressions.get_span(ast);
ctx.try_automatic_conversion_for_leaf_scalar(
unconverted,
goal_scalar,
arg_span,
)?
}
// No conversion is necessary.
None => unconverted,
};
converted_arguments.push(converted);
}
// If this function returns a predeclared type, register it
// in `Module::special_types`. The typifier will expect to
// be able to find it there.
if let crate::proc::Conclusion::Predeclared(predeclared) = rule.conclusion {
ctx.module.generate_predeclared_type(predeclared);
}
crate::Expression::Math {
fun,
arg,
arg1,
arg2,
arg3,
arg: converted_arguments[0],
arg1: converted_arguments.get(1).cloned(),
arg2: converted_arguments.get(2).cloned(),
arg3: converted_arguments.get(3).cloned(),
}
} else if let Some(fun) = Texture::map(function.name) {
self.texture_sample_helper(fun, arguments, span, ctx)?

View File

@ -7,6 +7,7 @@ mod emitter;
pub mod index;
mod layouter;
mod namer;
mod overloads;
mod terminator;
mod type_methods;
mod typifier;
@ -18,6 +19,7 @@ pub use emitter::Emitter;
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
pub use namer::{EntryPointIndex, NameKey, Namer};
pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
pub use terminator::ensure_block_returns;
use thiserror::Error;
pub use type_methods::min_max_float_representable_by;

View File

@ -0,0 +1,120 @@
//! Dynamically dispatched [`OverloadSet`]s.
use crate::common::DiagnosticDebug;
use crate::ir;
use crate::proc::overloads::{list, regular, OverloadSet, Rule};
use crate::proc::{GlobalCtx, TypeResolution};
use alloc::vec::Vec;
use core::fmt;
macro_rules! define_any_overload_set {
{ $( $module:ident :: $name:ident, )* } => {
/// An [`OverloadSet`] that dynamically dispatches to concrete implementations.
#[derive(Clone)]
pub(in crate::proc::overloads) enum AnyOverloadSet {
$(
$name ( $module :: $name ),
)*
}
$(
impl From<$module::$name> for AnyOverloadSet {
fn from(concrete: $module::$name) -> Self {
AnyOverloadSet::$name(concrete)
}
}
)*
impl OverloadSet for AnyOverloadSet {
fn is_empty(&self) -> bool {
match *self {
$(
AnyOverloadSet::$name(ref x) => x.is_empty(),
)*
}
}
fn min_arguments(&self) -> usize {
match *self {
$(
AnyOverloadSet::$name(ref x) => x.min_arguments(),
)*
}
}
fn max_arguments(&self) -> usize {
match *self {
$(
AnyOverloadSet::$name(ref x) => x.max_arguments(),
)*
}
}
fn arg(
&self,
i: usize,
ty: &ir::TypeInner,
types: &crate::UniqueArena<ir::Type>,
) -> Self {
match *self {
$(
AnyOverloadSet::$name(ref x) => AnyOverloadSet::$name(x.arg(i, ty, types)),
)*
}
}
fn concrete_only(self, types: &crate::UniqueArena<ir::Type>) -> Self {
match self {
$(
AnyOverloadSet::$name(x) => AnyOverloadSet::$name(x.concrete_only(types)),
)*
}
}
fn most_preferred(&self) -> Rule {
match *self {
$(
AnyOverloadSet::$name(ref x) => x.most_preferred(),
)*
}
}
fn overload_list(&self, gctx: &GlobalCtx<'_>) -> Vec<Rule> {
match *self {
$(
AnyOverloadSet::$name(ref x) => x.overload_list(gctx),
)*
}
}
fn allowed_args(&self, i: usize, gctx: &GlobalCtx<'_>) -> Vec<TypeResolution> {
match *self {
$(
AnyOverloadSet::$name(ref x) => x.allowed_args(i, gctx),
)*
}
}
fn for_debug(&self, types: &crate::UniqueArena<ir::Type>) -> impl fmt::Debug {
DiagnosticDebug((self, types))
}
}
impl fmt::Debug for DiagnosticDebug<(&AnyOverloadSet, &crate::UniqueArena<ir::Type>)> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (set, types) = self.0;
match *set {
$(
AnyOverloadSet::$name(ref x) => DiagnosticDebug((x, types)).fmt(f),
)*
}
}
}
}
}
define_any_overload_set! {
list::List,
regular::Regular,
}

View File

@ -0,0 +1,174 @@
//! A set of type constructors, represented as a bitset.
use crate::ir;
use crate::proc::overloads::one_bits_iter::OneBitsIter;
bitflags::bitflags! {
/// A set of type constructors.
#[derive(Copy, Clone, Debug, PartialEq)]
pub(crate) struct ConstructorSet: u16 {
const SCALAR = 1 << 0;
const VEC2 = 1 << 1;
const VEC3 = 1 << 2;
const VEC4 = 1 << 3;
const MAT2X2 = 1 << 4;
const MAT2X3 = 1 << 5;
const MAT2X4 = 1 << 6;
const MAT3X2 = 1 << 7;
const MAT3X3 = 1 << 8;
const MAT3X4 = 1 << 9;
const MAT4X2 = 1 << 10;
const MAT4X3 = 1 << 11;
const MAT4X4 = 1 << 12;
const VECN = Self::VEC2.bits()
| Self::VEC3.bits()
| Self::VEC4.bits();
}
}
impl ConstructorSet {
/// Return the single-member set containing `inner`'s constructor.
pub const fn singleton(inner: &ir::TypeInner) -> ConstructorSet {
use ir::TypeInner as Ti;
use ir::VectorSize as Vs;
match *inner {
Ti::Scalar(_) => Self::SCALAR,
Ti::Vector { size, scalar: _ } => match size {
Vs::Bi => Self::VEC2,
Vs::Tri => Self::VEC3,
Vs::Quad => Self::VEC4,
},
Ti::Matrix {
columns,
rows,
scalar: _,
} => match (columns, rows) {
(Vs::Bi, Vs::Bi) => Self::MAT2X2,
(Vs::Bi, Vs::Tri) => Self::MAT2X3,
(Vs::Bi, Vs::Quad) => Self::MAT2X4,
(Vs::Tri, Vs::Bi) => Self::MAT3X2,
(Vs::Tri, Vs::Tri) => Self::MAT3X3,
(Vs::Tri, Vs::Quad) => Self::MAT3X4,
(Vs::Quad, Vs::Bi) => Self::MAT4X2,
(Vs::Quad, Vs::Tri) => Self::MAT4X3,
(Vs::Quad, Vs::Quad) => Self::MAT4X4,
},
_ => Self::empty(),
}
}
pub const fn is_singleton(self) -> bool {
self.bits().is_power_of_two()
}
/// Return an iterator over this set's members.
///
/// Members are produced as singleton, in order from most general to least.
pub fn members(self) -> impl Iterator<Item = ConstructorSet> {
OneBitsIter::new(self.bits() as u64).map(|bit| Self::from_bits(bit as u16).unwrap())
}
/// Return the size of the sole element of `self`.
///
/// # Panics
///
/// Panic if `self` is not a singleton.
pub fn size(self) -> ConstructorSize {
use ir::VectorSize as Vs;
use ConstructorSize as Cs;
match self {
ConstructorSet::SCALAR => Cs::Scalar,
ConstructorSet::VEC2 => Cs::Vector(Vs::Bi),
ConstructorSet::VEC3 => Cs::Vector(Vs::Tri),
ConstructorSet::VEC4 => Cs::Vector(Vs::Quad),
ConstructorSet::MAT2X2 => Cs::Matrix {
columns: Vs::Bi,
rows: Vs::Bi,
},
ConstructorSet::MAT2X3 => Cs::Matrix {
columns: Vs::Bi,
rows: Vs::Tri,
},
ConstructorSet::MAT2X4 => Cs::Matrix {
columns: Vs::Bi,
rows: Vs::Quad,
},
ConstructorSet::MAT3X2 => Cs::Matrix {
columns: Vs::Tri,
rows: Vs::Bi,
},
ConstructorSet::MAT3X3 => Cs::Matrix {
columns: Vs::Tri,
rows: Vs::Tri,
},
ConstructorSet::MAT3X4 => Cs::Matrix {
columns: Vs::Tri,
rows: Vs::Quad,
},
ConstructorSet::MAT4X2 => Cs::Matrix {
columns: Vs::Quad,
rows: Vs::Bi,
},
ConstructorSet::MAT4X3 => Cs::Matrix {
columns: Vs::Quad,
rows: Vs::Tri,
},
ConstructorSet::MAT4X4 => Cs::Matrix {
columns: Vs::Quad,
rows: Vs::Quad,
},
_ => unreachable!("ConstructorSet was not a singleton"),
}
}
}
/// The sizes a member of [`ConstructorSet`] might have.
#[derive(Clone, Copy)]
pub enum ConstructorSize {
/// The constructor is [`SCALAR`].
///
/// [`SCALAR`]: ConstructorSet::SCALAR
Scalar,
/// The constructor is `VECN` for some `N`.
Vector(ir::VectorSize),
/// The constructor is `MATCXR` for some `C` and `R`.
Matrix {
columns: ir::VectorSize,
rows: ir::VectorSize,
},
}
impl ConstructorSize {
/// Construct a [`TypeInner`] for a type with this size and the given `scalar`.
///
/// [`TypeInner`]: ir::TypeInner
pub const fn to_inner(self, scalar: ir::Scalar) -> ir::TypeInner {
match self {
Self::Scalar => ir::TypeInner::Scalar(scalar),
Self::Vector(size) => ir::TypeInner::Vector { size, scalar },
Self::Matrix { columns, rows } => ir::TypeInner::Matrix {
columns,
rows,
scalar,
},
}
}
}
macro_rules! constructor_set {
( $( $constr:ident )|* ) => {
{
use $crate::proc::overloads::constructor_set::ConstructorSet;
ConstructorSet::empty()
$(
.union(ConstructorSet::$constr)
)*
}
}
}
pub(in crate::proc::overloads) use constructor_set;

View File

@ -0,0 +1,182 @@
//! An [`OverloadSet`] represented as a vector of rules.
//!
//! [`OverloadSet`]: crate::proc::overloads::OverloadSet
use crate::common::{DiagnosticDebug, ForDebug, ForDebugWithTypes};
use crate::ir;
use crate::proc::overloads::one_bits_iter::OneBitsIter;
use crate::proc::overloads::Rule;
use crate::proc::{GlobalCtx, TypeResolution};
use alloc::rc::Rc;
use alloc::vec::Vec;
use core::fmt;
/// A simple list of overloads.
///
/// Note that this type is not quite as general as it looks, in that
/// the implementation of `most_preferred` doesn't work for arbitrary
/// lists of overloads. See the documentation for [`List::rules`] for
/// details.
#[derive(Clone)]
pub(in crate::proc::overloads) struct List {
/// A bitmask of which elements of `rules` are included in the set.
members: u64,
/// A list of type rules that are members of the set.
///
/// These must be listed in order such that every rule in the list
/// is always more preferred than all subsequent rules in the
/// list. If there is no such arrangement of rules, then you
/// cannot use `List` to represent the overload set.
rules: Rc<Vec<Rule>>,
}
impl List {
pub(in crate::proc::overloads) fn from_rules(rules: Vec<Rule>) -> List {
List {
members: len_to_full_mask(rules.len()),
rules: Rc::new(rules),
}
}
fn members(&self) -> impl Iterator<Item = (u64, &Rule)> {
OneBitsIter::new(self.members).map(|mask| {
let index = mask.trailing_zeros() as usize;
(mask, &self.rules[index])
})
}
fn filter<F>(&self, mut pred: F) -> List
where
F: FnMut(&Rule) -> bool,
{
let mut filtered_members = 0;
for (mask, rule) in self.members() {
if pred(rule) {
filtered_members |= mask;
}
}
List {
members: filtered_members,
rules: self.rules.clone(),
}
}
}
impl crate::proc::overloads::OverloadSet for List {
fn is_empty(&self) -> bool {
self.members == 0
}
fn min_arguments(&self) -> usize {
self.members()
.fold(None, |best, (_, rule)| {
// This is different from `max_arguments` because
// `<Option as PartialOrd>` doesn't work the way we'd like.
let len = rule.arguments.len();
Some(match best {
Some(best) => core::cmp::max(best, len),
None => len,
})
})
.unwrap()
}
fn max_arguments(&self) -> usize {
self.members()
.fold(None, |n, (_, rule)| {
core::cmp::max(n, Some(rule.arguments.len()))
})
.unwrap()
}
fn arg(&self, i: usize, arg_ty: &ir::TypeInner, types: &crate::UniqueArena<ir::Type>) -> Self {
log::debug!("arg {i} of type {:?}", arg_ty.for_debug(types));
self.filter(|rule| {
if log::log_enabled!(log::Level::Debug) {
log::debug!(" considering rule {:?}", rule.for_debug(types));
match rule.arguments.get(i) {
Some(rule_ty) => {
let rule_ty = rule_ty.inner_with(types);
if arg_ty.equivalent(rule_ty, types) {
log::debug!(" types are equivalent");
} else {
match arg_ty.automatically_converts_to(rule_ty, types) {
Some((from, to)) => {
log::debug!(
" converts automatically from {:?} to {:?}",
from.for_debug(),
to.for_debug()
);
}
None => {
log::debug!(" no conversion possible");
}
}
}
}
None => {
log::debug!(" argument index {i} out of bounds");
}
}
}
rule.arguments.get(i).is_some_and(|rule_ty| {
let rule_ty = rule_ty.inner_with(types);
arg_ty.equivalent(rule_ty, types)
|| arg_ty.automatically_converts_to(rule_ty, types).is_some()
})
})
}
fn concrete_only(self, types: &crate::UniqueArena<ir::Type>) -> Self {
self.filter(|rule| {
rule.arguments
.iter()
.all(|arg_ty| !arg_ty.inner_with(types).is_abstract(types))
})
}
fn most_preferred(&self) -> Rule {
// As documented for `Self::rules`, whatever rule is first is
// the most preferred. `OverloadSet` documents this method to
// panic if the set is empty.
let (_, rule) = self.members().next().unwrap();
rule.clone()
}
fn overload_list(&self, _gctx: &GlobalCtx<'_>) -> Vec<Rule> {
self.members().map(|(_, rule)| rule.clone()).collect()
}
fn allowed_args(&self, i: usize, _gctx: &GlobalCtx<'_>) -> Vec<TypeResolution> {
self.members()
.map(|(_, rule)| rule.arguments[i].clone())
.collect()
}
fn for_debug(&self, types: &crate::UniqueArena<ir::Type>) -> impl fmt::Debug {
DiagnosticDebug((self, types))
}
}
const fn len_to_full_mask(n: usize) -> u64 {
if n >= 64 {
panic!("List::rules can only hold up to 63 rules");
}
(1_u64 << n) - 1
}
impl ForDebugWithTypes for &List {}
impl fmt::Debug for DiagnosticDebug<(&List, &crate::UniqueArena<ir::Type>)> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (list, types) = self.0;
f.debug_list()
.entries(list.members().map(|(_mask, rule)| rule.for_debug(types)))
.finish()
}
}

View File

@ -0,0 +1,221 @@
//! Overload sets for [`ir::MathFunction`].
use crate::proc::overloads::any_overload_set::AnyOverloadSet;
use crate::proc::overloads::list::List;
use crate::proc::overloads::regular::regular;
use crate::proc::overloads::utils::{
concrete_int_scalars, float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule,
scalar_or_vecn, triples, vector_sizes,
};
use crate::proc::overloads::OverloadSet;
use crate::ir;
impl ir::MathFunction {
pub fn overloads(self) -> impl OverloadSet {
use ir::MathFunction as Mf;
let set: AnyOverloadSet = match self {
// Component-wise unary numeric operations
Mf::Abs | Mf::Sign => regular!(1, SCALAR|VECN of NUMERIC).into(),
// Component-wise binary numeric operations
Mf::Min | Mf::Max => regular!(2, SCALAR|VECN of NUMERIC).into(),
// Component-wise ternary numeric operations
Mf::Clamp => regular!(3, SCALAR|VECN of NUMERIC).into(),
// Component-wise unary floating-point operations
Mf::Sin
| Mf::Cos
| Mf::Tan
| Mf::Asin
| Mf::Acos
| Mf::Atan
| Mf::Sinh
| Mf::Cosh
| Mf::Tanh
| Mf::Asinh
| Mf::Acosh
| Mf::Atanh
| Mf::Saturate
| Mf::Radians
| Mf::Degrees
| Mf::Ceil
| Mf::Floor
| Mf::Round
| Mf::Fract
| Mf::Trunc
| Mf::Exp
| Mf::Exp2
| Mf::Log
| Mf::Log2
| Mf::Sqrt
| Mf::InverseSqrt => regular!(1, SCALAR|VECN of FLOAT).into(),
// Component-wise binary floating-point operations
Mf::Atan2 | Mf::Pow | Mf::Step => regular!(2, SCALAR|VECN of FLOAT).into(),
// Component-wise ternary floating-point operations
Mf::Fma | Mf::SmoothStep => regular!(3, SCALAR|VECN of FLOAT).into(),
// Component-wise unary concrete integer operations
Mf::CountTrailingZeros
| Mf::CountLeadingZeros
| Mf::CountOneBits
| Mf::ReverseBits
| Mf::FirstTrailingBit
| Mf::FirstLeadingBit => regular!(1, SCALAR|VECN of CONCRETE_INTEGER).into(),
// Packing functions
Mf::Pack4x8snorm | Mf::Pack4x8unorm => regular!(1, VEC4 of F32 -> U32).into(),
Mf::Pack2x16snorm | Mf::Pack2x16unorm | Mf::Pack2x16float => {
regular!(1, VEC2 of F32 -> U32).into()
}
Mf::Pack4xI8 => regular!(1, VEC4 of I32 -> U32).into(),
Mf::Pack4xU8 => regular!(1, VEC4 of U32 -> U32).into(),
// Unpacking functions
Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => regular!(1, SCALAR of U32 -> Vec4F).into(),
Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack2x16float => {
regular!(1, SCALAR of U32 -> Vec2F).into()
}
Mf::Unpack4xI8 => regular!(1, SCALAR of U32 -> Vec4I).into(),
Mf::Unpack4xU8 => regular!(1, SCALAR of U32 -> Vec4U).into(),
// One-off operations
Mf::Dot => regular!(2, VECN of NUMERIC -> Scalar).into(),
Mf::Modf => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Modf).into(),
Mf::Frexp => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Frexp).into(),
Mf::Ldexp => ldexp().into(),
Mf::Outer => outer().into(),
Mf::Cross => regular!(2, VEC3 of FLOAT).into(),
Mf::Distance => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
Mf::Length => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
Mf::Normalize => regular!(1, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
Mf::FaceForward => regular!(3, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
Mf::Reflect => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
Mf::Refract => refract().into(),
Mf::Mix => mix().into(),
Mf::Inverse => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT).into(),
Mf::Transpose => transpose().into(),
Mf::Determinant => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT -> Scalar).into(),
Mf::QuantizeToF16 => regular!(1, SCALAR|VECN of F32).into(),
Mf::ExtractBits => extract_bits().into(),
Mf::InsertBits => insert_bits().into(),
};
set
}
}
fn ldexp() -> List {
/// Construct the exponent scalar given the mantissa's inner.
fn exponent_from_mantissa(mantissa: ir::Scalar) -> ir::Scalar {
match mantissa.kind {
ir::ScalarKind::AbstractFloat => ir::Scalar::ABSTRACT_INT,
ir::ScalarKind::Float => ir::Scalar::I32,
_ => unreachable!("not a float scalar"),
}
}
list(
// The ldexp mantissa argument can be any floating-point type.
float_scalars_unimplemented_abstract().flat_map(|mantissa_scalar| {
// The exponent type is the integer counterpart of the mantissa type.
let exponent_scalar = exponent_from_mantissa(mantissa_scalar);
// There are scalar and vector component-wise overloads.
scalar_or_vecn(mantissa_scalar)
.zip(scalar_or_vecn(exponent_scalar))
.map(move |(mantissa, exponent)| {
let result = mantissa.clone();
rule([mantissa, exponent], result)
})
}),
)
}
fn outer() -> List {
list(
triples(
vector_sizes(),
vector_sizes(),
float_scalars_unimplemented_abstract(),
)
.map(|(cols, rows, scalar)| {
let left = ir::TypeInner::Vector { size: cols, scalar };
let right = ir::TypeInner::Vector { size: rows, scalar };
let result = ir::TypeInner::Matrix {
columns: cols,
rows,
scalar,
};
rule([left, right], result)
}),
)
}
fn refract() -> List {
list(
pairs(vector_sizes(), float_scalars_unimplemented_abstract()).map(|(size, scalar)| {
let incident = ir::TypeInner::Vector { size, scalar };
let normal = incident.clone();
let ratio = ir::TypeInner::Scalar(scalar);
let result = incident.clone();
rule([incident, normal, ratio], result)
}),
)
}
fn transpose() -> List {
list(
triples(vector_sizes(), vector_sizes(), float_scalars()).map(|(a, b, scalar)| {
let input = ir::TypeInner::Matrix {
columns: a,
rows: b,
scalar,
};
let output = ir::TypeInner::Matrix {
columns: b,
rows: a,
scalar,
};
rule([input], output)
}),
)
}
fn extract_bits() -> List {
list(concrete_int_scalars().flat_map(|scalar| {
scalar_or_vecn(scalar).map(|input| {
let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
let count = ir::TypeInner::Scalar(ir::Scalar::U32);
let output = input.clone();
rule([input, offset, count], output)
})
}))
}
fn insert_bits() -> List {
list(concrete_int_scalars().flat_map(|scalar| {
scalar_or_vecn(scalar).map(|input| {
let newbits = input.clone();
let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
let count = ir::TypeInner::Scalar(ir::Scalar::U32);
let output = input.clone();
rule([input, newbits, offset, count], output)
})
}))
}
fn mix() -> List {
list(float_scalars().flat_map(|scalar| {
scalar_or_vecn(scalar).flat_map(move |input| {
let scalar_ratio = ir::TypeInner::Scalar(scalar);
[
rule([input.clone(), input.clone(), input.clone()], input.clone()),
rule([input.clone(), input.clone(), scalar_ratio], input),
]
})
}))
}

View File

@ -0,0 +1,237 @@
/*! Overload resolution for builtin functions.
This module defines the [`OverloadSet`] trait, which provides methods the
validator and typifier can use to check the types to builtin functions,
determine their result types, and produce diagnostics that explain why a given
application is not allowed and suggest fixes.
You can call [`MathFunction::overloads`] to obtain an `impl OverloadSet`
representing the given `MathFunction`'s overloads.
[`MathFunction::overloads`]: crate::ir::MathFunction::overloads
*/
mod constructor_set;
mod regular;
mod scalar_set;
mod any_overload_set;
mod list;
mod mathfunction;
mod one_bits_iter;
mod rule;
mod utils;
pub use rule::{Conclusion, MissingSpecialType, Rule};
use crate::ir;
use crate::proc::TypeResolution;
use alloc::vec::Vec;
use core::fmt;
#[expect(rustdoc::private_intra_doc_links)]
/// A trait for types representing of a set of Naga IR type rules.
///
/// Given an expression like `max(x, y)`, there are multiple type rules that
/// could apply, depending on the types of `x` and `y`, like:
///
/// - `max(i32, i32) -> i32`
/// - `max(vec4<f32>, vec4<f32>) -> vec4<f32>`
///
/// and so on. Borrowing WGSL's terminology, Naga calls the full set of type
/// rules that might apply to a given expression its "overload candidates", or
/// "overloads" for short.
///
/// This trait is meant to be implemented by types that represent a set of
/// overload candidates. For example, [`MathFunction::overloads`] returns an
/// `impl OverloadSet` describing the overloads for the given Naga IR math
/// function. Naga's typifier, validator, and WGSL front end use this trait for
/// their work.
///
/// [`MathFunction::overloads`]: ir::MathFunction::overloads
///
/// # Automatic conversions
///
/// In principle, overload sets are easy: you simply list all the overloads the
/// function supports, and then when you're presented with a call to typecheck,
/// you just see if the actual argument types presented in the source code match
/// some overload from the list.
///
/// However, Naga supports languages like WGSL, which apply certain [automatic
/// conversions] if necessary to make a call fit some overload's requirements.
/// This means that the set of calls that are effectively allowed by a given set
/// of overloads can be quite large, since any combination of automatic
/// conversions might be applied.
///
/// For example, if `x` is a `u32`, and `100` is an abstract integer, then even
/// though `max` has no overload like `max(u32, AbstractInt) -> ...`, the
/// expression `max(x, 100)` is still allowed, because AbstractInt automatically
/// converts to `u32`.
///
/// [automatic conversions]: https://gpuweb.github.io/gpuweb/wgsl/#feasible-automatic-conversion
///
/// # How to use `OverloadSet`
///
/// The general process of using an `OverloadSet` is as follows:
///
/// - Obtain an `OverloadSet` for a given operation (say, by calling
/// [`MathFunction::overloads`]). This set is fixed by Naga IR's semantics.
///
/// - Call its [`arg`] method, supplying the type of the argument passed to the
/// operation at a certain index. This returns a new `OverloadSet` containing
/// only those overloads that could accept the given type for the given
/// argument. This includes overloads that only match if automatic conversions
/// are applied.
///
/// - If, at some point, the overload set becomes empty, then the set of
/// arguments is not allowed for this operation, and the program is invalid.
/// The `OverloadSet` trait provides an [`is_empty`] method.
///
/// - After all arguments have been supplied, if the overload set is still
/// non-empty, you can call its [`most_preferred`] method to find out which
/// overload has the least cost in terms of automatic conversions.
///
/// - If the call is rejected, you can use `OverloadSet` to help produce
/// diagnostic messages that explain exactly what went wrong. `OverloadSet`
/// implementations are meant to be cheap to [`Clone`], so it is fine to keep
/// the original overload set value around, and re-run the selection process,
/// attempting to supply the rejected argument at each step to see exactly
/// which preceding argument ruled it out. The [`overload_list`] and
/// [`allowed_args`] methods are helpful for this.
///
/// [`arg`]: OverloadSet::arg
/// [`is_empty`]: OverloadSet::is_empty
/// [`most_preferred`]: OverloadSet::most_preferred
/// [`overload_list`]: OverloadSet::overload_list
/// [`allowed_args`]: OverloadSet::allowed_args
///
/// # Concrete implementations
///
/// This module contains two private implementations of `OverloadSet`:
///
/// - The [`List`] type is a straightforward list of overloads. It is general,
/// but verbose to use. The [`utils`] module exports functions that construct
/// `List` overload sets for the functions that need this.
///
/// - The [`Regular`] type is a compact, efficient representation for the kinds
/// of overload sets commonly seen for Naga IR mathematical functions.
/// However, in return for its simplicity, it is not as flexible as [`List`].
/// This module use the [`regular!`] macro as a legible notation for `Regular`
/// sets.
///
/// [`List`]: list::List
/// [`Regular`]: regular::Regular
/// [`regular!`]: regular::regular
pub trait OverloadSet: Clone {
/// Return true if `self` is the empty set of overloads.
fn is_empty(&self) -> bool;
/// Return the smallest number of arguments in any type rule in the set.
///
/// # Panics
///
/// Panics if `self` is empty.
fn min_arguments(&self) -> usize;
/// Return the largest number of arguments in any type rule in the set.
///
/// # Panics
///
/// Panics if `self` is empty.
fn max_arguments(&self) -> usize;
/// Find the overloads that could accept a given argument.
///
/// Return a new overload set containing those members of `self` that could
/// accept a value of type `ty` for their `i`'th argument, once
/// feasible automatic conversions have been applied.
fn arg(&self, i: usize, ty: &ir::TypeInner, types: &crate::UniqueArena<ir::Type>) -> Self;
/// Limit `self` to overloads whose arguments are all concrete types.
///
/// Naga's overload resolution is based on WGSL's [overload resolution
/// algorithm][ora], which includes the following step:
///
/// > Eliminate any candidate where one of its subexpressions resolves to
/// > an abstract type after feasible automatic conversions, but another of
/// > the candidates subexpressions is not a const-expression.
/// >
/// > Note: As a consequence, if any subexpression in the phrase is not a
/// > const-expression, then all subexpressions in the phrase must have a
/// > concrete type.
///
/// Essentially, if any one of the arguments is not a constant expression,
/// then the operation is going to be evaluated at runtime, so all its
/// arguments must be converted to a concrete type. If you determine that an
/// argument is non-constant, you can use this trait method to toss out any
/// overloads that would accept abstract types.
///
/// In almost all cases, this operation has no effect. Only constant
/// expressions can have abstract types, so if any argument is not a
/// constant expression, it must have a concrete type. Although many
/// operations in Naga IR have overloads for both abstract types and
/// concrete types, few operations have overloads that accept a mix of
/// abstract and concrete types. Thus, a single concrete argument will
/// usually have eliminated all overloads that accept abstract types anyway.
/// (The exceptions are oddities like `Expression::Select`, where the
/// `condition` operand could be a runtime expression even as the `accept`
/// and `reject` operands have abstract types.)
///
/// Note that it is *not* correct to just [concretize] all arguments once
/// you've noticed that some argument is non-constant. The type to which
/// each argument is converted depends on the overloads available, not just
/// the argument's own type.
///
/// [ora]: https://gpuweb.github.io/gpuweb/wgsl/#overload-resolution-section
/// [concretize]: https://gpuweb.github.io/gpuweb/wgsl/#concretization
fn concrete_only(self, types: &crate::UniqueArena<ir::Type>) -> Self;
/// Return the most preferred candidate.
///
/// Rank the candidates in `self` as described in WGSL's [overload
/// resolution algorithm][ora], and return a singleton set containing the
/// most preferred candidate.
///
/// # Simplifications versus WGSL
///
/// Naga's process for selecting the most preferred candidate is simpler
/// than WGSL's:
///
/// - WGSL allows for the possibility of ambiguous calls, where multiple
/// overload candidates exist, no one candidate is clearly better than all
/// the others. For example, if a function has the two overloads `(i32,
/// f32) -> bool` and `(f32, i32) -> bool`, and the arguments are both
/// AbstractInt, neither overload is preferred over the other. Ambiguous
/// calls are errors.
///
/// However, Naga IR includes no operations whose overload sets allow such
/// situations to arise, so there is always a most preferred candidate.
/// Thus, this method infallibly returns a `Rule`, and has no way to
/// indicate ambiguity.
///
/// - WGSL says that the most preferred candidate depends on the conversion
/// rank for each argument, as determined by the types of the arguments
/// being passed.
///
/// However, the overloads of every operation in Naga IR can be ranked
/// even without knowing the argument types. So this method does not
/// require the argument types as a parameter.
///
/// # Panics
///
/// Panics if `self` is empty, or if no argument types have been supplied.
///
/// [ora]: https://gpuweb.github.io/gpuweb/wgsl/#overload-resolution-section
fn most_preferred(&self) -> Rule;
/// Return a type rule for each of the overloads in `self`.
fn overload_list(&self, gctx: &crate::proc::GlobalCtx<'_>) -> Vec<Rule>;
/// Return a list of the types allowed for argument `i`.
fn allowed_args(&self, i: usize, gctx: &crate::proc::GlobalCtx<'_>) -> Vec<TypeResolution>;
/// Return an object that can be formatted with [`core::fmt::Debug`].
fn for_debug(&self, types: &crate::UniqueArena<ir::Type>) -> impl fmt::Debug;
}

View File

@ -0,0 +1,89 @@
//! An iterator over bitmasks.
/// An iterator that produces the set bits in the given `u64`.
///
/// `OneBitsIter(n)` is an [`Iterator`] that produces each of the set bits in
/// `n`, as a bitmask, in order of increasing value. In other words, it produces
/// the unique sequence of distinct powers of two that adds up to `n`.
///
/// For example, iterating over `OneBitsIter(21)` produces the values `1`, `4`,
/// and `16`, in that order, because `21` is `0xb10101`.
///
/// When `n` is the bits of a bitmask, this iterates over the set bits in the
/// bitmask, in order of increasing bit value. `bitflags` does define an `iter`
/// method, but it's not well-specified or well-implemented.
///
/// The values produced are masks, not bit numbers. Use `u64::trailing_zeros` if
/// you need bit numbers.
pub struct OneBitsIter(u64);
impl OneBitsIter {
pub const fn new(bits: u64) -> Self {
Self(bits)
}
}
impl Iterator for OneBitsIter {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
if self.0 == 0 {
return None;
}
// Subtracting one from a value in binary clears the lowest `1` bit
// (call it `B`), and sets all the bits below that.
let mask = self.0 - 1;
// Complementing that means that we've instead *set* `B`, *cleared*
// everything below it, and *complemented* everything above it.
//
// Masking with the original value clears everything above and below
// `B`, leaving only `B` set. This is the value we produce.
let item = self.0 & !mask;
// Now that we've produced this bit, remove it from `self.0`.
self.0 &= mask;
Some(item)
}
}
#[test]
fn empty() {
assert_eq!(OneBitsIter(0).next(), None);
}
#[test]
fn all() {
let mut obi = OneBitsIter(!0);
for bit in 0..64 {
assert_eq!(obi.next(), Some(1 << bit));
}
assert_eq!(obi.next(), None);
}
#[test]
fn first() {
let mut obi = OneBitsIter(1);
assert_eq!(obi.next(), Some(1));
assert_eq!(obi.next(), None);
}
#[test]
fn last() {
let mut obi = OneBitsIter(1 << 63);
assert_eq!(obi.next(), Some(1 << 63));
assert_eq!(obi.next(), None);
}
#[test]
fn in_order() {
let mut obi = OneBitsIter(0b11011000001);
assert_eq!(obi.next(), Some(1));
assert_eq!(obi.next(), Some(64));
assert_eq!(obi.next(), Some(128));
assert_eq!(obi.next(), Some(512));
assert_eq!(obi.next(), Some(1024));
assert_eq!(obi.next(), None);
}

View File

@ -0,0 +1,541 @@
/*! A representation for highly regular overload sets common in Naga IR.
Many Naga builtin functions' overload sets have a highly regular
structure. For example, many arithmetic functions can be applied to
any floating-point type, or any vector thereof. This module defines a
handful of types for representing such simple overload sets that is
simple and efficient.
*/
use crate::common::{DiagnosticDebug, ForDebugWithTypes};
use crate::ir;
use crate::proc::overloads::constructor_set::{ConstructorSet, ConstructorSize};
use crate::proc::overloads::rule::{Conclusion, Rule};
use crate::proc::overloads::scalar_set::ScalarSet;
use crate::proc::overloads::OverloadSet;
use crate::proc::{GlobalCtx, TypeResolution};
use crate::UniqueArena;
use alloc::vec::Vec;
use core::fmt;
/// Overload sets represented as sets of scalars and constructors.
///
/// This type represents an [`OverloadSet`] using a bitset of scalar
/// types and a bitset of type constructors that might be applied to
/// those scalars. The overload set contains a rule for every possible
/// combination of scalars and constructors, essentially the cartesian
/// product of the two sets.
///
/// For example, if the arity is 2, set of scalars is { AbstractFloat,
/// `f32` }, and the set of constructors is { `vec2`, `vec3` }, then
/// that represents the set of overloads:
///
/// - (`vec2<AbstractFloat>`, `vec2<AbstractFloat>`) -> `vec2<AbstractFloat>`
/// - (`vec2<f32>`, `vec2<f32>`) -> `vec2<f32>`
/// - (`vec3<AbstractFloat>`, `vec3<AbstractFloat>`) -> `vec3<AbstractFloat>`
/// - (`vec3<f32>`, `vec3<f32>`) -> `vec3<f32>`
///
/// The `conclude` value says how to determine the return type from
/// the argument type.
///
/// Restrictions:
///
/// - All overloads must take the same number of arguments.
///
/// - For any given overload, all its arguments must have the same
/// type.
#[derive(Clone)]
pub(in crate::proc::overloads) struct Regular {
/// The number of arguments in the rules.
pub arity: usize,
/// The set of type constructors to apply.
pub constructors: ConstructorSet,
/// The set of scalars to apply them to.
pub scalars: ScalarSet,
/// How to determine a member rule's return type given the type of
/// its arguments.
pub conclude: ConclusionRule,
}
impl Regular {
pub(in crate::proc::overloads) const EMPTY: Regular = Regular {
arity: 0,
constructors: ConstructorSet::empty(),
scalars: ScalarSet::empty(),
conclude: ConclusionRule::ArgumentType,
};
/// Return an iterator over all the argument types allowed by `self`.
///
/// Return an iterator that produces, for each overload in `self`, the
/// constructor and scalar of its argument types and return type.
///
/// A [`Regular`] value can only represent overload sets where, in
/// each overload, all the arguments have the same type, and the
/// return type is always going to be a determined by the argument
/// types, so giving the constructor and scalar is sufficient to
/// characterize the entire rule.
fn members(&self) -> impl Iterator<Item = (ConstructorSize, ir::Scalar)> {
let scalars = self.scalars;
self.constructors.members().flat_map(move |constructor| {
let size = constructor.size();
// Technically, we don't need the "most general" `TypeInner` here,
// but since `ScalarSet::members` only produces singletons anyway,
// the effect is the same.
scalars
.members()
.map(move |singleton| (size, singleton.most_general_scalar()))
})
}
fn rules(&self) -> impl Iterator<Item = Rule> {
let arity = self.arity;
let conclude = self.conclude;
self.members()
.map(move |(size, scalar)| make_rule(arity, size, scalar, conclude))
}
}
impl OverloadSet for Regular {
fn is_empty(&self) -> bool {
self.constructors.is_empty() || self.scalars.is_empty()
}
fn min_arguments(&self) -> usize {
assert!(!self.is_empty());
self.arity
}
fn max_arguments(&self) -> usize {
assert!(!self.is_empty());
self.arity
}
fn arg(&self, i: usize, ty: &ir::TypeInner, types: &UniqueArena<ir::Type>) -> Self {
if i >= self.arity {
return Self::EMPTY;
}
let constructor = ConstructorSet::singleton(ty);
let scalars = match ty.scalar_for_conversions(types) {
Some(ty_scalar) => ScalarSet::convertible_from(ty_scalar),
None => ScalarSet::empty(),
};
Self {
arity: self.arity,
// Constrain all member rules' constructors to match `ty`'s.
constructors: self.constructors & constructor,
// Constrain all member rules' arguments to be something
// that `ty` can be converted to.
scalars: self.scalars & scalars,
conclude: self.conclude,
}
}
fn concrete_only(self, _types: &UniqueArena<ir::Type>) -> Self {
Self {
scalars: self.scalars & ScalarSet::CONCRETE,
..self
}
}
fn most_preferred(&self) -> Rule {
assert!(!self.is_empty());
// If there is more than one constructor allowed, then we must
// not have had any arguments supplied at all. In any case, we
// don't have any unambiguously preferred candidate.
assert!(self.constructors.is_singleton());
let size = self.constructors.size();
let scalar = self.scalars.most_general_scalar();
make_rule(self.arity, size, scalar, self.conclude)
}
fn overload_list(&self, _gctx: &GlobalCtx<'_>) -> Vec<Rule> {
self.rules().collect()
}
fn allowed_args(&self, i: usize, _gctx: &GlobalCtx<'_>) -> Vec<TypeResolution> {
if i >= self.arity {
return Vec::new();
}
self.members()
.map(|(size, scalar)| TypeResolution::Value(size.to_inner(scalar)))
.collect()
}
fn for_debug(&self, types: &UniqueArena<ir::Type>) -> impl fmt::Debug {
DiagnosticDebug((self, types))
}
}
/// Construct a [`Regular`] member [`Rule`] for the given arity and type.
///
/// [`Regular`] can only represent rules where all the argument types and the
/// return type are the same, so just knowing `arity` and `inner` is sufficient.
///
/// [`Rule`]: crate::proc::overloads::Rule
fn make_rule(
arity: usize,
size: ConstructorSize,
scalar: ir::Scalar,
conclusion_rule: ConclusionRule,
) -> Rule {
let inner = size.to_inner(scalar);
let arg = TypeResolution::Value(inner.clone());
Rule {
arguments: core::iter::repeat(arg.clone()).take(arity).collect(),
conclusion: conclusion_rule.conclude(size, scalar),
}
}
/// Conclusion-computing rules.
#[derive(Clone, Copy, Debug)]
#[repr(u8)]
pub(in crate::proc::overloads) enum ConclusionRule {
ArgumentType,
Scalar,
Frexp,
Modf,
U32,
Vec2F,
Vec4F,
Vec4I,
Vec4U,
}
impl ConclusionRule {
fn conclude(self, size: ConstructorSize, scalar: ir::Scalar) -> Conclusion {
match self {
Self::ArgumentType => Conclusion::Value(size.to_inner(scalar)),
Self::Scalar => Conclusion::Value(ir::TypeInner::Scalar(scalar)),
Self::Frexp => Conclusion::for_frexp_modf(ir::MathFunction::Frexp, size, scalar),
Self::Modf => Conclusion::for_frexp_modf(ir::MathFunction::Modf, size, scalar),
Self::U32 => Conclusion::Value(ir::TypeInner::Scalar(ir::Scalar::U32)),
Self::Vec2F => Conclusion::Value(ir::TypeInner::Vector {
size: ir::VectorSize::Bi,
scalar: ir::Scalar::F32,
}),
Self::Vec4F => Conclusion::Value(ir::TypeInner::Vector {
size: ir::VectorSize::Quad,
scalar: ir::Scalar::F32,
}),
Self::Vec4I => Conclusion::Value(ir::TypeInner::Vector {
size: ir::VectorSize::Quad,
scalar: ir::Scalar::I32,
}),
Self::Vec4U => Conclusion::Value(ir::TypeInner::Vector {
size: ir::VectorSize::Quad,
scalar: ir::Scalar::U32,
}),
}
}
}
impl fmt::Debug for DiagnosticDebug<(&Regular, &UniqueArena<ir::Type>)> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (regular, types) = self.0;
let rules: Vec<Rule> = regular.rules().collect();
f.debug_struct("List")
.field("rules", &rules.for_debug(types))
.field("conclude", &regular.conclude)
.finish()
}
}
impl ForDebugWithTypes for &Regular {}
impl fmt::Debug for DiagnosticDebug<(&[Rule], &UniqueArena<ir::Type>)> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (rules, types) = self.0;
f.debug_list()
.entries(rules.iter().map(|rule| rule.for_debug(types)))
.finish()
}
}
impl ForDebugWithTypes for &[Rule] {}
/// Construct a [`Regular`] [`OverloadSet`].
///
/// Examples:
///
/// - `regular!(2, SCALAR|VECN of FLOAT)`: An overload set whose rules take two
/// arguments of the same type: a floating-point scalar (possibly abstract) or
/// a vector of such. The return type is the same as the argument type.
///
/// - `regular!(1, VECN of FLOAT -> Scalar)`: An overload set whose rules take
/// one argument that is a vector of floats, and whose return type is the leaf
/// scalar type of the argument type.
///
/// The constructor values (before the `<` angle brackets `>`) are
/// constants from [`ConstructorSet`].
///
/// The scalar values (inside the `<` angle brackets `>`) are
/// constants from [`ScalarSet`].
///
/// When a return type identifier is given, it is treated as a variant
/// of the the [`ConclusionRule`] enum.
macro_rules! regular {
// regular!(ARITY, CONSTRUCTOR of SCALAR)
( $arity:literal , $( $constr:ident )|* of $( $scalar:ident )|*) => {
{
use $crate::proc::overloads;
use overloads::constructor_set::constructor_set;
use overloads::regular::{Regular, ConclusionRule};
use overloads::scalar_set::scalar_set;
Regular {
arity: $arity,
constructors: constructor_set!( $( $constr )|* ),
scalars: scalar_set!( $( $scalar )|* ),
conclude: ConclusionRule::ArgumentType,
}
}
};
// regular!(ARITY, CONSTRUCTOR of SCALAR -> CONCLUSION_RULE)
( $arity:literal , $( $constr:ident )|* of $( $scalar:ident )|* -> $conclude:ident) => {
{
use $crate::proc::overloads;
use overloads::constructor_set::constructor_set;
use overloads::regular::{Regular, ConclusionRule};
use overloads::scalar_set::scalar_set;
Regular {
arity: $arity,
constructors:constructor_set!( $( $constr )|* ),
scalars: scalar_set!( $( $scalar )|* ),
conclude: ConclusionRule::$conclude,
}
}
};
}
pub(in crate::proc::overloads) use regular;
#[cfg(test)]
mod test {
use super::*;
use crate::ir;
const fn scalar(scalar: ir::Scalar) -> ir::TypeInner {
ir::TypeInner::Scalar(scalar)
}
const fn vec2(scalar: ir::Scalar) -> ir::TypeInner {
ir::TypeInner::Vector {
scalar,
size: ir::VectorSize::Bi,
}
}
const fn vec3(scalar: ir::Scalar) -> ir::TypeInner {
ir::TypeInner::Vector {
scalar,
size: ir::VectorSize::Tri,
}
}
/// Assert that `set` has a most preferred candidate whose type
/// conclusion is `expected`.
#[track_caller]
fn check_return_type(set: &Regular, expected: &ir::TypeInner, arena: &UniqueArena<ir::Type>) {
assert!(!set.is_empty());
let special_types = ir::SpecialTypes::default();
let preferred = set.most_preferred();
let conclusion = preferred.conclusion;
let resolution = conclusion
.into_resolution(&special_types)
.expect("special types should have been pre-registered");
let inner = resolution.inner_with(arena);
assert!(
inner.equivalent(expected, arena),
"Expected {:?}, got {:?}",
expected.for_debug(arena),
inner.for_debug(arena),
);
}
#[test]
fn unary_vec_or_scalar_numeric_scalar() {
let arena = UniqueArena::default();
let builtin = regular!(1, SCALAR of NUMERIC);
let ok = builtin.arg(0, &scalar(ir::Scalar::U32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
let err = builtin.arg(0, &scalar(ir::Scalar::BOOL), &arena);
assert!(err.is_empty());
}
#[test]
fn unary_vec_or_scalar_numeric_vector() {
let arena = UniqueArena::default();
let builtin = regular!(1, VECN|SCALAR of NUMERIC);
let ok = builtin.arg(0, &vec3(ir::Scalar::F64), &arena);
check_return_type(&ok, &vec3(ir::Scalar::F64), &arena);
let err = builtin.arg(0, &vec3(ir::Scalar::BOOL), &arena);
assert!(err.is_empty());
}
#[test]
fn unary_vec_or_scalar_numeric_matrix() {
let arena = UniqueArena::default();
let builtin = regular!(1, VECN|SCALAR of NUMERIC);
let err = builtin.arg(
0,
&ir::TypeInner::Matrix {
columns: ir::VectorSize::Tri,
rows: ir::VectorSize::Tri,
scalar: ir::Scalar::F32,
},
&arena,
);
assert!(err.is_empty());
}
#[test]
#[rustfmt::skip]
fn binary_vec_or_scalar_numeric_scalar() {
let arena = UniqueArena::default();
let builtin = regular!(2, VECN|SCALAR of NUMERIC);
let ok = builtin
.arg(0, &scalar(ir::Scalar::F32), &arena)
.arg(1, &scalar(ir::Scalar::F32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
.arg(1, &scalar(ir::Scalar::F32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::F32), &arena)
.arg(1, &scalar(ir::Scalar::ABSTRACT_INT), &arena);
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::U32), &arena)
.arg(1, &scalar(ir::Scalar::U32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::U32), &arena)
.arg(1, &scalar(ir::Scalar::ABSTRACT_INT), &arena);
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::ABSTRACT_INT), &arena)
.arg(1, &scalar(ir::Scalar::U32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
// Not numeric.
let err = builtin
.arg(0, &scalar(ir::Scalar::BOOL), &arena)
.arg(1, &scalar(ir::Scalar::BOOL), &arena);
assert!(err.is_empty());
// Different floating-point types.
let err = builtin
.arg(0, &scalar(ir::Scalar::F32), &arena)
.arg(1, &scalar(ir::Scalar::F64), &arena);
assert!(err.is_empty());
// Different constructor.
let err = builtin
.arg(0, &scalar(ir::Scalar::F32), &arena)
.arg(1, &vec2(ir::Scalar::F32), &arena);
assert!(err.is_empty());
// Different vector size
let err = builtin
.arg(0, &vec2(ir::Scalar::F32), &arena)
.arg(1, &vec3(ir::Scalar::F32), &arena);
assert!(err.is_empty());
}
#[test]
#[rustfmt::skip]
fn binary_vec_or_scalar_numeric_vector() {
let arena = UniqueArena::default();
let builtin = regular!(2, VECN|SCALAR of NUMERIC);
let ok = builtin
.arg(0, &vec3(ir::Scalar::F32), &arena)
.arg(1, &vec3(ir::Scalar::F32), &arena);
check_return_type(&ok, &vec3(ir::Scalar::F32), &arena);
// Different vector sizes.
let err = builtin
.arg(0, &vec2(ir::Scalar::F32), &arena)
.arg(1, &vec3(ir::Scalar::F32), &arena);
assert!(err.is_empty());
// Different vector scalars.
let err = builtin
.arg(0, &vec3(ir::Scalar::F32), &arena)
.arg(1, &vec3(ir::Scalar::F64), &arena);
assert!(err.is_empty());
// Mix of vectors and scalars.
let err = builtin
.arg(0, &scalar(ir::Scalar::F32), &arena)
.arg(1, &vec3(ir::Scalar::F32), &arena);
assert!(err.is_empty());
}
#[test]
#[rustfmt::skip]
fn binary_vec_or_scalar_numeric_vector_abstract() {
let arena = UniqueArena::default();
let builtin = regular!(2, VECN|SCALAR of NUMERIC);
let ok = builtin
.arg(0, &vec2(ir::Scalar::ABSTRACT_INT), &arena)
.arg(1, &vec2(ir::Scalar::U32), &arena);
check_return_type(&ok, &vec2(ir::Scalar::U32), &arena);
let ok = builtin
.arg(0, &vec3(ir::Scalar::ABSTRACT_INT), &arena)
.arg(1, &vec3(ir::Scalar::F32), &arena);
check_return_type(&ok, &vec3(ir::Scalar::F32), &arena);
let ok = builtin
.arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
.arg(1, &scalar(ir::Scalar::F32), &arena);
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
let err = builtin
.arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
.arg(1, &scalar(ir::Scalar::U32), &arena);
assert!(err.is_empty());
let err = builtin
.arg(0, &scalar(ir::Scalar::I32), &arena)
.arg(1, &scalar(ir::Scalar::U32), &arena);
assert!(err.is_empty());
}
}

View File

@ -0,0 +1,146 @@
/*! Type rules.
An implementation of [`OverloadSet`] represents a set of type rules, each of
which has a list of types for its arguments, and a conclusion about the
type of the expression as a whole.
This module defines the [`Rule`] type, representing a type rule from an
[`OverloadSet`], and the [`Conclusion`] type, a specialized enum for
representing a type rule's conclusion.
[`OverloadSet`]: crate::proc::overloads::OverloadSet
*/
use crate::common::{DiagnosticDebug, ForDebugWithTypes};
use crate::ir;
use crate::proc::overloads::constructor_set::ConstructorSize;
use crate::proc::TypeResolution;
use crate::UniqueArena;
use alloc::vec::Vec;
use core::fmt;
use core::result::Result;
/// A single type rule.
#[derive(Clone)]
pub struct Rule {
pub arguments: Vec<TypeResolution>,
pub conclusion: Conclusion,
}
/// The result type of a [`Rule`].
///
/// A `Conclusion` value represents the return type of some operation
/// in the builtin function database.
///
/// This is very similar to [`TypeInner`], except that it represents
/// predeclared types using [`PredeclaredType`], so that overload
/// resolution can delegate registering predeclared types to its users.
///
/// [`TypeInner`]: ir::TypeInner
/// [`PredeclaredType`]: ir::PredeclaredType
#[derive(Clone, Debug)]
pub enum Conclusion {
/// A type that can be entirely characterized by a [`TypeInner`] value.
///
/// [`TypeInner`]: ir::TypeInner
Value(ir::TypeInner),
/// A type that should be registered in the module's
/// [`SpecialTypes::predeclared_types`] table.
///
/// This is used for operations like [`Frexp`] and [`Modf`].
///
/// [`SpecialTypes::predeclared_types`]: ir::SpecialTypes::predeclared_types
/// [`Frexp`]: crate::ir::MathFunction::Frexp
/// [`Modf`]: crate::ir::MathFunction::Modf
Predeclared(ir::PredeclaredType),
}
impl Conclusion {
pub fn for_frexp_modf(
function: ir::MathFunction,
size: ConstructorSize,
scalar: ir::Scalar,
) -> Self {
use ir::MathFunction as Mf;
use ir::PredeclaredType as Pt;
let size = match size {
ConstructorSize::Scalar => None,
ConstructorSize::Vector(size) => Some(size),
ConstructorSize::Matrix { .. } => {
unreachable!("FrexpModf only supports scalars and vectors");
}
};
let predeclared = match function {
Mf::Frexp => Pt::FrexpResult { size, scalar },
Mf::Modf => Pt::ModfResult { size, scalar },
_ => {
unreachable!("FrexpModf only supports Frexp and Modf");
}
};
Conclusion::Predeclared(predeclared)
}
pub fn into_resolution(
self,
special_types: &ir::SpecialTypes,
) -> Result<TypeResolution, MissingSpecialType> {
match self {
Conclusion::Value(inner) => Ok(TypeResolution::Value(inner)),
Conclusion::Predeclared(predeclared) => {
let handle = *special_types
.predeclared_types
.get(&predeclared)
.ok_or(MissingSpecialType)?;
Ok(TypeResolution::Handle(handle))
}
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("Special type is not registered within the module")]
pub struct MissingSpecialType;
impl ForDebugWithTypes for &Rule {}
impl fmt::Debug for DiagnosticDebug<(&Rule, &UniqueArena<ir::Type>)> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (rule, arena) = self.0;
f.write_str("(")?;
for (i, argument) in rule.arguments.iter().enumerate() {
if i > 0 {
f.write_str(", ")?;
}
write!(f, "{:?}", argument.for_debug(arena))?;
}
write!(f, ") -> {:?}", rule.conclusion.for_debug(arena))
}
}
impl ForDebugWithTypes for &Conclusion {}
impl fmt::Debug for DiagnosticDebug<(&Conclusion, &UniqueArena<ir::Type>)> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (conclusion, ctx) = self.0;
#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))]
{
use crate::common::wgsl::TypeContext;
ctx.write_type_conclusion(conclusion, f)?;
}
#[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))]
{
let _ = ctx;
write!(f, "{conclusion:?}")?;
}
Ok(())
}
}

View File

@ -0,0 +1,141 @@
//! A set of scalar types, represented as a bitset.
use crate::ir::Scalar;
use crate::proc::overloads::one_bits_iter::OneBitsIter;
macro_rules! define_scalar_set {
{ $( $scalar:ident, )* } => {
/// An enum used to assign distinct bit numbers to [`ScalarSet`] elements.
#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
#[repr(u32)]
enum ScalarSetBits {
$( $scalar, )*
Count,
}
/// A table mapping bit numbers to the [`Scalar`] values they represent.
static SCALARS_FOR_BITS: [Scalar; ScalarSetBits::Count as usize] = [
$(
Scalar::$scalar,
)*
];
bitflags::bitflags! {
/// A set of scalar types.
///
/// This represents a set of [`Scalar`] types.
///
/// The Naga IR conversion rules arrange scalar types into a
/// lattice. The scalar types' bit values are chosen such that, if
/// A is convertible to B, then A's bit value is less than B's.
#[derive(Copy, Clone, Debug)]
pub(crate) struct ScalarSet: u16 {
$(
const $scalar = 1 << (ScalarSetBits::$scalar as u32);
)*
}
}
impl ScalarSet {
/// Return the set of scalars containing only `scalar`.
#[expect(dead_code)]
pub const fn singleton(scalar: Scalar) -> Self {
match scalar {
$(
Scalar::$scalar => Self::$scalar,
)*
_ => Self::empty(),
}
}
}
}
}
define_scalar_set! {
// Scalar types must be listed here in an order such that, if A is
// convertible to B, then A appears before B.
//
// In the concrete types, the 32-bit types *must* appear before
// other sizes, since that is how we represent conversion rank.
ABSTRACT_INT, ABSTRACT_FLOAT,
I32, I64,
U32, U64,
F32, F16, F64,
BOOL,
}
impl ScalarSet {
/// Return the set of scalars to which `scalar` can be automatically
/// converted.
pub fn convertible_from(scalar: Scalar) -> Self {
use Scalar as Sc;
match scalar {
Sc::I32 => Self::I32,
Sc::I64 => Self::I64,
Sc::U32 => Self::U32,
Sc::U64 => Self::U64,
Sc::F16 => Self::F16,
Sc::F32 => Self::F32,
Sc::F64 => Self::F64,
Sc::BOOL => Self::BOOL,
Sc::ABSTRACT_INT => Self::INTEGER | Self::FLOAT,
Sc::ABSTRACT_FLOAT => Self::FLOAT,
_ => Self::empty(),
}
}
/// Return the lowest-ranked member of `self` as a [`Scalar`].
///
/// # Panics
///
/// Panics if `self` is empty.
pub fn most_general_scalar(self) -> Scalar {
// If the set is empty, this returns the full bit-length of
// `self.bits()`, an index which is out of bounds for
// `SCALARS_FOR_BITS`.
let lowest = self.bits().trailing_zeros();
*SCALARS_FOR_BITS.get(lowest as usize).unwrap()
}
/// Return an iterator over this set's members.
///
/// Members are produced as singleton, in order from most general to least.
pub fn members(self) -> impl Iterator<Item = ScalarSet> {
OneBitsIter::new(self.bits() as u64).map(|bit| Self::from_bits(bit as u16).unwrap())
}
pub const FLOAT: Self = Self::ABSTRACT_FLOAT
.union(Self::F16)
.union(Self::F32)
.union(Self::F64);
pub const INTEGER: Self = Self::ABSTRACT_INT
.union(Self::I32)
.union(Self::I64)
.union(Self::U32)
.union(Self::U64);
pub const NUMERIC: Self = Self::FLOAT.union(Self::INTEGER);
pub const ABSTRACT: Self = Self::ABSTRACT_INT.union(Self::ABSTRACT_FLOAT);
pub const CONCRETE: Self = Self::all().difference(Self::ABSTRACT);
pub const CONCRETE_INTEGER: Self = Self::INTEGER.intersection(Self::CONCRETE);
pub const CONCRETE_FLOAT: Self = Self::FLOAT.intersection(Self::CONCRETE);
/// Floating-point scalars, with the abstract floats omitted for
/// #7405.
pub const FLOAT_ABSTRACT_UNIMPLEMENTED: Self = Self::CONCRETE_FLOAT;
}
macro_rules! scalar_set {
( $( $scalar:ident )|* ) => {
{
use $crate::proc::overloads::scalar_set::ScalarSet;
ScalarSet::empty()
$(
.union(ScalarSet::$scalar)
)*
}
}
}
pub(in crate::proc::overloads) use scalar_set;

View File

@ -0,0 +1,113 @@
//! Utility functions for constructing [`List`] overload sets.
//!
//! [`List`]: crate::proc::overloads::list::List
use crate::ir;
use crate::proc::overloads::list::List;
use crate::proc::overloads::rule::{Conclusion, Rule};
use crate::proc::TypeResolution;
use alloc::vec::Vec;
/// Produce all vector sizes.
pub fn vector_sizes() -> impl Iterator<Item = ir::VectorSize> + Clone {
static SIZES: [ir::VectorSize; 3] = [
ir::VectorSize::Bi,
ir::VectorSize::Tri,
ir::VectorSize::Quad,
];
SIZES.iter().cloned()
}
/// Produce all the floating-point [`ir::Scalar`]s.
///
/// Note that `F32` must appear before other sizes; this is how we
/// represent conversion rank.
pub fn float_scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
[
ir::Scalar::ABSTRACT_FLOAT,
ir::Scalar::F32,
ir::Scalar::F16,
ir::Scalar::F64,
]
.into_iter()
}
/// Produce all the floating-point [`ir::Scalar`]s, but omit
/// abstract types, for #7405.
pub fn float_scalars_unimplemented_abstract() -> impl Iterator<Item = ir::Scalar> + Clone {
[ir::Scalar::F32, ir::Scalar::F16, ir::Scalar::F64].into_iter()
}
/// Produce all concrete integer [`ir::Scalar`]s.
///
/// Note that `I32` and `U32` must come first; this is how we
/// represent conversion rank.
pub fn concrete_int_scalars() -> impl Iterator<Item = ir::Scalar> {
[
ir::Scalar::I32,
ir::Scalar::U32,
ir::Scalar::I64,
ir::Scalar::U64,
]
.into_iter()
}
/// Produce the scalar and vector [`ir::TypeInner`]s that have `s` as
/// their scalar.
pub fn scalar_or_vecn(scalar: ir::Scalar) -> impl Iterator<Item = ir::TypeInner> {
[
ir::TypeInner::Scalar(scalar),
ir::TypeInner::Vector {
size: ir::VectorSize::Bi,
scalar,
},
ir::TypeInner::Vector {
size: ir::VectorSize::Tri,
scalar,
},
ir::TypeInner::Vector {
size: ir::VectorSize::Quad,
scalar,
},
]
.into_iter()
}
/// Construct a [`Rule`] for an operation with the given
/// argument types and return type.
pub fn rule<const N: usize>(args: [ir::TypeInner; N], ret: ir::TypeInner) -> Rule {
Rule {
arguments: Vec::from_iter(args.into_iter().map(TypeResolution::Value)),
conclusion: Conclusion::Value(ret),
}
}
/// Construct a [`List`] from the given rules.
pub fn list(rules: impl Iterator<Item = Rule>) -> List {
List::from_rules(rules.collect())
}
/// Return the cartesian product of two iterators.
pub fn pairs<T: Clone, U>(
left: impl Iterator<Item = T>,
right: impl Iterator<Item = U> + Clone,
) -> impl Iterator<Item = (T, U)> {
left.flat_map(move |t| right.clone().map(move |u| (t.clone(), u)))
}
/// Return the cartesian product of three iterators.
pub fn triples<T: Clone, U: Clone, V>(
left: impl Iterator<Item = T>,
middle: impl Iterator<Item = U> + Clone,
right: impl Iterator<Item = V> + Clone,
) -> impl Iterator<Item = (T, U, V)> {
left.flat_map(move |t| {
let right = right.clone();
middle.clone().flat_map(move |u| {
let t = t.clone();
right.clone().map(move |v| (t.clone(), u.clone(), v))
})
})
}

View File

@ -3,6 +3,7 @@ use alloc::{format, string::String};
use thiserror::Error;
use crate::arena::{Arena, Handle, UniqueArena};
use crate::common::ForDebugWithTypes;
/// The result of computing an expression's type.
///
@ -191,6 +192,14 @@ pub enum ResolveError {
FunctionArgumentNotFound(u32),
#[error("Special type is not registered within the module")]
MissingSpecialType,
#[error("Call to builtin {0} has incorrect or ambiguous arguments")]
BuiltinArgumentsInvalid(String),
}
impl From<crate::proc::MissingSpecialType> for ResolveError {
fn from(_unit_struct: crate::proc::MissingSpecialType) -> Self {
ResolveError::MissingSpecialType
}
}
pub struct ResolveContext<'a> {
@ -635,210 +644,46 @@ impl<'a> ResolveContext<'a> {
arg2: _,
arg3: _,
} => {
use crate::MathFunction as Mf;
use crate::proc::OverloadSet as _;
let mut overloads = fun.overloads();
log::debug!(
"initial overloads for {fun:?}, {:#?}",
overloads.for_debug(types)
);
// If any argument is not a constant expression, then no
// overloads that accept abstract values should be considered.
// `OverloadSet::concrete_only` is supposed to help impose this
// restriction. However, no `MathFunction` accepts a mix of
// abstract and concrete arguments, so we don't need to worry
// about that here.
let res_arg = past(arg)?;
match fun {
Mf::Abs
| Mf::Min
| Mf::Max
| Mf::Clamp
| Mf::Saturate
| Mf::Cos
| Mf::Cosh
| Mf::Sin
| Mf::Sinh
| Mf::Tan
| Mf::Tanh
| Mf::Acos
| Mf::Asin
| Mf::Atan
| Mf::Atan2
| Mf::Asinh
| Mf::Acosh
| Mf::Atanh
| Mf::Radians
| Mf::Degrees
| Mf::Ceil
| Mf::Floor
| Mf::Round
| Mf::Fract
| Mf::Trunc
| Mf::Ldexp
| Mf::Exp
| Mf::Exp2
| Mf::Log
| Mf::Log2
| Mf::Pow
| Mf::QuantizeToF16 => res_arg.clone(),
Mf::Modf | Mf::Frexp => {
let (size, scalar) = match res_arg.inner_with(types) {
&Ti::Scalar(scalar) => (None, scalar),
&Ti::Vector { scalar, size } => (Some(size), scalar),
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{fun:?}({other:?}, _)"
)))
}
};
let result = self
.special_types
.predeclared_types
.get(&if fun == Mf::Modf {
crate::PredeclaredType::ModfResult { size, scalar }
} else {
crate::PredeclaredType::FrexpResult { size, scalar }
})
.ok_or(ResolveError::MissingSpecialType)?;
TypeResolution::Handle(*result)
}
Mf::Dot => match *res_arg.inner_with(types) {
Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)),
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{fun:?}({other:?}, _)"
)))
}
},
Mf::Outer => {
let arg1 = arg1.ok_or_else(|| {
ResolveError::IncompatibleOperands(format!("{fun:?}(_, None)"))
})?;
match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) {
(
&Ti::Vector {
size: columns,
scalar,
},
&Ti::Vector { size: rows, .. },
) => TypeResolution::Value(Ti::Matrix {
columns,
rows,
scalar,
}),
(left, right) => {
return Err(ResolveError::IncompatibleOperands(format!(
"{fun:?}({left:?}, {right:?})"
)))
}
}
}
Mf::Cross => res_arg.clone(),
Mf::Distance | Mf::Length => match *res_arg.inner_with(types) {
Ti::Scalar(scalar) | Ti::Vector { scalar, size: _ } => {
TypeResolution::Value(Ti::Scalar(scalar))
}
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{fun:?}({other:?})"
)))
}
},
Mf::Normalize | Mf::FaceForward | Mf::Reflect | Mf::Refract => res_arg.clone(),
// computational
Mf::Sign
| Mf::Fma
| Mf::Mix
| Mf::Step
| Mf::SmoothStep
| Mf::Sqrt
| Mf::InverseSqrt => res_arg.clone(),
Mf::Transpose => match *res_arg.inner_with(types) {
Ti::Matrix {
columns,
rows,
scalar,
} => TypeResolution::Value(Ti::Matrix {
columns: rows,
rows: columns,
scalar,
}),
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{fun:?}({other:?})"
)))
}
},
Mf::Inverse => match *res_arg.inner_with(types) {
Ti::Matrix {
columns,
rows,
scalar,
} if columns == rows => TypeResolution::Value(Ti::Matrix {
columns,
rows,
scalar,
}),
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{fun:?}({other:?})"
)))
}
},
Mf::Determinant => match *res_arg.inner_with(types) {
Ti::Matrix { scalar, .. } => TypeResolution::Value(Ti::Scalar(scalar)),
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{fun:?}({other:?})"
)))
}
},
// bits
Mf::CountTrailingZeros
| Mf::CountLeadingZeros
| Mf::CountOneBits
| Mf::ReverseBits
| Mf::ExtractBits
| Mf::InsertBits
| Mf::FirstTrailingBit
| Mf::FirstLeadingBit => match *res_arg.inner_with(types) {
Ti::Scalar(
scalar @ crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
..
},
) => TypeResolution::Value(Ti::Scalar(scalar)),
Ti::Vector {
size,
scalar:
scalar @ crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
..
},
} => TypeResolution::Value(Ti::Vector { size, scalar }),
ref other => {
return Err(ResolveError::IncompatibleOperands(format!(
"{fun:?}({other:?})"
)))
}
},
// data packing
Mf::Pack4x8snorm
| Mf::Pack4x8unorm
| Mf::Pack2x16snorm
| Mf::Pack2x16unorm
| Mf::Pack2x16float
| Mf::Pack4xI8
| Mf::Pack4xU8 => TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)),
// data unpacking
Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector {
size: crate::VectorSize::Quad,
scalar: crate::Scalar::F32,
}),
Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack2x16float => {
TypeResolution::Value(Ti::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::F32,
})
}
Mf::Unpack4xI8 => TypeResolution::Value(Ti::Vector {
size: crate::VectorSize::Quad,
scalar: crate::Scalar::I32,
}),
Mf::Unpack4xU8 => TypeResolution::Value(Ti::Vector {
size: crate::VectorSize::Quad,
scalar: crate::Scalar::U32,
}),
overloads = overloads.arg(0, res_arg.inner_with(types), types);
log::debug!(
"overloads after arg 0 of type {:?}: {:#?}",
res_arg.for_debug(types),
overloads.for_debug(types)
);
if let Some(arg1) = arg1 {
let res_arg1 = past(arg1)?;
overloads = overloads.arg(1, res_arg1.inner_with(types), types);
log::debug!(
"overloads after arg 1 of type {:?}: {:#?}",
res_arg1.for_debug(types),
overloads.for_debug(types)
);
}
if overloads.is_empty() {
return Err(ResolveError::BuiltinArgumentsInvalid(format!("{fun:?}")));
}
let rule = overloads.most_preferred();
rule.conclusion.into_resolution(self.special_types)?
}
crate::Expression::As {
expr,

View File

@ -2,6 +2,7 @@ use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, T
use crate::arena::UniqueArena;
use crate::{
arena::Handle,
proc::OverloadSet as _,
proc::{IndexableLengthError, ResolveError},
};
@ -1016,659 +1017,59 @@ impl super::Validator {
arg2,
arg3,
} => {
use crate::MathFunction as Mf;
let actuals: &[_] = match (arg1, arg2, arg3) {
(None, None, None) => &[arg],
(Some(arg1), None, None) => &[arg, arg1],
(Some(arg1), Some(arg2), None) => &[arg, arg1, arg2],
(Some(arg1), Some(arg2), Some(arg3)) => &[arg, arg1, arg2, arg3],
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
let resolve = |arg| &resolver[arg];
let arg_ty = resolve(arg);
let arg1_ty = arg1.map(resolve);
let arg2_ty = arg2.map(resolve);
let arg3_ty = arg3.map(resolve);
match fun {
Mf::Abs => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
let good = match *arg_ty {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
scalar.kind != Sk::Bool
}
_ => false,
};
if !good {
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
}
let actual_types: &[_] = match *actuals {
[arg0] => &[resolve(arg0)],
[arg0, arg1] => &[resolve(arg0), resolve(arg1)],
[arg0, arg1, arg2] => &[resolve(arg0), resolve(arg1), resolve(arg2)],
[arg0, arg1, arg2, arg3] => {
&[resolve(arg0), resolve(arg1), resolve(arg2), resolve(arg3)]
}
Mf::Min | Mf::Max => {
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), None, None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
let good = match *arg_ty {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
scalar.kind != Sk::Bool
}
_ => false,
};
if !good {
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Clamp => {
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), Some(ty2), None) => (ty1, ty2),
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
let good = match *arg_ty {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
scalar.kind != Sk::Bool
}
_ => false,
};
if !good {
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
if arg2_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
2,
arg2.unwrap(),
));
}
}
Mf::Saturate
| Mf::Cos
| Mf::Cosh
| Mf::Sin
| Mf::Sinh
| Mf::Tan
| Mf::Tanh
| Mf::Acos
| Mf::Asin
| Mf::Atan
| Mf::Asinh
| Mf::Acosh
| Mf::Atanh
| Mf::Radians
| Mf::Degrees
| Mf::Ceil
| Mf::Floor
| Mf::Round
| Mf::Fract
| Mf::Trunc
| Mf::Exp
| Mf::Exp2
| Mf::Log
| Mf::Log2
| Mf::Length
| Mf::Sqrt
| Mf::InverseSqrt => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
if scalar.kind == Sk::Float => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::Sign => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Float | Sk::Sint,
..
})
| Ti::Vector {
scalar:
Sc {
kind: Sk::Float | Sk::Sint,
..
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), None, None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
if scalar.kind == Sk::Float => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Modf | Mf::Frexp => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
if !matches!(*arg_ty,
Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
if scalar.kind == Sk::Float)
{
return Err(ExpressionError::InvalidArgumentType(fun, 1, arg));
}
}
Mf::Ldexp => {
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), None, None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
let size0 = match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Float, ..
}) => None,
Ti::Vector {
scalar:
Sc {
kind: Sk::Float, ..
},
size,
} => Some(size),
_ => {
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
}
};
let good = match *arg1_ty {
Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true,
Ti::Vector {
size,
scalar: Sc { kind: Sk::Sint, .. },
} if Some(size) == size0 => true,
_ => false,
};
if !good {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Dot => {
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), None, None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Vector {
scalar:
Sc {
kind: Sk::Float | Sk::Sint | Sk::Uint,
..
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Outer | Mf::Reflect => {
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), None, None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Vector {
scalar:
Sc {
kind: Sk::Float, ..
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Cross => {
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), None, None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Vector {
scalar:
Sc {
kind: Sk::Float, ..
},
size: crate::VectorSize::Tri,
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Refract => {
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), Some(ty2), None) => (ty1, ty2),
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
_ => unreachable!(),
};
match *arg_ty {
Ti::Vector {
scalar:
Sc {
kind: Sk::Float, ..
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
// Start with the set of all overloads available for `fun`.
let mut overloads = fun.overloads();
log::debug!(
"initial overloads for {:?}: {:#?}",
fun,
overloads.for_debug(&module.types)
);
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
// If any argument is not a constant expression, then no
// overloads that accept abstract values should be considered.
// `OverloadSet::concrete_only` is supposed to help impose this
// restriction. However, no `MathFunction` accepts a mix of
// abstract and concrete arguments, so we don't need to worry
// about that here.
match (arg_ty, arg2_ty) {
(
&Ti::Vector {
scalar:
Sc {
width: vector_width,
..
},
..
},
&Ti::Scalar(Sc {
width: scalar_width,
kind: Sk::Float,
}),
) if vector_width == scalar_width => {}
_ => {
return Err(ExpressionError::InvalidArgumentType(
fun,
2,
arg2.unwrap(),
))
}
}
}
Mf::Normalize => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Vector {
scalar:
Sc {
kind: Sk::Float, ..
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), Some(ty2), None) => (ty1, ty2),
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Float, ..
})
| Ti::Vector {
scalar:
Sc {
kind: Sk::Float, ..
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
if arg2_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
2,
arg2.unwrap(),
));
}
}
Mf::Mix => {
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), Some(ty2), None) => (ty1, ty2),
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
let arg_width = match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Float,
width,
})
| Ti::Vector {
scalar:
Sc {
kind: Sk::Float,
width,
},
..
} => width,
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
};
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
// the last argument can always be a scalar
match *arg2_ty {
Ti::Scalar(Sc {
kind: Sk::Float,
width,
}) if width == arg_width => {}
_ if arg2_ty == arg_ty => {}
_ => {
return Err(ExpressionError::InvalidArgumentType(
fun,
2,
arg2.unwrap(),
));
}
}
}
Mf::Inverse | Mf::Determinant => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
let good = match *arg_ty {
Ti::Matrix { columns, rows, .. } => columns == rows,
_ => false,
};
if !good {
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
}
}
Mf::Transpose => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Matrix { .. } => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::QuantizeToF16 => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Float,
width: 4,
})
| Ti::Vector {
scalar:
Sc {
kind: Sk::Float,
width: 4,
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
// Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
Mf::CountLeadingZeros
| Mf::CountTrailingZeros
| Mf::CountOneBits
| Mf::ReverseBits
| Mf::FirstLeadingBit
| Mf::FirstTrailingBit => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
Sk::Sint | Sk::Uint => {
if scalar.width != 4 {
return Err(ExpressionError::UnsupportedWidth(
fun,
scalar.kind,
scalar.width,
));
}
}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
},
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::InsertBits => {
let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Sint | Sk::Uint,
..
})
| Ti::Vector {
scalar:
Sc {
kind: Sk::Sint | Sk::Uint,
..
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
match *arg2_ty {
Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
_ => {
return Err(ExpressionError::InvalidArgumentType(
fun,
2,
arg2.unwrap(),
))
}
}
match *arg3_ty {
Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
_ => {
return Err(ExpressionError::InvalidArgumentType(
fun,
2,
arg3.unwrap(),
))
}
}
// Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
for &arg in [arg_ty, arg1_ty, arg2_ty, arg3_ty].iter() {
match *arg {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
if scalar.width != 4 {
return Err(ExpressionError::UnsupportedWidth(
fun,
scalar.kind,
scalar.width,
));
}
}
_ => {}
}
}
}
Mf::ExtractBits => {
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), Some(ty2), None) => (ty1, ty2),
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Sint | Sk::Uint,
..
})
| Ti::Vector {
scalar:
Sc {
kind: Sk::Sint | Sk::Uint,
..
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
match *arg1_ty {
Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
_ => {
return Err(ExpressionError::InvalidArgumentType(
fun,
2,
arg1.unwrap(),
))
}
}
match *arg2_ty {
Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
_ => {
return Err(ExpressionError::InvalidArgumentType(
fun,
2,
arg2.unwrap(),
))
}
}
// Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
for &arg in [arg_ty, arg1_ty, arg2_ty].iter() {
match *arg {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
if scalar.width != 4 {
return Err(ExpressionError::UnsupportedWidth(
fun,
scalar.kind,
scalar.width,
));
}
}
_ => {}
}
}
}
Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Vector {
size: crate::VectorSize::Bi,
scalar:
Sc {
kind: Sk::Float, ..
},
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Vector {
size: crate::VectorSize::Quad,
scalar:
Sc {
kind: Sk::Float, ..
},
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
mf @ (Mf::Pack4xI8 | Mf::Pack4xU8) => {
let scalar_kind = match mf {
Mf::Pack4xI8 => Sk::Sint,
Mf::Pack4xU8 => Sk::Uint,
_ => unreachable!(),
};
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Vector {
size: crate::VectorSize::Quad,
scalar: Sc { kind, .. },
} if kind == scalar_kind => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::Unpack2x16float
| Mf::Unpack2x16snorm
| Mf::Unpack2x16unorm
| Mf::Unpack4x8snorm
| Mf::Unpack4x8unorm
| Mf::Unpack4xI8
| Mf::Unpack4xU8 => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() {
// Remove overloads that cannot accept an `i`'th
// argument arguments of type `ty`.
overloads = overloads.arg(i, ty, &module.types);
log::debug!(
"overloads after arg {i}: {:#?}",
overloads.for_debug(&module.types)
);
if overloads.is_empty() {
log::debug!("all overloads eliminated");
return Err(ExpressionError::InvalidArgumentType(fun, i as u32, expr));
}
}
if actuals.len() < overloads.min_arguments() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
ShaderStages::all()
}
E::As {

View File

@ -0,0 +1 @@
targets = "SPIRV | METAL | GLSL | WGSL"

View File

@ -0,0 +1,102 @@
fn f() {
// For calls that return abstract types, we can't write the type we
// actually expect, but we can at least require an automatic
// conversion.
//
// Error cases are covered in `wgsl_errors::more_inconsistent_type`.
// start
var clamp_aiaiai: i32 = clamp(1, 1, 1);
var clamp_aiaiaf: f32 = clamp(1, 1, 1.0);
var clamp_aiaii: i32 = clamp(1, 1, 1i);
var clamp_aiaif: f32 = clamp(1, 1, 1f);
var clamp_aiafai: f32 = clamp(1, 1.0, 1);
var clamp_aiafaf: f32 = clamp(1, 1.0, 1.0);
//var clamp_aiafi: f32 = clamp(1, 1.0, 1i); error
var clamp_aiaff: f32 = clamp(1, 1.0, 1f);
var clamp_aiiai: i32 = clamp(1, 1i, 1);
//var clamp_aiiaf: f32 = clamp(1, 1i, 1.0); error
var clamp_aiii: i32 = clamp(1, 1i, 1i);
//var clamp_aiif: f32 = clamp(1, 1i, 1f); error
var clamp_aifai: f32 = clamp(1, 1f, 1);
var clamp_aifaf: f32 = clamp(1, 1f, 1.0);
//var clamp_aifi: f32 = clamp(1, 1f, 1i); error
var clamp_aiff: f32 = clamp(1, 1f, 1f);
var clamp_afaiai: f32 = clamp(1.0, 1, 1);
var clamp_afaiaf: f32 = clamp(1.0, 1, 1.0);
//var clamp_afaii: f32 = clamp(1.0, 1, 1i); error
var clamp_afaif: f32 = clamp(1.0, 1, 1f);
var clamp_afafai: f32 = clamp(1.0, 1.0, 1);
var clamp_afafaf: f32 = clamp(1.0, 1.0, 1.0);
//var clamp_afafi: f32 = clamp(1.0, 1.0, 1i); error
var clamp_afaff: f32 = clamp(1.0, 1.0, 1f);
//var clamp_afiai: f32 = clamp(1.0, 1i, 1); error
//var clamp_afiaf: f32 = clamp(1.0, 1i, 1.0); error
//var clamp_afii: f32 = clamp(1.0, 1i, 1i); error
//var clamp_afif: f32 = clamp(1.0, 1i, 1f); error
var clamp_affai: f32 = clamp(1.0, 1f, 1);
var clamp_affaf: f32 = clamp(1.0, 1f, 1.0);
//var clamp_affi: f32 = clamp(1.0, 1f, 1i); error
var clamp_afff: f32 = clamp(1.0, 1f, 1f);
var clamp_iaiai: i32 = clamp(1i, 1, 1);
//var clamp_iaiaf: f32 = clamp(1i, 1, 1.0); error
var clamp_iaii: i32 = clamp(1i, 1, 1i);
//var clamp_iaif: f32 = clamp(1i, 1, 1f); error
//var clamp_iafai: f32 = clamp(1i, 1.0, 1); error
//var clamp_iafaf: f32 = clamp(1i, 1.0, 1.0); error
//var clamp_iafi: f32 = clamp(1i, 1.0, 1i); error
//var clamp_iaff: f32 = clamp(1i, 1.0, 1f); error
var clamp_iiai: i32 = clamp(1i, 1i, 1);
//var clamp_iiaf: f32 = clamp(1i, 1i, 1.0); error
var clamp_iii: i32 = clamp(1i, 1i, 1i);
//var clamp_iif: f32 = clamp(1i, 1i, 1f); error
//var clamp_ifai: f32 = clamp(1i, 1f, 1); error
//var clamp_ifaf: f32 = clamp(1i, 1f, 1.0); error
//var clamp_ifi: f32 = clamp(1i, 1f, 1i); error
//var clamp_iff: f32 = clamp(1i, 1f, 1f); error
var clamp_faiai: f32 = clamp(1f, 1, 1);
var clamp_faiaf: f32 = clamp(1f, 1, 1.0);
//var clamp_faii: f32 = clamp(1f, 1, 1i); error
var clamp_faif: f32 = clamp(1f, 1, 1f);
var clamp_fafai: f32 = clamp(1f, 1.0, 1);
var clamp_fafaf: f32 = clamp(1f, 1.0, 1.0);
//var clamp_fafi: f32 = clamp(1f, 1.0, 1i); error
var clamp_faff: f32 = clamp(1f, 1.0, 1f);
//var clamp_fiai: f32 = clamp(1f, 1i, 1); error
//var clamp_fiaf: f32 = clamp(1f, 1i, 1.0); error
//var clamp_fii: f32 = clamp(1f, 1i, 1i); error
//var clamp_fif: f32 = clamp(1f, 1i, 1f); error
var clamp_ffai: f32 = clamp(1f, 1f, 1);
var clamp_ffaf: f32 = clamp(1f, 1f, 1.0);
//var clamp_ffi: f32 = clamp(1f, 1f, 1i); error
var clamp_fff: f32 = clamp(1f, 1f, 1f);
// end
var min_aiai: i32 = min(1, 1);
var min_aiaf: f32 = min(1, 1.0);
var min_aii: i32 = min(1, 1i);
var min_aif: f32 = min(1, 1f);
var min_afai: f32 = min(1.0, 1);
var min_afaf: f32 = min(1.0, 1.0);
//var min_afi: f32 = min(1.0, 1i); error
var min_aff: f32 = min(1.0, 1f);
var min_iai: i32 = min(1i, 1);
//var min_iaf: f32 = min(1i, 1.0); error
var min_ii: i32 = min(1i, 1i);
//var min_if: f32 = min(1i, 1f); error
var min_fai: f32 = min(1f, 1);
var min_faf: f32 = min(1f, 1.0);
//var min_fi: f32 = min(1f, 1i); error
var min_ff: f32 = min(1f, 1f);
var pow_aiai = pow(1, 1);
var pow_aiaf = pow(1, 1.0);
var pow_aif = pow(1, 1f);
var pow_afai = pow(1.0, 1);
var pow_afaf = pow(1.0, 1.0);
var pow_aff = pow(1.0, 1f);
var pow_fai = pow(1f, 1);
var pow_faf = pow(1f, 1.0);
var pow_ff = pow(1f, 1f);
}

View File

@ -266,13 +266,15 @@ fn builtin_cross_product_args() {
use naga::{MathFunction, Module, Type, TypeInner, VectorSize};
// We want to ensure that the *only* problem with the code is the
// arity of the vectors passed to `cross`. So validate two
// versions of the module varying only in that aspect.
// arity of the call, or the size of the vectors passed to
// `cross`. So validate different versions of the module varying
// only in those aspects.
//
// Looking at uses of the `wg_load` makes it easy to identify the
// differences between the two variants.
// Looking at uses of `size` and `arity` makes it easy to identify
// the differences between the variants.
fn variant(
size: VectorSize,
arity: usize,
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
let span = naga::Span::default();
let mut module = Module::default();
@ -311,9 +313,9 @@ fn builtin_cross_product_args() {
Expression::Math {
fun: MathFunction::Cross,
arg: ex_zero,
arg1: Some(ex_zero),
arg2: None,
arg3: None,
arg1: (arity >= 2).then_some(ex_zero),
arg2: (arity >= 3).then_some(ex_zero),
arg3: (arity >= 4).then_some(ex_zero),
},
span,
);
@ -338,9 +340,14 @@ fn builtin_cross_product_args() {
.validate(&module)
}
assert!(variant(VectorSize::Bi).is_err());
variant(VectorSize::Tri).expect("module should validate");
assert!(variant(VectorSize::Quad).is_err());
assert!(variant(VectorSize::Bi, 2).is_err());
assert!(variant(VectorSize::Tri, 1).is_err());
variant(VectorSize::Tri, 2).expect("module should validate");
assert!(variant(VectorSize::Tri, 3).is_err());
assert!(variant(VectorSize::Tri, 4).is_err());
assert!(variant(VectorSize::Quad, 2).is_err());
}
#[cfg(feature = "wgsl-in")]
@ -758,3 +765,61 @@ fn bad_texture_dimensions_level() {
assert!(validate("1i").is_ok());
assert!(validate("1").is_ok());
}
#[test]
fn arity_check() {
use ir::MathFunction as Mf;
use naga::Span;
use naga::{ir, valid};
let _ = env_logger::builder().is_test(true).try_init();
type Result = core::result::Result<naga::valid::ModuleInfo, naga::valid::ValidationError>;
fn validate(fun: ir::MathFunction, args: &[usize]) -> Result {
let nowhere = Span::default();
let mut module = ir::Module::default();
let ty_f32 = module.types.insert(
ir::Type {
name: Some("f32".to_string()),
inner: ir::TypeInner::Scalar(ir::Scalar::F32),
},
nowhere,
);
let mut f = ir::Function {
result: Some(ir::FunctionResult {
ty: ty_f32,
binding: None,
}),
..ir::Function::default()
};
let ex_zero = f
.expressions
.append(ir::Expression::ZeroValue(ty_f32), nowhere);
let ex_pow = f.expressions.append(
dbg!(ir::Expression::Math {
fun,
arg: ex_zero,
arg1: args.contains(&1).then_some(ex_zero),
arg2: args.contains(&2).then_some(ex_zero),
arg3: args.contains(&3).then_some(ex_zero),
}),
nowhere,
);
f.body = ir::Block::from_vec(vec![
ir::Statement::Emit(naga::Range::new_from_bounds(ex_pow, ex_pow)),
ir::Statement::Return {
value: Some(ex_pow),
},
]);
module.functions.append(f, nowhere);
valid::Validator::new(Default::default(), valid::Capabilities::all())
.validate(&module)
.map_err(|err| err.into_inner()) // discard spans
}
assert!(validate(Mf::Sin, &[]).is_ok());
assert!(validate(Mf::Sin, &[1]).is_err());
assert!(validate(Mf::Sin, &[3]).is_err());
assert!(validate(Mf::Pow, &[1]).is_ok());
assert!(validate(Mf::Pow, &[3]).is_err());
}

View File

@ -179,6 +179,56 @@ fn bad_type_cast() {
);
}
#[test]
fn cross_vec2() {
check(
r#"
fn x() -> f32 {
return cross(vec2(0., 1.), vec2(0., 1.));
}
"#,
"\
error: wrong type passed as argument #1 to `cross`
wgsl:3:24
3 return cross(vec2(0., 1.), vec2(0., 1.));
^^^^^ ^^^^^^^^^^^^ argument #1 has type `vec2<{AbstractFloat}>`
= note: `cross` accepts the following types for argument #1:
= note: allowed type: vec3<{AbstractFloat}>
= note: allowed type: vec3<f32>
= note: allowed type: vec3<f16>
= note: allowed type: vec3<f64>
",
);
}
#[test]
fn cross_vec4() {
check(
r#"
fn x() -> f32 {
return cross(vec4(0., 1., 2., 3.), vec4(0., 1., 2., 3.));
}
"#,
"\
error: wrong type passed as argument #1 to `cross`
wgsl:3:24
3 return cross(vec4(0., 1., 2., 3.), vec4(0., 1., 2., 3.));
^^^^^ ^^^^^^^^^^^^^^^^^^^^ argument #1 has type `vec4<{AbstractFloat}>`
= note: `cross` accepts the following types for argument #1:
= note: allowed type: vec3<{AbstractFloat}>
= note: allowed type: vec3<f32>
= note: allowed type: vec3<f16>
= note: allowed type: vec3<f64>
",
);
}
#[test]
fn type_not_constructible() {
check(
@ -3136,3 +3186,141 @@ fn issue7165() {
// rendering an error if the module contained spans.
let _location = err.location("");
}
#[test]
fn wrong_argument_count() {
check(
"fn foo() -> f32 {
return sin();
}",
r#"error: wrong number of arguments: expected 1, found 0
wgsl:2:20
2 return sin();
^^^ wrong number of arguments
"#,
);
}
#[test]
fn too_many_arguments() {
check(
"fn foo() -> f32 {
return sin(1.0, 2.0);
}",
r#"error: too many arguments passed to `sin`
wgsl:2:20
2 return sin(1.0, 2.0);
^^^ ^^^ unexpected argument #2
= note: The `sin` function accepts at most 1 argument(s)
"#,
);
}
#[test]
fn too_many_arguments_2() {
check(
"fn foo() -> f32 {
return distance(vec2<f32>(), 0i);
}",
r#"error: wrong type passed as argument #2 to `distance`
wgsl:2:20
2 return distance(vec2<f32>(), 0i);
^^^^^^^^ ^^ argument #2 has type `i32`
= note: `distance` accepts the following types for argument #2:
= note: allowed type: vec2<f32>
= note: allowed type: vec2<f16>
= note: allowed type: vec2<f64>
= note: allowed type: vec3<f32>
= note: allowed type: vec3<f16>
= note: allowed type: vec3<f64>
= note: allowed type: vec4<f32>
= note: allowed type: vec4<f16>
= note: allowed type: vec4<f64>
"#,
);
}
#[test]
fn inconsistent_type() {
check(
"fn foo() -> f32 {
return dot(vec4<f32>(), vec3<f32>());
}",
r#"error: inconsistent type passed as argument #2 to `dot`
wgsl:2:20
2 return dot(vec4<f32>(), vec3<f32>());
^^^ ^^^^^^^^^^ ^^^^^^^^^^ argument #2 has type vec3<f32>
this argument has type vec4<f32>, which constrains subsequent arguments
= note: Because argument #1 has type vec4<f32>, only the following types
= note: (or types that automatically convert to them) are accepted for argument #2:
= note: allowed type: vec4<f32>
"#,
);
}
#[test]
fn more_inconsistent_type() {
#[track_caller]
fn variant(call: &str) {
let input = format!(
r#"
fn f() {{ var x = {call}; }}
"#
);
let result = naga::front::wgsl::parse_str(&input);
let Err(ref err) = result else {
panic!("expected ParseError, got {result:#?}");
};
if !err.message().contains("inconsistent type") {
panic!("expected 'inconsistent type' error, got {result:#?}");
}
}
variant("min(1.0, 1i)");
variant("min(1i, 1.0)");
variant("min(1i, 1f)");
variant("min(1f, 1i)");
variant("clamp(1, 1.0, 1i)");
variant("clamp(1, 1i, 1.0)");
variant("clamp(1, 1i, 1f)");
variant("clamp(1, 1f, 1i)");
variant("clamp(1.0, 1, 1i)");
variant("clamp(1.0, 1.0, 1i)");
variant("clamp(1.0, 1i, 1)");
variant("clamp(1.0, 1i, 1.0)");
variant("clamp(1.0, 1i, 1i)");
variant("clamp(1.0, 1i, 1f)");
variant("clamp(1.0, 1f, 1i)");
variant("clamp(1i, 1, 1.0)");
variant("clamp(1i, 1, 1f)");
variant("clamp(1i, 1.0, 1)");
variant("clamp(1i, 1.0, 1.0)");
variant("clamp(1i, 1.0, 1i)");
variant("clamp(1i, 1.0, 1f)");
variant("clamp(1i, 1i, 1.0)");
variant("clamp(1i, 1i, 1f)");
variant("clamp(1i, 1f, 1)");
variant("clamp(1i, 1f, 1.0)");
variant("clamp(1i, 1f, 1i)");
variant("clamp(1i, 1f, 1f)");
variant("clamp(1f, 1, 1i)");
variant("clamp(1f, 1.0, 1i)");
variant("clamp(1f, 1i, 1)");
variant("clamp(1f, 1i, 1.0)");
variant("clamp(1f, 1i, 1i)");
variant("clamp(1f, 1i, 1f)");
variant("clamp(1f, 1f, 1i)");
}

View File

@ -868,7 +868,13 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
ty: Value(Vector(
size: Tri,
scalar: (
kind: Float,
width: 4,
),
)),
),
(
uniformity: (
@ -1231,7 +1237,13 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
ty: Value(Vector(
size: Tri,
scalar: (
kind: Float,
width: 4,
),
)),
),
(
uniformity: (
@ -1252,7 +1264,10 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(0),
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
@ -1261,7 +1276,10 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(0),
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (

View File

@ -0,0 +1,66 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
void f(
) {
int clamp_aiaiai = 1;
float clamp_aiaiaf = 1.0;
int clamp_aiaii = 1;
float clamp_aiaif = 1.0;
float clamp_aiafai = 1.0;
float clamp_aiafaf = 1.0;
float clamp_aiaff = 1.0;
int clamp_aiiai = 1;
int clamp_aiii = 1;
float clamp_aifai = 1.0;
float clamp_aifaf = 1.0;
float clamp_aiff = 1.0;
float clamp_afaiai = 1.0;
float clamp_afaiaf = 1.0;
float clamp_afaif = 1.0;
float clamp_afafai = 1.0;
float clamp_afafaf = 1.0;
float clamp_afaff = 1.0;
float clamp_affai = 1.0;
float clamp_affaf = 1.0;
float clamp_afff = 1.0;
int clamp_iaiai = 1;
int clamp_iaii = 1;
int clamp_iiai = 1;
int clamp_iii = 1;
float clamp_faiai = 1.0;
float clamp_faiaf = 1.0;
float clamp_faif = 1.0;
float clamp_fafai = 1.0;
float clamp_fafaf = 1.0;
float clamp_faff = 1.0;
float clamp_ffai = 1.0;
float clamp_ffaf = 1.0;
float clamp_fff = 1.0;
int min_aiai = 1;
float min_aiaf = 1.0;
int min_aii = 1;
float min_aif = 1.0;
float min_afai = 1.0;
float min_afaf = 1.0;
float min_aff = 1.0;
int min_iai = 1;
int min_ii = 1;
float min_fai = 1.0;
float min_faf = 1.0;
float min_ff = 1.0;
float pow_aiai = 1.0;
float pow_aiaf = 1.0;
float pow_aif = 1.0;
float pow_afai = 1.0;
float pow_afaf = 1.0;
float pow_aff = 1.0;
float pow_fai = 1.0;
float pow_faf = 1.0;
float pow_ff = 1.0;
return;
}

View File

@ -0,0 +1,77 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 68
OpCapability Shader
OpCapability Linkage
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
%2 = OpTypeVoid
%3 = OpTypeInt 32 1
%4 = OpTypeFloat 32
%7 = OpTypeFunction %2
%8 = OpConstant %3 1
%9 = OpConstant %4 1.0
%11 = OpTypePointer Function %3
%13 = OpTypePointer Function %4
%6 = OpFunction %2 None %7
%5 = OpLabel
%66 = OpVariable %13 Function %9
%63 = OpVariable %13 Function %9
%60 = OpVariable %13 Function %9
%57 = OpVariable %13 Function %9
%54 = OpVariable %11 Function %8
%51 = OpVariable %13 Function %9
%48 = OpVariable %11 Function %8
%45 = OpVariable %13 Function %9
%42 = OpVariable %13 Function %9
%39 = OpVariable %13 Function %9
%36 = OpVariable %11 Function %8
%33 = OpVariable %11 Function %8
%30 = OpVariable %13 Function %9
%27 = OpVariable %13 Function %9
%24 = OpVariable %13 Function %9
%21 = OpVariable %13 Function %9
%18 = OpVariable %13 Function %9
%15 = OpVariable %13 Function %9
%10 = OpVariable %11 Function %8
%64 = OpVariable %13 Function %9
%61 = OpVariable %13 Function %9
%58 = OpVariable %13 Function %9
%55 = OpVariable %13 Function %9
%52 = OpVariable %13 Function %9
%49 = OpVariable %13 Function %9
%46 = OpVariable %11 Function %8
%43 = OpVariable %13 Function %9
%40 = OpVariable %13 Function %9
%37 = OpVariable %13 Function %9
%34 = OpVariable %11 Function %8
%31 = OpVariable %13 Function %9
%28 = OpVariable %13 Function %9
%25 = OpVariable %13 Function %9
%22 = OpVariable %13 Function %9
%19 = OpVariable %11 Function %8
%16 = OpVariable %13 Function %9
%12 = OpVariable %13 Function %9
%65 = OpVariable %13 Function %9
%62 = OpVariable %13 Function %9
%59 = OpVariable %13 Function %9
%56 = OpVariable %13 Function %9
%53 = OpVariable %11 Function %8
%50 = OpVariable %13 Function %9
%47 = OpVariable %13 Function %9
%44 = OpVariable %13 Function %9
%41 = OpVariable %13 Function %9
%38 = OpVariable %13 Function %9
%35 = OpVariable %11 Function %8
%32 = OpVariable %13 Function %9
%29 = OpVariable %13 Function %9
%26 = OpVariable %13 Function %9
%23 = OpVariable %13 Function %9
%20 = OpVariable %11 Function %8
%17 = OpVariable %13 Function %9
%14 = OpVariable %11 Function %8
OpBranch %67
%67 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -492,8 +492,9 @@ OpLine %3 93 25
%200 = OpVectorShuffle %11 %199 %199 0 1 2
%201 = OpFSub %11 %198 %200
%202 = OpExtInst %11 %1 Normalize %201
OpLine %3 94 23
OpLine %3 94 32
%203 = OpDot %4 %155 %202
OpLine %3 94 23
%204 = OpExtInst %4 %1 FMax %48 %203
OpLine %3 96 9
%205 = OpFMul %4 %196 %204
@ -599,8 +600,9 @@ OpLine %3 112 25
%276 = OpVectorShuffle %11 %275 %275 0 1 2
%277 = OpFSub %11 %274 %276
%278 = OpExtInst %11 %1 Normalize %277
OpLine %3 113 23
OpLine %3 113 32
%279 = OpDot %4 %239 %278
OpLine %3 113 23
%280 = OpExtInst %4 %1 FMax %48 %279
OpLine %3 114 9
%281 = OpFMul %4 %272 %280

View File

@ -0,0 +1,60 @@
fn f() {
var clamp_aiaiai: i32 = 1i;
var clamp_aiaiaf: f32 = 1f;
var clamp_aiaii: i32 = 1i;
var clamp_aiaif: f32 = 1f;
var clamp_aiafai: f32 = 1f;
var clamp_aiafaf: f32 = 1f;
var clamp_aiaff: f32 = 1f;
var clamp_aiiai: i32 = 1i;
var clamp_aiii: i32 = 1i;
var clamp_aifai: f32 = 1f;
var clamp_aifaf: f32 = 1f;
var clamp_aiff: f32 = 1f;
var clamp_afaiai: f32 = 1f;
var clamp_afaiaf: f32 = 1f;
var clamp_afaif: f32 = 1f;
var clamp_afafai: f32 = 1f;
var clamp_afafaf: f32 = 1f;
var clamp_afaff: f32 = 1f;
var clamp_affai: f32 = 1f;
var clamp_affaf: f32 = 1f;
var clamp_afff: f32 = 1f;
var clamp_iaiai: i32 = 1i;
var clamp_iaii: i32 = 1i;
var clamp_iiai: i32 = 1i;
var clamp_iii: i32 = 1i;
var clamp_faiai: f32 = 1f;
var clamp_faiaf: f32 = 1f;
var clamp_faif: f32 = 1f;
var clamp_fafai: f32 = 1f;
var clamp_fafaf: f32 = 1f;
var clamp_faff: f32 = 1f;
var clamp_ffai: f32 = 1f;
var clamp_ffaf: f32 = 1f;
var clamp_fff: f32 = 1f;
var min_aiai: i32 = 1i;
var min_aiaf: f32 = 1f;
var min_aii: i32 = 1i;
var min_aif: f32 = 1f;
var min_afai: f32 = 1f;
var min_afaf: f32 = 1f;
var min_aff: f32 = 1f;
var min_iai: i32 = 1i;
var min_ii: i32 = 1i;
var min_fai: f32 = 1f;
var min_faf: f32 = 1f;
var min_ff: f32 = 1f;
var pow_aiai: f32 = 1f;
var pow_aiaf: f32 = 1f;
var pow_aif: f32 = 1f;
var pow_afai: f32 = 1f;
var pow_afaf: f32 = 1f;
var pow_aff: f32 = 1f;
var pow_fai: f32 = 1f;
var pow_faf: f32 = 1f;
var pow_ff: f32 = 1f;
return;
}

View File

@ -6,6 +6,8 @@ extend-exclude = [
# spirv-asm isn't real source code
'*.spvasm',
'docs/big-picture.xml',
# This test has weird pattern-derived variable names.
'naga/tests/in/wgsl/abstract-types-builtins.wgsl',
]
# Corrections take the form of a key/value pair. The key is the incorrect word