mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-04-13 20:46:34 +00:00
[naga] Support MathFunction
overloads correctly.
Define a new trait, `proc::builtins::OverloadSet`, for types that represent a Naga IR builtin function's set of overloads. The `OverloadSet` trait includes operations needed to validate calls, choose automatic type conversions, and generate diagnostics. Add a new function, `ir::MathFunction::overloads`, which returns the given `MathFunction`'s set of overloads as an `impl OverloadSet` value. Use this in the WGSL front end, the validator, and the typifier. To support `MathFunction::overloads`, provide two implementations of `OverloadSet`: - `List` is flexible but verbose. - `Regular` is concise but more restrictive. Some snapshot output is affected because `TypeResolution::Handle` values turn into `TypeResolution::Value`, since the function database constructs the return type directly. To work around #7405, avoid offering abstract-typed overloads of some functions. This addresses #6443 for `MathFunction`, although that issue covers other categories of operations as well.
This commit is contained in:
parent
8292949478
commit
7da699608d
@ -1,6 +1,6 @@
|
||||
//! Displaying Naga IR terms in diagnostic output.
|
||||
|
||||
use crate::proc::GlobalCtx;
|
||||
use crate::proc::{GlobalCtx, Rule};
|
||||
use crate::{Handle, Scalar, Type, TypeInner};
|
||||
|
||||
#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))]
|
||||
@ -83,6 +83,23 @@ impl fmt::Display for DiagnosticDisplay<(&TypeInner, GlobalCtx<'_>)> {
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for DiagnosticDisplay<(&str, &Rule, GlobalCtx<'_>)> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let (name, rule, ref ctx) = self.0;
|
||||
|
||||
#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))]
|
||||
ctx.write_type_rule(name, rule, f)?;
|
||||
|
||||
#[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))]
|
||||
{
|
||||
let _ = ctx;
|
||||
write!(f, "{name}({:?}) -> {:?}", rule.arguments, rule.conclusion)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for DiagnosticDisplay<Scalar> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let scalar = self.0;
|
||||
|
@ -133,6 +133,37 @@ pub trait TypeContext {
|
||||
}
|
||||
}
|
||||
|
||||
fn write_type_conclusion<W: Write>(
|
||||
&self,
|
||||
conclusion: &crate::proc::Conclusion,
|
||||
out: &mut W,
|
||||
) -> core::fmt::Result {
|
||||
use crate::proc::Conclusion as Co;
|
||||
|
||||
match *conclusion {
|
||||
Co::Value(ref inner) => self.write_type_inner(inner, out),
|
||||
Co::Predeclared(ref predeclared) => out.write_str(&predeclared.struct_name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_type_rule<W: Write>(
|
||||
&self,
|
||||
name: &str,
|
||||
rule: &crate::proc::Rule,
|
||||
out: &mut W,
|
||||
) -> core::fmt::Result {
|
||||
write!(out, "fn {name}(")?;
|
||||
for (i, arg) in rule.arguments.iter().enumerate() {
|
||||
if i > 0 {
|
||||
out.write_str(", ")?;
|
||||
}
|
||||
self.write_type_resolution(arg, out)?
|
||||
}
|
||||
out.write_str(") -> ")?;
|
||||
self.write_type_conclusion(&rule.conclusion, out)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn type_to_string(&self, handle: Handle<crate::Type>) -> String {
|
||||
let mut buf = String::new();
|
||||
self.write_type(handle, &mut buf).unwrap();
|
||||
@ -150,6 +181,12 @@ pub trait TypeContext {
|
||||
self.write_type_resolution(resolution, &mut buf).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
fn type_rule_to_string(&self, name: &str, rule: &crate::proc::Rule) -> String {
|
||||
let mut buf = String::new();
|
||||
self.write_type_rule(name, rule, &mut buf).unwrap();
|
||||
buf
|
||||
}
|
||||
}
|
||||
|
||||
fn try_write_type_inner<C, W>(ctx: &C, inner: &TypeInner, out: &mut W) -> Result<(), WriteTypeError>
|
||||
|
@ -271,6 +271,75 @@ pub(crate) enum Error<'a> {
|
||||
expected: Range<u32>,
|
||||
found: u32,
|
||||
},
|
||||
/// No overload of this function accepts this many arguments.
|
||||
TooManyArguments {
|
||||
/// The name of the function being called.
|
||||
function: String,
|
||||
|
||||
/// The function name in the call expression.
|
||||
call_span: Span,
|
||||
|
||||
/// The first argument that is unacceptable.
|
||||
arg_span: Span,
|
||||
|
||||
/// Maximum number of arguments accepted by any overload of
|
||||
/// this function.
|
||||
max_arguments: u32,
|
||||
},
|
||||
/// A value passed to a builtin function has a type that is not
|
||||
/// accepted by any overload of the function.
|
||||
WrongArgumentType {
|
||||
/// The name of the function being called.
|
||||
function: String,
|
||||
|
||||
/// The function name in the call expression.
|
||||
call_span: Span,
|
||||
|
||||
/// The first argument whose type is unacceptable.
|
||||
arg_span: Span,
|
||||
|
||||
/// The index of the first argument whose type is unacceptable.
|
||||
arg_index: u32,
|
||||
|
||||
/// That argument's actual type.
|
||||
arg_ty: String,
|
||||
|
||||
/// The set of argument types that would have been accepted for
|
||||
/// this argument, given the prior arguments.
|
||||
allowed: Vec<String>,
|
||||
},
|
||||
/// A value passed to a builtin function has a type that is not
|
||||
/// accepted, given the earlier arguments' types.
|
||||
InconsistentArgumentType {
|
||||
/// The name of the function being called.
|
||||
function: String,
|
||||
|
||||
/// The function name in the call expression.
|
||||
call_span: Span,
|
||||
|
||||
/// The first unacceptable argument.
|
||||
arg_span: Span,
|
||||
|
||||
/// The index of the first unacceptable argument.
|
||||
arg_index: u32,
|
||||
|
||||
/// The actual type of the first unacceptable argument.
|
||||
arg_ty: String,
|
||||
|
||||
/// The prior argument whose type made the `arg_span` argument
|
||||
/// unacceptable.
|
||||
inconsistent_span: Span,
|
||||
|
||||
/// The index of the `inconsistent_span` argument.
|
||||
inconsistent_index: u32,
|
||||
|
||||
/// The type of the `inconsistent_span` argument.
|
||||
inconsistent_ty: String,
|
||||
|
||||
/// The types that would have been accepted instead of the
|
||||
/// first unacceptable argument.
|
||||
allowed: Vec<String>,
|
||||
},
|
||||
FunctionReturnsVoid(Span),
|
||||
FunctionMustUseUnused(Span),
|
||||
FunctionMustUseReturnsVoid(Span, Span),
|
||||
@ -402,7 +471,8 @@ impl<'a> Error<'a> {
|
||||
"workgroup size separator (`,`) or a closing parenthesis".to_string()
|
||||
}
|
||||
ExpectedToken::GlobalItem => concat!(
|
||||
"global item (`struct`, `const`, `var`, `alias`, `fn`, `diagnostic`, `enable`, `requires`, `;`) ",
|
||||
"global item (`struct`, `const`, `var`, `alias`, ",
|
||||
"`fn`, `diagnostic`, `enable`, `requires`, `;`) ",
|
||||
"or the end of the file"
|
||||
)
|
||||
.to_string(),
|
||||
@ -831,6 +901,74 @@ impl<'a> Error<'a> {
|
||||
labels: vec![(span, "wrong number of arguments".into())],
|
||||
notes: vec![],
|
||||
},
|
||||
Error::TooManyArguments {
|
||||
ref function,
|
||||
call_span,
|
||||
arg_span,
|
||||
max_arguments,
|
||||
} => ParseError {
|
||||
message: format!("too many arguments passed to `{function}`"),
|
||||
labels: vec![
|
||||
(call_span, "".into()),
|
||||
(arg_span, format!("unexpected argument #{}", max_arguments + 1).into())
|
||||
],
|
||||
notes: vec![
|
||||
format!("The `{function}` function accepts at most {max_arguments} argument(s)")
|
||||
],
|
||||
},
|
||||
Error::WrongArgumentType {
|
||||
ref function,
|
||||
call_span,
|
||||
arg_span,
|
||||
arg_index,
|
||||
ref arg_ty,
|
||||
ref allowed,
|
||||
} => {
|
||||
let message = format!(
|
||||
"wrong type passed as argument #{} to `{function}`",
|
||||
arg_index + 1,
|
||||
);
|
||||
let labels = vec![
|
||||
(call_span, "".into()),
|
||||
(arg_span, format!("argument #{} has type `{arg_ty}`", arg_index + 1).into())
|
||||
];
|
||||
|
||||
let mut notes = vec![];
|
||||
notes.push(format!("`{function}` accepts the following types for argument #{}:", arg_index + 1));
|
||||
notes.extend(allowed.iter().map(|ty| format!("allowed type: {ty}")));
|
||||
|
||||
ParseError { message, labels, notes }
|
||||
},
|
||||
Error::InconsistentArgumentType {
|
||||
ref function,
|
||||
call_span,
|
||||
arg_span,
|
||||
arg_index,
|
||||
ref arg_ty,
|
||||
inconsistent_span,
|
||||
inconsistent_index,
|
||||
ref inconsistent_ty,
|
||||
ref allowed
|
||||
} => {
|
||||
let message = format!(
|
||||
"inconsistent type passed as argument #{} to `{function}`",
|
||||
arg_index + 1,
|
||||
);
|
||||
let labels = vec![
|
||||
(call_span, "".into()),
|
||||
(arg_span, format!("argument #{} has type {arg_ty}", arg_index + 1).into()),
|
||||
(inconsistent_span, format!(
|
||||
"this argument has type {inconsistent_ty}, which constrains subsequent arguments"
|
||||
).into()),
|
||||
];
|
||||
let mut notes = vec![
|
||||
format!("Because argument #{} has type {inconsistent_ty}, only the following types", inconsistent_index + 1),
|
||||
format!("(or types that automatically convert to them) are accepted for argument #{}:", arg_index + 1),
|
||||
];
|
||||
notes.extend(allowed.iter().map(|ty| format!("allowed type: {ty}")));
|
||||
|
||||
ParseError { message, labels, notes }
|
||||
}
|
||||
Error::FunctionReturnsVoid(span) => ParseError {
|
||||
message: "function does not return any value".to_string(),
|
||||
labels: vec![(span, "".into())],
|
||||
|
@ -6,7 +6,7 @@ use alloc::{
|
||||
};
|
||||
use core::num::NonZeroU32;
|
||||
|
||||
use crate::common::wgsl::TypeContext;
|
||||
use crate::common::wgsl::{TryToWgsl, TypeContext};
|
||||
use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType};
|
||||
use crate::front::wgsl::index::Index;
|
||||
use crate::front::wgsl::parse::number::Number;
|
||||
@ -491,6 +491,19 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a wrapper around `value` suitable for formatting.
|
||||
///
|
||||
/// Return a wrapper around `value` that implements
|
||||
/// [`core::fmt::Display`] in a form suitable for use in
|
||||
/// diagnostic messages.
|
||||
fn as_diagnostic_display<T>(
|
||||
&self,
|
||||
value: T,
|
||||
) -> crate::common::DiagnosticDisplay<(T, crate::proc::GlobalCtx)> {
|
||||
let ctx = self.module.to_ctx();
|
||||
crate::common::DiagnosticDisplay((value, ctx))
|
||||
}
|
||||
|
||||
fn append_expression(
|
||||
&mut self,
|
||||
expr: crate::Expression,
|
||||
@ -2439,52 +2452,204 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
|
||||
crate::Expression::Derivative { axis, ctrl, expr }
|
||||
} else if let Some(fun) = conv::map_standard_fun(function.name) {
|
||||
let expected = fun.argument_count() as _;
|
||||
let mut args = ctx.prepare_args(arguments, expected, span);
|
||||
use crate::proc::OverloadSet as _;
|
||||
|
||||
let arg = self.expression(args.next()?, ctx)?;
|
||||
let arg1 = args
|
||||
.next()
|
||||
.map(|x| self.expression(x, ctx))
|
||||
.ok()
|
||||
.transpose()?;
|
||||
let arg2 = args
|
||||
.next()
|
||||
.map(|x| self.expression(x, ctx))
|
||||
.ok()
|
||||
.transpose()?;
|
||||
let arg3 = args
|
||||
.next()
|
||||
.map(|x| self.expression(x, ctx))
|
||||
.ok()
|
||||
.transpose()?;
|
||||
let fun_overloads = fun.overloads();
|
||||
let mut remaining_overloads = fun_overloads.clone();
|
||||
let min_arguments = remaining_overloads.min_arguments();
|
||||
let max_arguments = remaining_overloads.max_arguments();
|
||||
if arguments.len() < min_arguments {
|
||||
return Err(Box::new(Error::WrongArgumentCount {
|
||||
span,
|
||||
expected: min_arguments as u32..max_arguments as u32,
|
||||
found: arguments.len() as u32,
|
||||
}));
|
||||
}
|
||||
if arguments.len() > max_arguments {
|
||||
return Err(Box::new(Error::TooManyArguments {
|
||||
function: fun.to_wgsl_for_diagnostics(),
|
||||
call_span: span,
|
||||
arg_span: ctx.ast_expressions.get_span(arguments[max_arguments]),
|
||||
max_arguments: max_arguments as _,
|
||||
}));
|
||||
}
|
||||
|
||||
args.finish()?;
|
||||
log::debug!(
|
||||
"Initial overloads: {:#?}",
|
||||
remaining_overloads.for_debug(&ctx.module.types)
|
||||
);
|
||||
let mut unconverted_arguments = Vec::with_capacity(arguments.len());
|
||||
for (arg_index, &arg) in arguments.iter().enumerate() {
|
||||
let lowered = self.expression_for_abstract(arg, ctx)?;
|
||||
let ty = resolve_inner!(ctx, lowered);
|
||||
log::debug!(
|
||||
"Supplying argument {arg_index} of type {}",
|
||||
crate::common::DiagnosticDisplay((ty, ctx.module.to_ctx()))
|
||||
);
|
||||
let next_remaining_overloads =
|
||||
remaining_overloads.arg(arg_index, ty, &ctx.module.types);
|
||||
|
||||
if fun == crate::MathFunction::Modf || fun == crate::MathFunction::Frexp {
|
||||
if let Some((size, scalar)) = match *resolve_inner!(ctx, arg) {
|
||||
crate::TypeInner::Scalar(scalar) => Some((None, scalar)),
|
||||
crate::TypeInner::Vector { size, scalar, .. } => {
|
||||
Some((Some(size), scalar))
|
||||
// If any argument is not a constant expression, then no overloads
|
||||
// that accept abstract values should be considered.
|
||||
// (`OverloadSet::concrete_only` is supposed to help impose this
|
||||
// restriction.) However, no `MathFunction` accepts a mix of
|
||||
// abstract and concrete arguments, so we don't need to worry
|
||||
// about that here.
|
||||
|
||||
log::debug!(
|
||||
"Remaining overloads: {:#?}",
|
||||
next_remaining_overloads.for_debug(&ctx.module.types)
|
||||
);
|
||||
|
||||
// If the set of remaining overloads is empty, then this argument's type
|
||||
// was unacceptable. Diagnose the problem and produce an error message.
|
||||
if next_remaining_overloads.is_empty() {
|
||||
let function = fun.to_wgsl_for_diagnostics();
|
||||
let call_span = span;
|
||||
let arg_span = ctx.ast_expressions.get_span(arg);
|
||||
let arg_ty = ctx.as_diagnostic_display(ty).to_string();
|
||||
|
||||
// Is this type *ever* permitted for the arg_index'th argument?
|
||||
// For example, `bool` is never permitted for `max`.
|
||||
let only_this_argument =
|
||||
fun_overloads.arg(arg_index, ty, &ctx.module.types);
|
||||
if only_this_argument.is_empty() {
|
||||
// No overload of `fun` accepts this type as the
|
||||
// arg_index'th argument. Determine the set of types that
|
||||
// would ever be allowed there.
|
||||
let allowed: Vec<String> = fun_overloads
|
||||
.allowed_args(arg_index, &ctx.module.to_ctx())
|
||||
.iter()
|
||||
.map(|ty| ctx.type_resolution_to_string(ty))
|
||||
.collect();
|
||||
|
||||
if allowed.is_empty() {
|
||||
// No overload of `fun` accepts any argument at this
|
||||
// index, so it's a simple case of excess arguments.
|
||||
// However, since each `MathFunction`'s overloads all
|
||||
// have the same arity, we should have detected this
|
||||
// earlier.
|
||||
unreachable!("expected all overloads to have the same arity");
|
||||
}
|
||||
|
||||
// Some overloads of `fun` do accept this many arguments,
|
||||
// but none accept one of this type.
|
||||
return Err(Box::new(Error::WrongArgumentType {
|
||||
function,
|
||||
call_span,
|
||||
arg_span,
|
||||
arg_index: arg_index as u32,
|
||||
arg_ty,
|
||||
allowed,
|
||||
}));
|
||||
}
|
||||
_ => None,
|
||||
} {
|
||||
ctx.module.generate_predeclared_type(
|
||||
if fun == crate::MathFunction::Modf {
|
||||
crate::PredeclaredType::ModfResult { size, scalar }
|
||||
} else {
|
||||
crate::PredeclaredType::FrexpResult { size, scalar }
|
||||
},
|
||||
);
|
||||
|
||||
// This argument's type is accepted by some overloads---just
|
||||
// not those overloads that remain, given the prior arguments.
|
||||
// For example, `max` accepts `f32` as its second argument -
|
||||
// but not if the first was `i32`.
|
||||
|
||||
// Build a list of the types that would have been accepted here,
|
||||
// given the prior arguments.
|
||||
let allowed: Vec<String> = remaining_overloads
|
||||
.allowed_args(arg_index, &ctx.module.to_ctx())
|
||||
.iter()
|
||||
.map(|ty| ctx.type_resolution_to_string(ty))
|
||||
.collect();
|
||||
|
||||
// Re-run the argument list to determine which prior argument
|
||||
// made this one unacceptable.
|
||||
let mut remaining_overloads = fun_overloads;
|
||||
for (prior_index, &prior_expr) in
|
||||
unconverted_arguments.iter().enumerate()
|
||||
{
|
||||
let prior_ty =
|
||||
ctx.typifier()[prior_expr].inner_with(&ctx.module.types);
|
||||
remaining_overloads = remaining_overloads.arg(
|
||||
prior_index,
|
||||
prior_ty,
|
||||
&ctx.module.types,
|
||||
);
|
||||
if remaining_overloads
|
||||
.arg(arg_index, ty, &ctx.module.types)
|
||||
.is_empty()
|
||||
{
|
||||
// This is the argument that killed our dreams.
|
||||
let inconsistent_span =
|
||||
ctx.ast_expressions.get_span(arguments[prior_index]);
|
||||
let inconsistent_ty =
|
||||
ctx.as_diagnostic_display(prior_ty).to_string();
|
||||
|
||||
if allowed.is_empty() {
|
||||
// Some overloads did accept `ty` at `arg_index`, but
|
||||
// given the arguments up through `prior_expr`, we see
|
||||
// no types acceptable at `arg_index`. This means that some
|
||||
// overloads expect fewer arguments than others. However,
|
||||
// each `MathFunction`'s overloads have the same arity, so this
|
||||
// should be impossible.
|
||||
unreachable!(
|
||||
"expected all overloads to have the same arity"
|
||||
);
|
||||
}
|
||||
|
||||
// Report `arg`'s type as inconsistent with `prior_expr`'s
|
||||
return Err(Box::new(Error::InconsistentArgumentType {
|
||||
function,
|
||||
call_span,
|
||||
arg_span,
|
||||
arg_index: arg_index as u32,
|
||||
arg_ty,
|
||||
inconsistent_span,
|
||||
inconsistent_index: prior_index as u32,
|
||||
inconsistent_ty,
|
||||
allowed,
|
||||
}));
|
||||
}
|
||||
}
|
||||
unreachable!("Failed to eliminate argument type when re-tried");
|
||||
}
|
||||
remaining_overloads = next_remaining_overloads;
|
||||
unconverted_arguments.push(lowered);
|
||||
}
|
||||
|
||||
// Select the most preferred type rule for this call,
|
||||
// given the argument types supplied above.
|
||||
let rule = remaining_overloads.most_preferred();
|
||||
|
||||
let mut converted_arguments = Vec::with_capacity(arguments.len());
|
||||
for (i, (&ast, unconverted)) in
|
||||
arguments.iter().zip(unconverted_arguments).enumerate()
|
||||
{
|
||||
let goal_inner = rule.arguments[i].inner_with(&ctx.module.types);
|
||||
let converted = match goal_inner.scalar_for_conversions(&ctx.module.types) {
|
||||
Some(goal_scalar) => {
|
||||
let arg_span = ctx.ast_expressions.get_span(ast);
|
||||
ctx.try_automatic_conversion_for_leaf_scalar(
|
||||
unconverted,
|
||||
goal_scalar,
|
||||
arg_span,
|
||||
)?
|
||||
}
|
||||
// No conversion is necessary.
|
||||
None => unconverted,
|
||||
};
|
||||
|
||||
converted_arguments.push(converted);
|
||||
}
|
||||
|
||||
// If this function returns a predeclared type, register it
|
||||
// in `Module::special_types`. The typifier will expect to
|
||||
// be able to find it there.
|
||||
if let crate::proc::Conclusion::Predeclared(predeclared) = rule.conclusion {
|
||||
ctx.module.generate_predeclared_type(predeclared);
|
||||
}
|
||||
|
||||
crate::Expression::Math {
|
||||
fun,
|
||||
arg,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
arg: converted_arguments[0],
|
||||
arg1: converted_arguments.get(1).cloned(),
|
||||
arg2: converted_arguments.get(2).cloned(),
|
||||
arg3: converted_arguments.get(3).cloned(),
|
||||
}
|
||||
} else if let Some(fun) = Texture::map(function.name) {
|
||||
self.texture_sample_helper(fun, arguments, span, ctx)?
|
||||
|
@ -7,6 +7,7 @@ mod emitter;
|
||||
pub mod index;
|
||||
mod layouter;
|
||||
mod namer;
|
||||
mod overloads;
|
||||
mod terminator;
|
||||
mod type_methods;
|
||||
mod typifier;
|
||||
@ -18,6 +19,7 @@ pub use emitter::Emitter;
|
||||
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
|
||||
pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
|
||||
pub use namer::{EntryPointIndex, NameKey, Namer};
|
||||
pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
|
||||
pub use terminator::ensure_block_returns;
|
||||
use thiserror::Error;
|
||||
pub use type_methods::min_max_float_representable_by;
|
||||
|
120
naga/src/proc/overloads/any_overload_set.rs
Normal file
120
naga/src/proc/overloads/any_overload_set.rs
Normal file
@ -0,0 +1,120 @@
|
||||
//! Dynamically dispatched [`OverloadSet`]s.
|
||||
|
||||
use crate::common::DiagnosticDebug;
|
||||
use crate::ir;
|
||||
use crate::proc::overloads::{list, regular, OverloadSet, Rule};
|
||||
use crate::proc::{GlobalCtx, TypeResolution};
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt;
|
||||
|
||||
macro_rules! define_any_overload_set {
|
||||
{ $( $module:ident :: $name:ident, )* } => {
|
||||
/// An [`OverloadSet`] that dynamically dispatches to concrete implementations.
|
||||
#[derive(Clone)]
|
||||
pub(in crate::proc::overloads) enum AnyOverloadSet {
|
||||
$(
|
||||
$name ( $module :: $name ),
|
||||
)*
|
||||
}
|
||||
|
||||
$(
|
||||
impl From<$module::$name> for AnyOverloadSet {
|
||||
fn from(concrete: $module::$name) -> Self {
|
||||
AnyOverloadSet::$name(concrete)
|
||||
}
|
||||
}
|
||||
)*
|
||||
|
||||
impl OverloadSet for AnyOverloadSet {
|
||||
fn is_empty(&self) -> bool {
|
||||
match *self {
|
||||
$(
|
||||
AnyOverloadSet::$name(ref x) => x.is_empty(),
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
fn min_arguments(&self) -> usize {
|
||||
match *self {
|
||||
$(
|
||||
AnyOverloadSet::$name(ref x) => x.min_arguments(),
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
fn max_arguments(&self) -> usize {
|
||||
match *self {
|
||||
$(
|
||||
AnyOverloadSet::$name(ref x) => x.max_arguments(),
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
fn arg(
|
||||
&self,
|
||||
i: usize,
|
||||
ty: &ir::TypeInner,
|
||||
types: &crate::UniqueArena<ir::Type>,
|
||||
) -> Self {
|
||||
match *self {
|
||||
$(
|
||||
AnyOverloadSet::$name(ref x) => AnyOverloadSet::$name(x.arg(i, ty, types)),
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
fn concrete_only(self, types: &crate::UniqueArena<ir::Type>) -> Self {
|
||||
match self {
|
||||
$(
|
||||
AnyOverloadSet::$name(x) => AnyOverloadSet::$name(x.concrete_only(types)),
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
fn most_preferred(&self) -> Rule {
|
||||
match *self {
|
||||
$(
|
||||
AnyOverloadSet::$name(ref x) => x.most_preferred(),
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
fn overload_list(&self, gctx: &GlobalCtx<'_>) -> Vec<Rule> {
|
||||
match *self {
|
||||
$(
|
||||
AnyOverloadSet::$name(ref x) => x.overload_list(gctx),
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
fn allowed_args(&self, i: usize, gctx: &GlobalCtx<'_>) -> Vec<TypeResolution> {
|
||||
match *self {
|
||||
$(
|
||||
AnyOverloadSet::$name(ref x) => x.allowed_args(i, gctx),
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
fn for_debug(&self, types: &crate::UniqueArena<ir::Type>) -> impl fmt::Debug {
|
||||
DiagnosticDebug((self, types))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for DiagnosticDebug<(&AnyOverloadSet, &crate::UniqueArena<ir::Type>)> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let (set, types) = self.0;
|
||||
match *set {
|
||||
$(
|
||||
AnyOverloadSet::$name(ref x) => DiagnosticDebug((x, types)).fmt(f),
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
define_any_overload_set! {
|
||||
list::List,
|
||||
regular::Regular,
|
||||
}
|
174
naga/src/proc/overloads/constructor_set.rs
Normal file
174
naga/src/proc/overloads/constructor_set.rs
Normal file
@ -0,0 +1,174 @@
|
||||
//! A set of type constructors, represented as a bitset.
|
||||
|
||||
use crate::ir;
|
||||
use crate::proc::overloads::one_bits_iter::OneBitsIter;
|
||||
|
||||
bitflags::bitflags! {
|
||||
/// A set of type constructors.
|
||||
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||
pub(crate) struct ConstructorSet: u16 {
|
||||
const SCALAR = 1 << 0;
|
||||
const VEC2 = 1 << 1;
|
||||
const VEC3 = 1 << 2;
|
||||
const VEC4 = 1 << 3;
|
||||
const MAT2X2 = 1 << 4;
|
||||
const MAT2X3 = 1 << 5;
|
||||
const MAT2X4 = 1 << 6;
|
||||
const MAT3X2 = 1 << 7;
|
||||
const MAT3X3 = 1 << 8;
|
||||
const MAT3X4 = 1 << 9;
|
||||
const MAT4X2 = 1 << 10;
|
||||
const MAT4X3 = 1 << 11;
|
||||
const MAT4X4 = 1 << 12;
|
||||
|
||||
const VECN = Self::VEC2.bits()
|
||||
| Self::VEC3.bits()
|
||||
| Self::VEC4.bits();
|
||||
}
|
||||
}
|
||||
|
||||
impl ConstructorSet {
|
||||
/// Return the single-member set containing `inner`'s constructor.
|
||||
pub const fn singleton(inner: &ir::TypeInner) -> ConstructorSet {
|
||||
use ir::TypeInner as Ti;
|
||||
use ir::VectorSize as Vs;
|
||||
match *inner {
|
||||
Ti::Scalar(_) => Self::SCALAR,
|
||||
Ti::Vector { size, scalar: _ } => match size {
|
||||
Vs::Bi => Self::VEC2,
|
||||
Vs::Tri => Self::VEC3,
|
||||
Vs::Quad => Self::VEC4,
|
||||
},
|
||||
Ti::Matrix {
|
||||
columns,
|
||||
rows,
|
||||
scalar: _,
|
||||
} => match (columns, rows) {
|
||||
(Vs::Bi, Vs::Bi) => Self::MAT2X2,
|
||||
(Vs::Bi, Vs::Tri) => Self::MAT2X3,
|
||||
(Vs::Bi, Vs::Quad) => Self::MAT2X4,
|
||||
(Vs::Tri, Vs::Bi) => Self::MAT3X2,
|
||||
(Vs::Tri, Vs::Tri) => Self::MAT3X3,
|
||||
(Vs::Tri, Vs::Quad) => Self::MAT3X4,
|
||||
(Vs::Quad, Vs::Bi) => Self::MAT4X2,
|
||||
(Vs::Quad, Vs::Tri) => Self::MAT4X3,
|
||||
(Vs::Quad, Vs::Quad) => Self::MAT4X4,
|
||||
},
|
||||
_ => Self::empty(),
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn is_singleton(self) -> bool {
|
||||
self.bits().is_power_of_two()
|
||||
}
|
||||
|
||||
/// Return an iterator over this set's members.
|
||||
///
|
||||
/// Members are produced as singleton, in order from most general to least.
|
||||
pub fn members(self) -> impl Iterator<Item = ConstructorSet> {
|
||||
OneBitsIter::new(self.bits() as u64).map(|bit| Self::from_bits(bit as u16).unwrap())
|
||||
}
|
||||
|
||||
/// Return the size of the sole element of `self`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panic if `self` is not a singleton.
|
||||
pub fn size(self) -> ConstructorSize {
|
||||
use ir::VectorSize as Vs;
|
||||
use ConstructorSize as Cs;
|
||||
|
||||
match self {
|
||||
ConstructorSet::SCALAR => Cs::Scalar,
|
||||
ConstructorSet::VEC2 => Cs::Vector(Vs::Bi),
|
||||
ConstructorSet::VEC3 => Cs::Vector(Vs::Tri),
|
||||
ConstructorSet::VEC4 => Cs::Vector(Vs::Quad),
|
||||
ConstructorSet::MAT2X2 => Cs::Matrix {
|
||||
columns: Vs::Bi,
|
||||
rows: Vs::Bi,
|
||||
},
|
||||
ConstructorSet::MAT2X3 => Cs::Matrix {
|
||||
columns: Vs::Bi,
|
||||
rows: Vs::Tri,
|
||||
},
|
||||
ConstructorSet::MAT2X4 => Cs::Matrix {
|
||||
columns: Vs::Bi,
|
||||
rows: Vs::Quad,
|
||||
},
|
||||
ConstructorSet::MAT3X2 => Cs::Matrix {
|
||||
columns: Vs::Tri,
|
||||
rows: Vs::Bi,
|
||||
},
|
||||
ConstructorSet::MAT3X3 => Cs::Matrix {
|
||||
columns: Vs::Tri,
|
||||
rows: Vs::Tri,
|
||||
},
|
||||
ConstructorSet::MAT3X4 => Cs::Matrix {
|
||||
columns: Vs::Tri,
|
||||
rows: Vs::Quad,
|
||||
},
|
||||
ConstructorSet::MAT4X2 => Cs::Matrix {
|
||||
columns: Vs::Quad,
|
||||
rows: Vs::Bi,
|
||||
},
|
||||
ConstructorSet::MAT4X3 => Cs::Matrix {
|
||||
columns: Vs::Quad,
|
||||
rows: Vs::Tri,
|
||||
},
|
||||
ConstructorSet::MAT4X4 => Cs::Matrix {
|
||||
columns: Vs::Quad,
|
||||
rows: Vs::Quad,
|
||||
},
|
||||
_ => unreachable!("ConstructorSet was not a singleton"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The sizes a member of [`ConstructorSet`] might have.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ConstructorSize {
|
||||
/// The constructor is [`SCALAR`].
|
||||
///
|
||||
/// [`SCALAR`]: ConstructorSet::SCALAR
|
||||
Scalar,
|
||||
|
||||
/// The constructor is `VECN` for some `N`.
|
||||
Vector(ir::VectorSize),
|
||||
|
||||
/// The constructor is `MATCXR` for some `C` and `R`.
|
||||
Matrix {
|
||||
columns: ir::VectorSize,
|
||||
rows: ir::VectorSize,
|
||||
},
|
||||
}
|
||||
|
||||
impl ConstructorSize {
|
||||
/// Construct a [`TypeInner`] for a type with this size and the given `scalar`.
|
||||
///
|
||||
/// [`TypeInner`]: ir::TypeInner
|
||||
pub const fn to_inner(self, scalar: ir::Scalar) -> ir::TypeInner {
|
||||
match self {
|
||||
Self::Scalar => ir::TypeInner::Scalar(scalar),
|
||||
Self::Vector(size) => ir::TypeInner::Vector { size, scalar },
|
||||
Self::Matrix { columns, rows } => ir::TypeInner::Matrix {
|
||||
columns,
|
||||
rows,
|
||||
scalar,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! constructor_set {
|
||||
( $( $constr:ident )|* ) => {
|
||||
{
|
||||
use $crate::proc::overloads::constructor_set::ConstructorSet;
|
||||
ConstructorSet::empty()
|
||||
$(
|
||||
.union(ConstructorSet::$constr)
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(in crate::proc::overloads) use constructor_set;
|
182
naga/src/proc/overloads/list.rs
Normal file
182
naga/src/proc/overloads/list.rs
Normal file
@ -0,0 +1,182 @@
|
||||
//! An [`OverloadSet`] represented as a vector of rules.
|
||||
//!
|
||||
//! [`OverloadSet`]: crate::proc::overloads::OverloadSet
|
||||
|
||||
use crate::common::{DiagnosticDebug, ForDebug, ForDebugWithTypes};
|
||||
use crate::ir;
|
||||
use crate::proc::overloads::one_bits_iter::OneBitsIter;
|
||||
use crate::proc::overloads::Rule;
|
||||
use crate::proc::{GlobalCtx, TypeResolution};
|
||||
|
||||
use alloc::rc::Rc;
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt;
|
||||
|
||||
/// A simple list of overloads.
|
||||
///
|
||||
/// Note that this type is not quite as general as it looks, in that
|
||||
/// the implementation of `most_preferred` doesn't work for arbitrary
|
||||
/// lists of overloads. See the documentation for [`List::rules`] for
|
||||
/// details.
|
||||
#[derive(Clone)]
|
||||
pub(in crate::proc::overloads) struct List {
|
||||
/// A bitmask of which elements of `rules` are included in the set.
|
||||
members: u64,
|
||||
|
||||
/// A list of type rules that are members of the set.
|
||||
///
|
||||
/// These must be listed in order such that every rule in the list
|
||||
/// is always more preferred than all subsequent rules in the
|
||||
/// list. If there is no such arrangement of rules, then you
|
||||
/// cannot use `List` to represent the overload set.
|
||||
rules: Rc<Vec<Rule>>,
|
||||
}
|
||||
|
||||
impl List {
|
||||
pub(in crate::proc::overloads) fn from_rules(rules: Vec<Rule>) -> List {
|
||||
List {
|
||||
members: len_to_full_mask(rules.len()),
|
||||
rules: Rc::new(rules),
|
||||
}
|
||||
}
|
||||
|
||||
fn members(&self) -> impl Iterator<Item = (u64, &Rule)> {
|
||||
OneBitsIter::new(self.members).map(|mask| {
|
||||
let index = mask.trailing_zeros() as usize;
|
||||
(mask, &self.rules[index])
|
||||
})
|
||||
}
|
||||
|
||||
fn filter<F>(&self, mut pred: F) -> List
|
||||
where
|
||||
F: FnMut(&Rule) -> bool,
|
||||
{
|
||||
let mut filtered_members = 0;
|
||||
for (mask, rule) in self.members() {
|
||||
if pred(rule) {
|
||||
filtered_members |= mask;
|
||||
}
|
||||
}
|
||||
|
||||
List {
|
||||
members: filtered_members,
|
||||
rules: self.rules.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::proc::overloads::OverloadSet for List {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.members == 0
|
||||
}
|
||||
|
||||
fn min_arguments(&self) -> usize {
|
||||
self.members()
|
||||
.fold(None, |best, (_, rule)| {
|
||||
// This is different from `max_arguments` because
|
||||
// `<Option as PartialOrd>` doesn't work the way we'd like.
|
||||
let len = rule.arguments.len();
|
||||
Some(match best {
|
||||
Some(best) => core::cmp::max(best, len),
|
||||
None => len,
|
||||
})
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn max_arguments(&self) -> usize {
|
||||
self.members()
|
||||
.fold(None, |n, (_, rule)| {
|
||||
core::cmp::max(n, Some(rule.arguments.len()))
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn arg(&self, i: usize, arg_ty: &ir::TypeInner, types: &crate::UniqueArena<ir::Type>) -> Self {
|
||||
log::debug!("arg {i} of type {:?}", arg_ty.for_debug(types));
|
||||
self.filter(|rule| {
|
||||
if log::log_enabled!(log::Level::Debug) {
|
||||
log::debug!(" considering rule {:?}", rule.for_debug(types));
|
||||
match rule.arguments.get(i) {
|
||||
Some(rule_ty) => {
|
||||
let rule_ty = rule_ty.inner_with(types);
|
||||
if arg_ty.equivalent(rule_ty, types) {
|
||||
log::debug!(" types are equivalent");
|
||||
} else {
|
||||
match arg_ty.automatically_converts_to(rule_ty, types) {
|
||||
Some((from, to)) => {
|
||||
log::debug!(
|
||||
" converts automatically from {:?} to {:?}",
|
||||
from.for_debug(),
|
||||
to.for_debug()
|
||||
);
|
||||
}
|
||||
None => {
|
||||
log::debug!(" no conversion possible");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
log::debug!(" argument index {i} out of bounds");
|
||||
}
|
||||
}
|
||||
}
|
||||
rule.arguments.get(i).is_some_and(|rule_ty| {
|
||||
let rule_ty = rule_ty.inner_with(types);
|
||||
arg_ty.equivalent(rule_ty, types)
|
||||
|| arg_ty.automatically_converts_to(rule_ty, types).is_some()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn concrete_only(self, types: &crate::UniqueArena<ir::Type>) -> Self {
|
||||
self.filter(|rule| {
|
||||
rule.arguments
|
||||
.iter()
|
||||
.all(|arg_ty| !arg_ty.inner_with(types).is_abstract(types))
|
||||
})
|
||||
}
|
||||
|
||||
fn most_preferred(&self) -> Rule {
|
||||
// As documented for `Self::rules`, whatever rule is first is
|
||||
// the most preferred. `OverloadSet` documents this method to
|
||||
// panic if the set is empty.
|
||||
let (_, rule) = self.members().next().unwrap();
|
||||
rule.clone()
|
||||
}
|
||||
|
||||
fn overload_list(&self, _gctx: &GlobalCtx<'_>) -> Vec<Rule> {
|
||||
self.members().map(|(_, rule)| rule.clone()).collect()
|
||||
}
|
||||
|
||||
fn allowed_args(&self, i: usize, _gctx: &GlobalCtx<'_>) -> Vec<TypeResolution> {
|
||||
self.members()
|
||||
.map(|(_, rule)| rule.arguments[i].clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn for_debug(&self, types: &crate::UniqueArena<ir::Type>) -> impl fmt::Debug {
|
||||
DiagnosticDebug((self, types))
|
||||
}
|
||||
}
|
||||
|
||||
const fn len_to_full_mask(n: usize) -> u64 {
|
||||
if n >= 64 {
|
||||
panic!("List::rules can only hold up to 63 rules");
|
||||
}
|
||||
|
||||
(1_u64 << n) - 1
|
||||
}
|
||||
|
||||
impl ForDebugWithTypes for &List {}
|
||||
|
||||
impl fmt::Debug for DiagnosticDebug<(&List, &crate::UniqueArena<ir::Type>)> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let (list, types) = self.0;
|
||||
|
||||
f.debug_list()
|
||||
.entries(list.members().map(|(_mask, rule)| rule.for_debug(types)))
|
||||
.finish()
|
||||
}
|
||||
}
|
221
naga/src/proc/overloads/mathfunction.rs
Normal file
221
naga/src/proc/overloads/mathfunction.rs
Normal file
@ -0,0 +1,221 @@
|
||||
//! Overload sets for [`ir::MathFunction`].
|
||||
|
||||
use crate::proc::overloads::any_overload_set::AnyOverloadSet;
|
||||
use crate::proc::overloads::list::List;
|
||||
use crate::proc::overloads::regular::regular;
|
||||
use crate::proc::overloads::utils::{
|
||||
concrete_int_scalars, float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule,
|
||||
scalar_or_vecn, triples, vector_sizes,
|
||||
};
|
||||
use crate::proc::overloads::OverloadSet;
|
||||
|
||||
use crate::ir;
|
||||
|
||||
impl ir::MathFunction {
|
||||
pub fn overloads(self) -> impl OverloadSet {
|
||||
use ir::MathFunction as Mf;
|
||||
|
||||
let set: AnyOverloadSet = match self {
|
||||
// Component-wise unary numeric operations
|
||||
Mf::Abs | Mf::Sign => regular!(1, SCALAR|VECN of NUMERIC).into(),
|
||||
|
||||
// Component-wise binary numeric operations
|
||||
Mf::Min | Mf::Max => regular!(2, SCALAR|VECN of NUMERIC).into(),
|
||||
|
||||
// Component-wise ternary numeric operations
|
||||
Mf::Clamp => regular!(3, SCALAR|VECN of NUMERIC).into(),
|
||||
|
||||
// Component-wise unary floating-point operations
|
||||
Mf::Sin
|
||||
| Mf::Cos
|
||||
| Mf::Tan
|
||||
| Mf::Asin
|
||||
| Mf::Acos
|
||||
| Mf::Atan
|
||||
| Mf::Sinh
|
||||
| Mf::Cosh
|
||||
| Mf::Tanh
|
||||
| Mf::Asinh
|
||||
| Mf::Acosh
|
||||
| Mf::Atanh
|
||||
| Mf::Saturate
|
||||
| Mf::Radians
|
||||
| Mf::Degrees
|
||||
| Mf::Ceil
|
||||
| Mf::Floor
|
||||
| Mf::Round
|
||||
| Mf::Fract
|
||||
| Mf::Trunc
|
||||
| Mf::Exp
|
||||
| Mf::Exp2
|
||||
| Mf::Log
|
||||
| Mf::Log2
|
||||
| Mf::Sqrt
|
||||
| Mf::InverseSqrt => regular!(1, SCALAR|VECN of FLOAT).into(),
|
||||
|
||||
// Component-wise binary floating-point operations
|
||||
Mf::Atan2 | Mf::Pow | Mf::Step => regular!(2, SCALAR|VECN of FLOAT).into(),
|
||||
|
||||
// Component-wise ternary floating-point operations
|
||||
Mf::Fma | Mf::SmoothStep => regular!(3, SCALAR|VECN of FLOAT).into(),
|
||||
|
||||
// Component-wise unary concrete integer operations
|
||||
Mf::CountTrailingZeros
|
||||
| Mf::CountLeadingZeros
|
||||
| Mf::CountOneBits
|
||||
| Mf::ReverseBits
|
||||
| Mf::FirstTrailingBit
|
||||
| Mf::FirstLeadingBit => regular!(1, SCALAR|VECN of CONCRETE_INTEGER).into(),
|
||||
|
||||
// Packing functions
|
||||
Mf::Pack4x8snorm | Mf::Pack4x8unorm => regular!(1, VEC4 of F32 -> U32).into(),
|
||||
Mf::Pack2x16snorm | Mf::Pack2x16unorm | Mf::Pack2x16float => {
|
||||
regular!(1, VEC2 of F32 -> U32).into()
|
||||
}
|
||||
Mf::Pack4xI8 => regular!(1, VEC4 of I32 -> U32).into(),
|
||||
Mf::Pack4xU8 => regular!(1, VEC4 of U32 -> U32).into(),
|
||||
|
||||
// Unpacking functions
|
||||
Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => regular!(1, SCALAR of U32 -> Vec4F).into(),
|
||||
Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack2x16float => {
|
||||
regular!(1, SCALAR of U32 -> Vec2F).into()
|
||||
}
|
||||
Mf::Unpack4xI8 => regular!(1, SCALAR of U32 -> Vec4I).into(),
|
||||
Mf::Unpack4xU8 => regular!(1, SCALAR of U32 -> Vec4U).into(),
|
||||
|
||||
// One-off operations
|
||||
Mf::Dot => regular!(2, VECN of NUMERIC -> Scalar).into(),
|
||||
Mf::Modf => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Modf).into(),
|
||||
Mf::Frexp => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Frexp).into(),
|
||||
Mf::Ldexp => ldexp().into(),
|
||||
Mf::Outer => outer().into(),
|
||||
Mf::Cross => regular!(2, VEC3 of FLOAT).into(),
|
||||
Mf::Distance => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
|
||||
Mf::Length => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
|
||||
Mf::Normalize => regular!(1, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
|
||||
Mf::FaceForward => regular!(3, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
|
||||
Mf::Reflect => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
|
||||
Mf::Refract => refract().into(),
|
||||
Mf::Mix => mix().into(),
|
||||
Mf::Inverse => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT).into(),
|
||||
Mf::Transpose => transpose().into(),
|
||||
Mf::Determinant => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT -> Scalar).into(),
|
||||
Mf::QuantizeToF16 => regular!(1, SCALAR|VECN of F32).into(),
|
||||
Mf::ExtractBits => extract_bits().into(),
|
||||
Mf::InsertBits => insert_bits().into(),
|
||||
};
|
||||
|
||||
set
|
||||
}
|
||||
}
|
||||
|
||||
fn ldexp() -> List {
|
||||
/// Construct the exponent scalar given the mantissa's inner.
|
||||
fn exponent_from_mantissa(mantissa: ir::Scalar) -> ir::Scalar {
|
||||
match mantissa.kind {
|
||||
ir::ScalarKind::AbstractFloat => ir::Scalar::ABSTRACT_INT,
|
||||
ir::ScalarKind::Float => ir::Scalar::I32,
|
||||
_ => unreachable!("not a float scalar"),
|
||||
}
|
||||
}
|
||||
|
||||
list(
|
||||
// The ldexp mantissa argument can be any floating-point type.
|
||||
float_scalars_unimplemented_abstract().flat_map(|mantissa_scalar| {
|
||||
// The exponent type is the integer counterpart of the mantissa type.
|
||||
let exponent_scalar = exponent_from_mantissa(mantissa_scalar);
|
||||
// There are scalar and vector component-wise overloads.
|
||||
scalar_or_vecn(mantissa_scalar)
|
||||
.zip(scalar_or_vecn(exponent_scalar))
|
||||
.map(move |(mantissa, exponent)| {
|
||||
let result = mantissa.clone();
|
||||
rule([mantissa, exponent], result)
|
||||
})
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn outer() -> List {
|
||||
list(
|
||||
triples(
|
||||
vector_sizes(),
|
||||
vector_sizes(),
|
||||
float_scalars_unimplemented_abstract(),
|
||||
)
|
||||
.map(|(cols, rows, scalar)| {
|
||||
let left = ir::TypeInner::Vector { size: cols, scalar };
|
||||
let right = ir::TypeInner::Vector { size: rows, scalar };
|
||||
let result = ir::TypeInner::Matrix {
|
||||
columns: cols,
|
||||
rows,
|
||||
scalar,
|
||||
};
|
||||
rule([left, right], result)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn refract() -> List {
|
||||
list(
|
||||
pairs(vector_sizes(), float_scalars_unimplemented_abstract()).map(|(size, scalar)| {
|
||||
let incident = ir::TypeInner::Vector { size, scalar };
|
||||
let normal = incident.clone();
|
||||
let ratio = ir::TypeInner::Scalar(scalar);
|
||||
let result = incident.clone();
|
||||
rule([incident, normal, ratio], result)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn transpose() -> List {
|
||||
list(
|
||||
triples(vector_sizes(), vector_sizes(), float_scalars()).map(|(a, b, scalar)| {
|
||||
let input = ir::TypeInner::Matrix {
|
||||
columns: a,
|
||||
rows: b,
|
||||
scalar,
|
||||
};
|
||||
let output = ir::TypeInner::Matrix {
|
||||
columns: b,
|
||||
rows: a,
|
||||
scalar,
|
||||
};
|
||||
rule([input], output)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn extract_bits() -> List {
|
||||
list(concrete_int_scalars().flat_map(|scalar| {
|
||||
scalar_or_vecn(scalar).map(|input| {
|
||||
let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
|
||||
let count = ir::TypeInner::Scalar(ir::Scalar::U32);
|
||||
let output = input.clone();
|
||||
rule([input, offset, count], output)
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
fn insert_bits() -> List {
|
||||
list(concrete_int_scalars().flat_map(|scalar| {
|
||||
scalar_or_vecn(scalar).map(|input| {
|
||||
let newbits = input.clone();
|
||||
let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
|
||||
let count = ir::TypeInner::Scalar(ir::Scalar::U32);
|
||||
let output = input.clone();
|
||||
rule([input, newbits, offset, count], output)
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
fn mix() -> List {
|
||||
list(float_scalars().flat_map(|scalar| {
|
||||
scalar_or_vecn(scalar).flat_map(move |input| {
|
||||
let scalar_ratio = ir::TypeInner::Scalar(scalar);
|
||||
[
|
||||
rule([input.clone(), input.clone(), input.clone()], input.clone()),
|
||||
rule([input.clone(), input.clone(), scalar_ratio], input),
|
||||
]
|
||||
})
|
||||
}))
|
||||
}
|
237
naga/src/proc/overloads/mod.rs
Normal file
237
naga/src/proc/overloads/mod.rs
Normal file
@ -0,0 +1,237 @@
|
||||
/*! Overload resolution for builtin functions.
|
||||
|
||||
This module defines the [`OverloadSet`] trait, which provides methods the
|
||||
validator and typifier can use to check the types to builtin functions,
|
||||
determine their result types, and produce diagnostics that explain why a given
|
||||
application is not allowed and suggest fixes.
|
||||
|
||||
You can call [`MathFunction::overloads`] to obtain an `impl OverloadSet`
|
||||
representing the given `MathFunction`'s overloads.
|
||||
|
||||
[`MathFunction::overloads`]: crate::ir::MathFunction::overloads
|
||||
|
||||
*/
|
||||
|
||||
mod constructor_set;
|
||||
mod regular;
|
||||
mod scalar_set;
|
||||
|
||||
mod any_overload_set;
|
||||
mod list;
|
||||
mod mathfunction;
|
||||
mod one_bits_iter;
|
||||
mod rule;
|
||||
mod utils;
|
||||
|
||||
pub use rule::{Conclusion, MissingSpecialType, Rule};
|
||||
|
||||
use crate::ir;
|
||||
use crate::proc::TypeResolution;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt;
|
||||
|
||||
#[expect(rustdoc::private_intra_doc_links)]
|
||||
/// A trait for types representing of a set of Naga IR type rules.
|
||||
///
|
||||
/// Given an expression like `max(x, y)`, there are multiple type rules that
|
||||
/// could apply, depending on the types of `x` and `y`, like:
|
||||
///
|
||||
/// - `max(i32, i32) -> i32`
|
||||
/// - `max(vec4<f32>, vec4<f32>) -> vec4<f32>`
|
||||
///
|
||||
/// and so on. Borrowing WGSL's terminology, Naga calls the full set of type
|
||||
/// rules that might apply to a given expression its "overload candidates", or
|
||||
/// "overloads" for short.
|
||||
///
|
||||
/// This trait is meant to be implemented by types that represent a set of
|
||||
/// overload candidates. For example, [`MathFunction::overloads`] returns an
|
||||
/// `impl OverloadSet` describing the overloads for the given Naga IR math
|
||||
/// function. Naga's typifier, validator, and WGSL front end use this trait for
|
||||
/// their work.
|
||||
///
|
||||
/// [`MathFunction::overloads`]: ir::MathFunction::overloads
|
||||
///
|
||||
/// # Automatic conversions
|
||||
///
|
||||
/// In principle, overload sets are easy: you simply list all the overloads the
|
||||
/// function supports, and then when you're presented with a call to typecheck,
|
||||
/// you just see if the actual argument types presented in the source code match
|
||||
/// some overload from the list.
|
||||
///
|
||||
/// However, Naga supports languages like WGSL, which apply certain [automatic
|
||||
/// conversions] if necessary to make a call fit some overload's requirements.
|
||||
/// This means that the set of calls that are effectively allowed by a given set
|
||||
/// of overloads can be quite large, since any combination of automatic
|
||||
/// conversions might be applied.
|
||||
///
|
||||
/// For example, if `x` is a `u32`, and `100` is an abstract integer, then even
|
||||
/// though `max` has no overload like `max(u32, AbstractInt) -> ...`, the
|
||||
/// expression `max(x, 100)` is still allowed, because AbstractInt automatically
|
||||
/// converts to `u32`.
|
||||
///
|
||||
/// [automatic conversions]: https://gpuweb.github.io/gpuweb/wgsl/#feasible-automatic-conversion
|
||||
///
|
||||
/// # How to use `OverloadSet`
|
||||
///
|
||||
/// The general process of using an `OverloadSet` is as follows:
|
||||
///
|
||||
/// - Obtain an `OverloadSet` for a given operation (say, by calling
|
||||
/// [`MathFunction::overloads`]). This set is fixed by Naga IR's semantics.
|
||||
///
|
||||
/// - Call its [`arg`] method, supplying the type of the argument passed to the
|
||||
/// operation at a certain index. This returns a new `OverloadSet` containing
|
||||
/// only those overloads that could accept the given type for the given
|
||||
/// argument. This includes overloads that only match if automatic conversions
|
||||
/// are applied.
|
||||
///
|
||||
/// - If, at some point, the overload set becomes empty, then the set of
|
||||
/// arguments is not allowed for this operation, and the program is invalid.
|
||||
/// The `OverloadSet` trait provides an [`is_empty`] method.
|
||||
///
|
||||
/// - After all arguments have been supplied, if the overload set is still
|
||||
/// non-empty, you can call its [`most_preferred`] method to find out which
|
||||
/// overload has the least cost in terms of automatic conversions.
|
||||
///
|
||||
/// - If the call is rejected, you can use `OverloadSet` to help produce
|
||||
/// diagnostic messages that explain exactly what went wrong. `OverloadSet`
|
||||
/// implementations are meant to be cheap to [`Clone`], so it is fine to keep
|
||||
/// the original overload set value around, and re-run the selection process,
|
||||
/// attempting to supply the rejected argument at each step to see exactly
|
||||
/// which preceding argument ruled it out. The [`overload_list`] and
|
||||
/// [`allowed_args`] methods are helpful for this.
|
||||
///
|
||||
/// [`arg`]: OverloadSet::arg
|
||||
/// [`is_empty`]: OverloadSet::is_empty
|
||||
/// [`most_preferred`]: OverloadSet::most_preferred
|
||||
/// [`overload_list`]: OverloadSet::overload_list
|
||||
/// [`allowed_args`]: OverloadSet::allowed_args
|
||||
///
|
||||
/// # Concrete implementations
|
||||
///
|
||||
/// This module contains two private implementations of `OverloadSet`:
|
||||
///
|
||||
/// - The [`List`] type is a straightforward list of overloads. It is general,
|
||||
/// but verbose to use. The [`utils`] module exports functions that construct
|
||||
/// `List` overload sets for the functions that need this.
|
||||
///
|
||||
/// - The [`Regular`] type is a compact, efficient representation for the kinds
|
||||
/// of overload sets commonly seen for Naga IR mathematical functions.
|
||||
/// However, in return for its simplicity, it is not as flexible as [`List`].
|
||||
/// This module use the [`regular!`] macro as a legible notation for `Regular`
|
||||
/// sets.
|
||||
///
|
||||
/// [`List`]: list::List
|
||||
/// [`Regular`]: regular::Regular
|
||||
/// [`regular!`]: regular::regular
|
||||
pub trait OverloadSet: Clone {
|
||||
/// Return true if `self` is the empty set of overloads.
|
||||
fn is_empty(&self) -> bool;
|
||||
|
||||
/// Return the smallest number of arguments in any type rule in the set.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `self` is empty.
|
||||
fn min_arguments(&self) -> usize;
|
||||
|
||||
/// Return the largest number of arguments in any type rule in the set.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `self` is empty.
|
||||
fn max_arguments(&self) -> usize;
|
||||
|
||||
/// Find the overloads that could accept a given argument.
|
||||
///
|
||||
/// Return a new overload set containing those members of `self` that could
|
||||
/// accept a value of type `ty` for their `i`'th argument, once
|
||||
/// feasible automatic conversions have been applied.
|
||||
fn arg(&self, i: usize, ty: &ir::TypeInner, types: &crate::UniqueArena<ir::Type>) -> Self;
|
||||
|
||||
/// Limit `self` to overloads whose arguments are all concrete types.
|
||||
///
|
||||
/// Naga's overload resolution is based on WGSL's [overload resolution
|
||||
/// algorithm][ora], which includes the following step:
|
||||
///
|
||||
/// > Eliminate any candidate where one of its subexpressions resolves to
|
||||
/// > an abstract type after feasible automatic conversions, but another of
|
||||
/// > the candidate’s subexpressions is not a const-expression.
|
||||
/// >
|
||||
/// > Note: As a consequence, if any subexpression in the phrase is not a
|
||||
/// > const-expression, then all subexpressions in the phrase must have a
|
||||
/// > concrete type.
|
||||
///
|
||||
/// Essentially, if any one of the arguments is not a constant expression,
|
||||
/// then the operation is going to be evaluated at runtime, so all its
|
||||
/// arguments must be converted to a concrete type. If you determine that an
|
||||
/// argument is non-constant, you can use this trait method to toss out any
|
||||
/// overloads that would accept abstract types.
|
||||
///
|
||||
/// In almost all cases, this operation has no effect. Only constant
|
||||
/// expressions can have abstract types, so if any argument is not a
|
||||
/// constant expression, it must have a concrete type. Although many
|
||||
/// operations in Naga IR have overloads for both abstract types and
|
||||
/// concrete types, few operations have overloads that accept a mix of
|
||||
/// abstract and concrete types. Thus, a single concrete argument will
|
||||
/// usually have eliminated all overloads that accept abstract types anyway.
|
||||
/// (The exceptions are oddities like `Expression::Select`, where the
|
||||
/// `condition` operand could be a runtime expression even as the `accept`
|
||||
/// and `reject` operands have abstract types.)
|
||||
///
|
||||
/// Note that it is *not* correct to just [concretize] all arguments once
|
||||
/// you've noticed that some argument is non-constant. The type to which
|
||||
/// each argument is converted depends on the overloads available, not just
|
||||
/// the argument's own type.
|
||||
///
|
||||
/// [ora]: https://gpuweb.github.io/gpuweb/wgsl/#overload-resolution-section
|
||||
/// [concretize]: https://gpuweb.github.io/gpuweb/wgsl/#concretization
|
||||
fn concrete_only(self, types: &crate::UniqueArena<ir::Type>) -> Self;
|
||||
|
||||
/// Return the most preferred candidate.
|
||||
///
|
||||
/// Rank the candidates in `self` as described in WGSL's [overload
|
||||
/// resolution algorithm][ora], and return a singleton set containing the
|
||||
/// most preferred candidate.
|
||||
///
|
||||
/// # Simplifications versus WGSL
|
||||
///
|
||||
/// Naga's process for selecting the most preferred candidate is simpler
|
||||
/// than WGSL's:
|
||||
///
|
||||
/// - WGSL allows for the possibility of ambiguous calls, where multiple
|
||||
/// overload candidates exist, no one candidate is clearly better than all
|
||||
/// the others. For example, if a function has the two overloads `(i32,
|
||||
/// f32) -> bool` and `(f32, i32) -> bool`, and the arguments are both
|
||||
/// AbstractInt, neither overload is preferred over the other. Ambiguous
|
||||
/// calls are errors.
|
||||
///
|
||||
/// However, Naga IR includes no operations whose overload sets allow such
|
||||
/// situations to arise, so there is always a most preferred candidate.
|
||||
/// Thus, this method infallibly returns a `Rule`, and has no way to
|
||||
/// indicate ambiguity.
|
||||
///
|
||||
/// - WGSL says that the most preferred candidate depends on the conversion
|
||||
/// rank for each argument, as determined by the types of the arguments
|
||||
/// being passed.
|
||||
///
|
||||
/// However, the overloads of every operation in Naga IR can be ranked
|
||||
/// even without knowing the argument types. So this method does not
|
||||
/// require the argument types as a parameter.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `self` is empty, or if no argument types have been supplied.
|
||||
///
|
||||
/// [ora]: https://gpuweb.github.io/gpuweb/wgsl/#overload-resolution-section
|
||||
fn most_preferred(&self) -> Rule;
|
||||
|
||||
/// Return a type rule for each of the overloads in `self`.
|
||||
fn overload_list(&self, gctx: &crate::proc::GlobalCtx<'_>) -> Vec<Rule>;
|
||||
|
||||
/// Return a list of the types allowed for argument `i`.
|
||||
fn allowed_args(&self, i: usize, gctx: &crate::proc::GlobalCtx<'_>) -> Vec<TypeResolution>;
|
||||
|
||||
/// Return an object that can be formatted with [`core::fmt::Debug`].
|
||||
fn for_debug(&self, types: &crate::UniqueArena<ir::Type>) -> impl fmt::Debug;
|
||||
}
|
89
naga/src/proc/overloads/one_bits_iter.rs
Normal file
89
naga/src/proc/overloads/one_bits_iter.rs
Normal file
@ -0,0 +1,89 @@
|
||||
//! An iterator over bitmasks.
|
||||
|
||||
/// An iterator that produces the set bits in the given `u64`.
|
||||
///
|
||||
/// `OneBitsIter(n)` is an [`Iterator`] that produces each of the set bits in
|
||||
/// `n`, as a bitmask, in order of increasing value. In other words, it produces
|
||||
/// the unique sequence of distinct powers of two that adds up to `n`.
|
||||
///
|
||||
/// For example, iterating over `OneBitsIter(21)` produces the values `1`, `4`,
|
||||
/// and `16`, in that order, because `21` is `0xb10101`.
|
||||
///
|
||||
/// When `n` is the bits of a bitmask, this iterates over the set bits in the
|
||||
/// bitmask, in order of increasing bit value. `bitflags` does define an `iter`
|
||||
/// method, but it's not well-specified or well-implemented.
|
||||
///
|
||||
/// The values produced are masks, not bit numbers. Use `u64::trailing_zeros` if
|
||||
/// you need bit numbers.
|
||||
pub struct OneBitsIter(u64);
|
||||
|
||||
impl OneBitsIter {
|
||||
pub const fn new(bits: u64) -> Self {
|
||||
Self(bits)
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for OneBitsIter {
|
||||
type Item = u64;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.0 == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Subtracting one from a value in binary clears the lowest `1` bit
|
||||
// (call it `B`), and sets all the bits below that.
|
||||
let mask = self.0 - 1;
|
||||
|
||||
// Complementing that means that we've instead *set* `B`, *cleared*
|
||||
// everything below it, and *complemented* everything above it.
|
||||
//
|
||||
// Masking with the original value clears everything above and below
|
||||
// `B`, leaving only `B` set. This is the value we produce.
|
||||
let item = self.0 & !mask;
|
||||
|
||||
// Now that we've produced this bit, remove it from `self.0`.
|
||||
self.0 &= mask;
|
||||
|
||||
Some(item)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty() {
|
||||
assert_eq!(OneBitsIter(0).next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all() {
|
||||
let mut obi = OneBitsIter(!0);
|
||||
for bit in 0..64 {
|
||||
assert_eq!(obi.next(), Some(1 << bit));
|
||||
}
|
||||
assert_eq!(obi.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first() {
|
||||
let mut obi = OneBitsIter(1);
|
||||
assert_eq!(obi.next(), Some(1));
|
||||
assert_eq!(obi.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn last() {
|
||||
let mut obi = OneBitsIter(1 << 63);
|
||||
assert_eq!(obi.next(), Some(1 << 63));
|
||||
assert_eq!(obi.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn in_order() {
|
||||
let mut obi = OneBitsIter(0b11011000001);
|
||||
assert_eq!(obi.next(), Some(1));
|
||||
assert_eq!(obi.next(), Some(64));
|
||||
assert_eq!(obi.next(), Some(128));
|
||||
assert_eq!(obi.next(), Some(512));
|
||||
assert_eq!(obi.next(), Some(1024));
|
||||
assert_eq!(obi.next(), None);
|
||||
}
|
541
naga/src/proc/overloads/regular.rs
Normal file
541
naga/src/proc/overloads/regular.rs
Normal file
@ -0,0 +1,541 @@
|
||||
/*! A representation for highly regular overload sets common in Naga IR.
|
||||
|
||||
Many Naga builtin functions' overload sets have a highly regular
|
||||
structure. For example, many arithmetic functions can be applied to
|
||||
any floating-point type, or any vector thereof. This module defines a
|
||||
handful of types for representing such simple overload sets that is
|
||||
simple and efficient.
|
||||
|
||||
*/
|
||||
|
||||
use crate::common::{DiagnosticDebug, ForDebugWithTypes};
|
||||
use crate::ir;
|
||||
use crate::proc::overloads::constructor_set::{ConstructorSet, ConstructorSize};
|
||||
use crate::proc::overloads::rule::{Conclusion, Rule};
|
||||
use crate::proc::overloads::scalar_set::ScalarSet;
|
||||
use crate::proc::overloads::OverloadSet;
|
||||
use crate::proc::{GlobalCtx, TypeResolution};
|
||||
use crate::UniqueArena;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt;
|
||||
|
||||
/// Overload sets represented as sets of scalars and constructors.
|
||||
///
|
||||
/// This type represents an [`OverloadSet`] using a bitset of scalar
|
||||
/// types and a bitset of type constructors that might be applied to
|
||||
/// those scalars. The overload set contains a rule for every possible
|
||||
/// combination of scalars and constructors, essentially the cartesian
|
||||
/// product of the two sets.
|
||||
///
|
||||
/// For example, if the arity is 2, set of scalars is { AbstractFloat,
|
||||
/// `f32` }, and the set of constructors is { `vec2`, `vec3` }, then
|
||||
/// that represents the set of overloads:
|
||||
///
|
||||
/// - (`vec2<AbstractFloat>`, `vec2<AbstractFloat>`) -> `vec2<AbstractFloat>`
|
||||
/// - (`vec2<f32>`, `vec2<f32>`) -> `vec2<f32>`
|
||||
/// - (`vec3<AbstractFloat>`, `vec3<AbstractFloat>`) -> `vec3<AbstractFloat>`
|
||||
/// - (`vec3<f32>`, `vec3<f32>`) -> `vec3<f32>`
|
||||
///
|
||||
/// The `conclude` value says how to determine the return type from
|
||||
/// the argument type.
|
||||
///
|
||||
/// Restrictions:
|
||||
///
|
||||
/// - All overloads must take the same number of arguments.
|
||||
///
|
||||
/// - For any given overload, all its arguments must have the same
|
||||
/// type.
|
||||
#[derive(Clone)]
|
||||
pub(in crate::proc::overloads) struct Regular {
|
||||
/// The number of arguments in the rules.
|
||||
pub arity: usize,
|
||||
|
||||
/// The set of type constructors to apply.
|
||||
pub constructors: ConstructorSet,
|
||||
|
||||
/// The set of scalars to apply them to.
|
||||
pub scalars: ScalarSet,
|
||||
|
||||
/// How to determine a member rule's return type given the type of
|
||||
/// its arguments.
|
||||
pub conclude: ConclusionRule,
|
||||
}
|
||||
|
||||
impl Regular {
|
||||
pub(in crate::proc::overloads) const EMPTY: Regular = Regular {
|
||||
arity: 0,
|
||||
constructors: ConstructorSet::empty(),
|
||||
scalars: ScalarSet::empty(),
|
||||
conclude: ConclusionRule::ArgumentType,
|
||||
};
|
||||
|
||||
/// Return an iterator over all the argument types allowed by `self`.
|
||||
///
|
||||
/// Return an iterator that produces, for each overload in `self`, the
|
||||
/// constructor and scalar of its argument types and return type.
|
||||
///
|
||||
/// A [`Regular`] value can only represent overload sets where, in
|
||||
/// each overload, all the arguments have the same type, and the
|
||||
/// return type is always going to be a determined by the argument
|
||||
/// types, so giving the constructor and scalar is sufficient to
|
||||
/// characterize the entire rule.
|
||||
fn members(&self) -> impl Iterator<Item = (ConstructorSize, ir::Scalar)> {
|
||||
let scalars = self.scalars;
|
||||
self.constructors.members().flat_map(move |constructor| {
|
||||
let size = constructor.size();
|
||||
// Technically, we don't need the "most general" `TypeInner` here,
|
||||
// but since `ScalarSet::members` only produces singletons anyway,
|
||||
// the effect is the same.
|
||||
scalars
|
||||
.members()
|
||||
.map(move |singleton| (size, singleton.most_general_scalar()))
|
||||
})
|
||||
}
|
||||
|
||||
fn rules(&self) -> impl Iterator<Item = Rule> {
|
||||
let arity = self.arity;
|
||||
let conclude = self.conclude;
|
||||
self.members()
|
||||
.map(move |(size, scalar)| make_rule(arity, size, scalar, conclude))
|
||||
}
|
||||
}
|
||||
|
||||
impl OverloadSet for Regular {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.constructors.is_empty() || self.scalars.is_empty()
|
||||
}
|
||||
|
||||
fn min_arguments(&self) -> usize {
|
||||
assert!(!self.is_empty());
|
||||
self.arity
|
||||
}
|
||||
|
||||
fn max_arguments(&self) -> usize {
|
||||
assert!(!self.is_empty());
|
||||
self.arity
|
||||
}
|
||||
|
||||
fn arg(&self, i: usize, ty: &ir::TypeInner, types: &UniqueArena<ir::Type>) -> Self {
|
||||
if i >= self.arity {
|
||||
return Self::EMPTY;
|
||||
}
|
||||
|
||||
let constructor = ConstructorSet::singleton(ty);
|
||||
|
||||
let scalars = match ty.scalar_for_conversions(types) {
|
||||
Some(ty_scalar) => ScalarSet::convertible_from(ty_scalar),
|
||||
None => ScalarSet::empty(),
|
||||
};
|
||||
|
||||
Self {
|
||||
arity: self.arity,
|
||||
|
||||
// Constrain all member rules' constructors to match `ty`'s.
|
||||
constructors: self.constructors & constructor,
|
||||
|
||||
// Constrain all member rules' arguments to be something
|
||||
// that `ty` can be converted to.
|
||||
scalars: self.scalars & scalars,
|
||||
|
||||
conclude: self.conclude,
|
||||
}
|
||||
}
|
||||
|
||||
fn concrete_only(self, _types: &UniqueArena<ir::Type>) -> Self {
|
||||
Self {
|
||||
scalars: self.scalars & ScalarSet::CONCRETE,
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
fn most_preferred(&self) -> Rule {
|
||||
assert!(!self.is_empty());
|
||||
|
||||
// If there is more than one constructor allowed, then we must
|
||||
// not have had any arguments supplied at all. In any case, we
|
||||
// don't have any unambiguously preferred candidate.
|
||||
assert!(self.constructors.is_singleton());
|
||||
|
||||
let size = self.constructors.size();
|
||||
let scalar = self.scalars.most_general_scalar();
|
||||
make_rule(self.arity, size, scalar, self.conclude)
|
||||
}
|
||||
|
||||
fn overload_list(&self, _gctx: &GlobalCtx<'_>) -> Vec<Rule> {
|
||||
self.rules().collect()
|
||||
}
|
||||
|
||||
fn allowed_args(&self, i: usize, _gctx: &GlobalCtx<'_>) -> Vec<TypeResolution> {
|
||||
if i >= self.arity {
|
||||
return Vec::new();
|
||||
}
|
||||
self.members()
|
||||
.map(|(size, scalar)| TypeResolution::Value(size.to_inner(scalar)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn for_debug(&self, types: &UniqueArena<ir::Type>) -> impl fmt::Debug {
|
||||
DiagnosticDebug((self, types))
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [`Regular`] member [`Rule`] for the given arity and type.
|
||||
///
|
||||
/// [`Regular`] can only represent rules where all the argument types and the
|
||||
/// return type are the same, so just knowing `arity` and `inner` is sufficient.
|
||||
///
|
||||
/// [`Rule`]: crate::proc::overloads::Rule
|
||||
fn make_rule(
|
||||
arity: usize,
|
||||
size: ConstructorSize,
|
||||
scalar: ir::Scalar,
|
||||
conclusion_rule: ConclusionRule,
|
||||
) -> Rule {
|
||||
let inner = size.to_inner(scalar);
|
||||
let arg = TypeResolution::Value(inner.clone());
|
||||
Rule {
|
||||
arguments: core::iter::repeat(arg.clone()).take(arity).collect(),
|
||||
conclusion: conclusion_rule.conclude(size, scalar),
|
||||
}
|
||||
}
|
||||
|
||||
/// Conclusion-computing rules.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[repr(u8)]
|
||||
pub(in crate::proc::overloads) enum ConclusionRule {
|
||||
ArgumentType,
|
||||
Scalar,
|
||||
Frexp,
|
||||
Modf,
|
||||
U32,
|
||||
Vec2F,
|
||||
Vec4F,
|
||||
Vec4I,
|
||||
Vec4U,
|
||||
}
|
||||
|
||||
impl ConclusionRule {
|
||||
fn conclude(self, size: ConstructorSize, scalar: ir::Scalar) -> Conclusion {
|
||||
match self {
|
||||
Self::ArgumentType => Conclusion::Value(size.to_inner(scalar)),
|
||||
Self::Scalar => Conclusion::Value(ir::TypeInner::Scalar(scalar)),
|
||||
Self::Frexp => Conclusion::for_frexp_modf(ir::MathFunction::Frexp, size, scalar),
|
||||
Self::Modf => Conclusion::for_frexp_modf(ir::MathFunction::Modf, size, scalar),
|
||||
Self::U32 => Conclusion::Value(ir::TypeInner::Scalar(ir::Scalar::U32)),
|
||||
Self::Vec2F => Conclusion::Value(ir::TypeInner::Vector {
|
||||
size: ir::VectorSize::Bi,
|
||||
scalar: ir::Scalar::F32,
|
||||
}),
|
||||
Self::Vec4F => Conclusion::Value(ir::TypeInner::Vector {
|
||||
size: ir::VectorSize::Quad,
|
||||
scalar: ir::Scalar::F32,
|
||||
}),
|
||||
Self::Vec4I => Conclusion::Value(ir::TypeInner::Vector {
|
||||
size: ir::VectorSize::Quad,
|
||||
scalar: ir::Scalar::I32,
|
||||
}),
|
||||
Self::Vec4U => Conclusion::Value(ir::TypeInner::Vector {
|
||||
size: ir::VectorSize::Quad,
|
||||
scalar: ir::Scalar::U32,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for DiagnosticDebug<(&Regular, &UniqueArena<ir::Type>)> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let (regular, types) = self.0;
|
||||
let rules: Vec<Rule> = regular.rules().collect();
|
||||
f.debug_struct("List")
|
||||
.field("rules", &rules.for_debug(types))
|
||||
.field("conclude", ®ular.conclude)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ForDebugWithTypes for &Regular {}
|
||||
|
||||
impl fmt::Debug for DiagnosticDebug<(&[Rule], &UniqueArena<ir::Type>)> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let (rules, types) = self.0;
|
||||
f.debug_list()
|
||||
.entries(rules.iter().map(|rule| rule.for_debug(types)))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ForDebugWithTypes for &[Rule] {}
|
||||
|
||||
/// Construct a [`Regular`] [`OverloadSet`].
|
||||
///
|
||||
/// Examples:
|
||||
///
|
||||
/// - `regular!(2, SCALAR|VECN of FLOAT)`: An overload set whose rules take two
|
||||
/// arguments of the same type: a floating-point scalar (possibly abstract) or
|
||||
/// a vector of such. The return type is the same as the argument type.
|
||||
///
|
||||
/// - `regular!(1, VECN of FLOAT -> Scalar)`: An overload set whose rules take
|
||||
/// one argument that is a vector of floats, and whose return type is the leaf
|
||||
/// scalar type of the argument type.
|
||||
///
|
||||
/// The constructor values (before the `<` angle brackets `>`) are
|
||||
/// constants from [`ConstructorSet`].
|
||||
///
|
||||
/// The scalar values (inside the `<` angle brackets `>`) are
|
||||
/// constants from [`ScalarSet`].
|
||||
///
|
||||
/// When a return type identifier is given, it is treated as a variant
|
||||
/// of the the [`ConclusionRule`] enum.
|
||||
macro_rules! regular {
|
||||
// regular!(ARITY, CONSTRUCTOR of SCALAR)
|
||||
( $arity:literal , $( $constr:ident )|* of $( $scalar:ident )|*) => {
|
||||
{
|
||||
use $crate::proc::overloads;
|
||||
use overloads::constructor_set::constructor_set;
|
||||
use overloads::regular::{Regular, ConclusionRule};
|
||||
use overloads::scalar_set::scalar_set;
|
||||
Regular {
|
||||
arity: $arity,
|
||||
constructors: constructor_set!( $( $constr )|* ),
|
||||
scalars: scalar_set!( $( $scalar )|* ),
|
||||
conclude: ConclusionRule::ArgumentType,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// regular!(ARITY, CONSTRUCTOR of SCALAR -> CONCLUSION_RULE)
|
||||
( $arity:literal , $( $constr:ident )|* of $( $scalar:ident )|* -> $conclude:ident) => {
|
||||
{
|
||||
use $crate::proc::overloads;
|
||||
use overloads::constructor_set::constructor_set;
|
||||
use overloads::regular::{Regular, ConclusionRule};
|
||||
use overloads::scalar_set::scalar_set;
|
||||
Regular {
|
||||
arity: $arity,
|
||||
constructors:constructor_set!( $( $constr )|* ),
|
||||
scalars: scalar_set!( $( $scalar )|* ),
|
||||
conclude: ConclusionRule::$conclude,
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(in crate::proc::overloads) use regular;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::ir;
|
||||
|
||||
const fn scalar(scalar: ir::Scalar) -> ir::TypeInner {
|
||||
ir::TypeInner::Scalar(scalar)
|
||||
}
|
||||
|
||||
const fn vec2(scalar: ir::Scalar) -> ir::TypeInner {
|
||||
ir::TypeInner::Vector {
|
||||
scalar,
|
||||
size: ir::VectorSize::Bi,
|
||||
}
|
||||
}
|
||||
|
||||
const fn vec3(scalar: ir::Scalar) -> ir::TypeInner {
|
||||
ir::TypeInner::Vector {
|
||||
scalar,
|
||||
size: ir::VectorSize::Tri,
|
||||
}
|
||||
}
|
||||
|
||||
/// Assert that `set` has a most preferred candidate whose type
|
||||
/// conclusion is `expected`.
|
||||
#[track_caller]
|
||||
fn check_return_type(set: &Regular, expected: &ir::TypeInner, arena: &UniqueArena<ir::Type>) {
|
||||
assert!(!set.is_empty());
|
||||
|
||||
let special_types = ir::SpecialTypes::default();
|
||||
|
||||
let preferred = set.most_preferred();
|
||||
let conclusion = preferred.conclusion;
|
||||
let resolution = conclusion
|
||||
.into_resolution(&special_types)
|
||||
.expect("special types should have been pre-registered");
|
||||
let inner = resolution.inner_with(arena);
|
||||
|
||||
assert!(
|
||||
inner.equivalent(expected, arena),
|
||||
"Expected {:?}, got {:?}",
|
||||
expected.for_debug(arena),
|
||||
inner.for_debug(arena),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unary_vec_or_scalar_numeric_scalar() {
|
||||
let arena = UniqueArena::default();
|
||||
|
||||
let builtin = regular!(1, SCALAR of NUMERIC);
|
||||
|
||||
let ok = builtin.arg(0, &scalar(ir::Scalar::U32), &arena);
|
||||
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
|
||||
|
||||
let err = builtin.arg(0, &scalar(ir::Scalar::BOOL), &arena);
|
||||
assert!(err.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unary_vec_or_scalar_numeric_vector() {
|
||||
let arena = UniqueArena::default();
|
||||
|
||||
let builtin = regular!(1, VECN|SCALAR of NUMERIC);
|
||||
|
||||
let ok = builtin.arg(0, &vec3(ir::Scalar::F64), &arena);
|
||||
check_return_type(&ok, &vec3(ir::Scalar::F64), &arena);
|
||||
|
||||
let err = builtin.arg(0, &vec3(ir::Scalar::BOOL), &arena);
|
||||
assert!(err.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unary_vec_or_scalar_numeric_matrix() {
|
||||
let arena = UniqueArena::default();
|
||||
|
||||
let builtin = regular!(1, VECN|SCALAR of NUMERIC);
|
||||
|
||||
let err = builtin.arg(
|
||||
0,
|
||||
&ir::TypeInner::Matrix {
|
||||
columns: ir::VectorSize::Tri,
|
||||
rows: ir::VectorSize::Tri,
|
||||
scalar: ir::Scalar::F32,
|
||||
},
|
||||
&arena,
|
||||
);
|
||||
assert!(err.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[rustfmt::skip]
|
||||
fn binary_vec_or_scalar_numeric_scalar() {
|
||||
let arena = UniqueArena::default();
|
||||
|
||||
let builtin = regular!(2, VECN|SCALAR of NUMERIC);
|
||||
|
||||
let ok = builtin
|
||||
.arg(0, &scalar(ir::Scalar::F32), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::F32), &arena);
|
||||
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
|
||||
|
||||
let ok = builtin
|
||||
.arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::F32), &arena);
|
||||
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
|
||||
|
||||
let ok = builtin
|
||||
.arg(0, &scalar(ir::Scalar::F32), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::ABSTRACT_INT), &arena);
|
||||
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
|
||||
|
||||
let ok = builtin
|
||||
.arg(0, &scalar(ir::Scalar::U32), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::U32), &arena);
|
||||
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
|
||||
|
||||
let ok = builtin
|
||||
.arg(0, &scalar(ir::Scalar::U32), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::ABSTRACT_INT), &arena);
|
||||
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
|
||||
|
||||
let ok = builtin
|
||||
.arg(0, &scalar(ir::Scalar::ABSTRACT_INT), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::U32), &arena);
|
||||
check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
|
||||
|
||||
// Not numeric.
|
||||
let err = builtin
|
||||
.arg(0, &scalar(ir::Scalar::BOOL), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::BOOL), &arena);
|
||||
assert!(err.is_empty());
|
||||
|
||||
// Different floating-point types.
|
||||
let err = builtin
|
||||
.arg(0, &scalar(ir::Scalar::F32), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::F64), &arena);
|
||||
assert!(err.is_empty());
|
||||
|
||||
// Different constructor.
|
||||
let err = builtin
|
||||
.arg(0, &scalar(ir::Scalar::F32), &arena)
|
||||
.arg(1, &vec2(ir::Scalar::F32), &arena);
|
||||
assert!(err.is_empty());
|
||||
|
||||
// Different vector size
|
||||
let err = builtin
|
||||
.arg(0, &vec2(ir::Scalar::F32), &arena)
|
||||
.arg(1, &vec3(ir::Scalar::F32), &arena);
|
||||
assert!(err.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[rustfmt::skip]
|
||||
fn binary_vec_or_scalar_numeric_vector() {
|
||||
let arena = UniqueArena::default();
|
||||
|
||||
let builtin = regular!(2, VECN|SCALAR of NUMERIC);
|
||||
|
||||
let ok = builtin
|
||||
.arg(0, &vec3(ir::Scalar::F32), &arena)
|
||||
.arg(1, &vec3(ir::Scalar::F32), &arena);
|
||||
check_return_type(&ok, &vec3(ir::Scalar::F32), &arena);
|
||||
|
||||
// Different vector sizes.
|
||||
let err = builtin
|
||||
.arg(0, &vec2(ir::Scalar::F32), &arena)
|
||||
.arg(1, &vec3(ir::Scalar::F32), &arena);
|
||||
assert!(err.is_empty());
|
||||
|
||||
// Different vector scalars.
|
||||
let err = builtin
|
||||
.arg(0, &vec3(ir::Scalar::F32), &arena)
|
||||
.arg(1, &vec3(ir::Scalar::F64), &arena);
|
||||
assert!(err.is_empty());
|
||||
|
||||
// Mix of vectors and scalars.
|
||||
let err = builtin
|
||||
.arg(0, &scalar(ir::Scalar::F32), &arena)
|
||||
.arg(1, &vec3(ir::Scalar::F32), &arena);
|
||||
assert!(err.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[rustfmt::skip]
|
||||
fn binary_vec_or_scalar_numeric_vector_abstract() {
|
||||
let arena = UniqueArena::default();
|
||||
|
||||
let builtin = regular!(2, VECN|SCALAR of NUMERIC);
|
||||
|
||||
let ok = builtin
|
||||
.arg(0, &vec2(ir::Scalar::ABSTRACT_INT), &arena)
|
||||
.arg(1, &vec2(ir::Scalar::U32), &arena);
|
||||
check_return_type(&ok, &vec2(ir::Scalar::U32), &arena);
|
||||
|
||||
let ok = builtin
|
||||
.arg(0, &vec3(ir::Scalar::ABSTRACT_INT), &arena)
|
||||
.arg(1, &vec3(ir::Scalar::F32), &arena);
|
||||
check_return_type(&ok, &vec3(ir::Scalar::F32), &arena);
|
||||
|
||||
let ok = builtin
|
||||
.arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::F32), &arena);
|
||||
check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
|
||||
|
||||
let err = builtin
|
||||
.arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::U32), &arena);
|
||||
assert!(err.is_empty());
|
||||
|
||||
let err = builtin
|
||||
.arg(0, &scalar(ir::Scalar::I32), &arena)
|
||||
.arg(1, &scalar(ir::Scalar::U32), &arena);
|
||||
assert!(err.is_empty());
|
||||
}
|
||||
}
|
146
naga/src/proc/overloads/rule.rs
Normal file
146
naga/src/proc/overloads/rule.rs
Normal file
@ -0,0 +1,146 @@
|
||||
/*! Type rules.
|
||||
|
||||
An implementation of [`OverloadSet`] represents a set of type rules, each of
|
||||
which has a list of types for its arguments, and a conclusion about the
|
||||
type of the expression as a whole.
|
||||
|
||||
This module defines the [`Rule`] type, representing a type rule from an
|
||||
[`OverloadSet`], and the [`Conclusion`] type, a specialized enum for
|
||||
representing a type rule's conclusion.
|
||||
|
||||
[`OverloadSet`]: crate::proc::overloads::OverloadSet
|
||||
|
||||
*/
|
||||
|
||||
use crate::common::{DiagnosticDebug, ForDebugWithTypes};
|
||||
use crate::ir;
|
||||
use crate::proc::overloads::constructor_set::ConstructorSize;
|
||||
use crate::proc::TypeResolution;
|
||||
use crate::UniqueArena;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt;
|
||||
use core::result::Result;
|
||||
|
||||
/// A single type rule.
|
||||
#[derive(Clone)]
|
||||
pub struct Rule {
|
||||
pub arguments: Vec<TypeResolution>,
|
||||
pub conclusion: Conclusion,
|
||||
}
|
||||
|
||||
/// The result type of a [`Rule`].
|
||||
///
|
||||
/// A `Conclusion` value represents the return type of some operation
|
||||
/// in the builtin function database.
|
||||
///
|
||||
/// This is very similar to [`TypeInner`], except that it represents
|
||||
/// predeclared types using [`PredeclaredType`], so that overload
|
||||
/// resolution can delegate registering predeclared types to its users.
|
||||
///
|
||||
/// [`TypeInner`]: ir::TypeInner
|
||||
/// [`PredeclaredType`]: ir::PredeclaredType
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum Conclusion {
|
||||
/// A type that can be entirely characterized by a [`TypeInner`] value.
|
||||
///
|
||||
/// [`TypeInner`]: ir::TypeInner
|
||||
Value(ir::TypeInner),
|
||||
|
||||
/// A type that should be registered in the module's
|
||||
/// [`SpecialTypes::predeclared_types`] table.
|
||||
///
|
||||
/// This is used for operations like [`Frexp`] and [`Modf`].
|
||||
///
|
||||
/// [`SpecialTypes::predeclared_types`]: ir::SpecialTypes::predeclared_types
|
||||
/// [`Frexp`]: crate::ir::MathFunction::Frexp
|
||||
/// [`Modf`]: crate::ir::MathFunction::Modf
|
||||
Predeclared(ir::PredeclaredType),
|
||||
}
|
||||
|
||||
impl Conclusion {
|
||||
pub fn for_frexp_modf(
|
||||
function: ir::MathFunction,
|
||||
size: ConstructorSize,
|
||||
scalar: ir::Scalar,
|
||||
) -> Self {
|
||||
use ir::MathFunction as Mf;
|
||||
use ir::PredeclaredType as Pt;
|
||||
|
||||
let size = match size {
|
||||
ConstructorSize::Scalar => None,
|
||||
ConstructorSize::Vector(size) => Some(size),
|
||||
ConstructorSize::Matrix { .. } => {
|
||||
unreachable!("FrexpModf only supports scalars and vectors");
|
||||
}
|
||||
};
|
||||
|
||||
let predeclared = match function {
|
||||
Mf::Frexp => Pt::FrexpResult { size, scalar },
|
||||
Mf::Modf => Pt::ModfResult { size, scalar },
|
||||
_ => {
|
||||
unreachable!("FrexpModf only supports Frexp and Modf");
|
||||
}
|
||||
};
|
||||
|
||||
Conclusion::Predeclared(predeclared)
|
||||
}
|
||||
|
||||
pub fn into_resolution(
|
||||
self,
|
||||
special_types: &ir::SpecialTypes,
|
||||
) -> Result<TypeResolution, MissingSpecialType> {
|
||||
match self {
|
||||
Conclusion::Value(inner) => Ok(TypeResolution::Value(inner)),
|
||||
Conclusion::Predeclared(predeclared) => {
|
||||
let handle = *special_types
|
||||
.predeclared_types
|
||||
.get(&predeclared)
|
||||
.ok_or(MissingSpecialType)?;
|
||||
Ok(TypeResolution::Handle(handle))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Special type is not registered within the module")]
|
||||
pub struct MissingSpecialType;
|
||||
|
||||
impl ForDebugWithTypes for &Rule {}
|
||||
|
||||
impl fmt::Debug for DiagnosticDebug<(&Rule, &UniqueArena<ir::Type>)> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let (rule, arena) = self.0;
|
||||
f.write_str("(")?;
|
||||
for (i, argument) in rule.arguments.iter().enumerate() {
|
||||
if i > 0 {
|
||||
f.write_str(", ")?;
|
||||
}
|
||||
write!(f, "{:?}", argument.for_debug(arena))?;
|
||||
}
|
||||
write!(f, ") -> {:?}", rule.conclusion.for_debug(arena))
|
||||
}
|
||||
}
|
||||
|
||||
impl ForDebugWithTypes for &Conclusion {}
|
||||
|
||||
impl fmt::Debug for DiagnosticDebug<(&Conclusion, &UniqueArena<ir::Type>)> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let (conclusion, ctx) = self.0;
|
||||
|
||||
#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))]
|
||||
{
|
||||
use crate::common::wgsl::TypeContext;
|
||||
ctx.write_type_conclusion(conclusion, f)?;
|
||||
}
|
||||
|
||||
#[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))]
|
||||
{
|
||||
let _ = ctx;
|
||||
write!(f, "{conclusion:?}")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
141
naga/src/proc/overloads/scalar_set.rs
Normal file
141
naga/src/proc/overloads/scalar_set.rs
Normal file
@ -0,0 +1,141 @@
|
||||
//! A set of scalar types, represented as a bitset.
|
||||
|
||||
use crate::ir::Scalar;
|
||||
use crate::proc::overloads::one_bits_iter::OneBitsIter;
|
||||
|
||||
macro_rules! define_scalar_set {
|
||||
{ $( $scalar:ident, )* } => {
|
||||
/// An enum used to assign distinct bit numbers to [`ScalarSet`] elements.
|
||||
#[expect(non_camel_case_types, clippy::upper_case_acronyms)]
|
||||
#[repr(u32)]
|
||||
enum ScalarSetBits {
|
||||
$( $scalar, )*
|
||||
Count,
|
||||
}
|
||||
|
||||
/// A table mapping bit numbers to the [`Scalar`] values they represent.
|
||||
static SCALARS_FOR_BITS: [Scalar; ScalarSetBits::Count as usize] = [
|
||||
$(
|
||||
Scalar::$scalar,
|
||||
)*
|
||||
];
|
||||
|
||||
bitflags::bitflags! {
|
||||
/// A set of scalar types.
|
||||
///
|
||||
/// This represents a set of [`Scalar`] types.
|
||||
///
|
||||
/// The Naga IR conversion rules arrange scalar types into a
|
||||
/// lattice. The scalar types' bit values are chosen such that, if
|
||||
/// A is convertible to B, then A's bit value is less than B's.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub(crate) struct ScalarSet: u16 {
|
||||
$(
|
||||
const $scalar = 1 << (ScalarSetBits::$scalar as u32);
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarSet {
|
||||
/// Return the set of scalars containing only `scalar`.
|
||||
#[expect(dead_code)]
|
||||
pub const fn singleton(scalar: Scalar) -> Self {
|
||||
match scalar {
|
||||
$(
|
||||
Scalar::$scalar => Self::$scalar,
|
||||
)*
|
||||
_ => Self::empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
define_scalar_set! {
|
||||
// Scalar types must be listed here in an order such that, if A is
|
||||
// convertible to B, then A appears before B.
|
||||
//
|
||||
// In the concrete types, the 32-bit types *must* appear before
|
||||
// other sizes, since that is how we represent conversion rank.
|
||||
ABSTRACT_INT, ABSTRACT_FLOAT,
|
||||
I32, I64,
|
||||
U32, U64,
|
||||
F32, F16, F64,
|
||||
BOOL,
|
||||
}
|
||||
|
||||
impl ScalarSet {
|
||||
/// Return the set of scalars to which `scalar` can be automatically
|
||||
/// converted.
|
||||
pub fn convertible_from(scalar: Scalar) -> Self {
|
||||
use Scalar as Sc;
|
||||
match scalar {
|
||||
Sc::I32 => Self::I32,
|
||||
Sc::I64 => Self::I64,
|
||||
Sc::U32 => Self::U32,
|
||||
Sc::U64 => Self::U64,
|
||||
Sc::F16 => Self::F16,
|
||||
Sc::F32 => Self::F32,
|
||||
Sc::F64 => Self::F64,
|
||||
Sc::BOOL => Self::BOOL,
|
||||
Sc::ABSTRACT_INT => Self::INTEGER | Self::FLOAT,
|
||||
Sc::ABSTRACT_FLOAT => Self::FLOAT,
|
||||
_ => Self::empty(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the lowest-ranked member of `self` as a [`Scalar`].
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `self` is empty.
|
||||
pub fn most_general_scalar(self) -> Scalar {
|
||||
// If the set is empty, this returns the full bit-length of
|
||||
// `self.bits()`, an index which is out of bounds for
|
||||
// `SCALARS_FOR_BITS`.
|
||||
let lowest = self.bits().trailing_zeros();
|
||||
*SCALARS_FOR_BITS.get(lowest as usize).unwrap()
|
||||
}
|
||||
|
||||
/// Return an iterator over this set's members.
|
||||
///
|
||||
/// Members are produced as singleton, in order from most general to least.
|
||||
pub fn members(self) -> impl Iterator<Item = ScalarSet> {
|
||||
OneBitsIter::new(self.bits() as u64).map(|bit| Self::from_bits(bit as u16).unwrap())
|
||||
}
|
||||
|
||||
pub const FLOAT: Self = Self::ABSTRACT_FLOAT
|
||||
.union(Self::F16)
|
||||
.union(Self::F32)
|
||||
.union(Self::F64);
|
||||
|
||||
pub const INTEGER: Self = Self::ABSTRACT_INT
|
||||
.union(Self::I32)
|
||||
.union(Self::I64)
|
||||
.union(Self::U32)
|
||||
.union(Self::U64);
|
||||
|
||||
pub const NUMERIC: Self = Self::FLOAT.union(Self::INTEGER);
|
||||
pub const ABSTRACT: Self = Self::ABSTRACT_INT.union(Self::ABSTRACT_FLOAT);
|
||||
pub const CONCRETE: Self = Self::all().difference(Self::ABSTRACT);
|
||||
pub const CONCRETE_INTEGER: Self = Self::INTEGER.intersection(Self::CONCRETE);
|
||||
pub const CONCRETE_FLOAT: Self = Self::FLOAT.intersection(Self::CONCRETE);
|
||||
|
||||
/// Floating-point scalars, with the abstract floats omitted for
|
||||
/// #7405.
|
||||
pub const FLOAT_ABSTRACT_UNIMPLEMENTED: Self = Self::CONCRETE_FLOAT;
|
||||
}
|
||||
|
||||
macro_rules! scalar_set {
|
||||
( $( $scalar:ident )|* ) => {
|
||||
{
|
||||
use $crate::proc::overloads::scalar_set::ScalarSet;
|
||||
ScalarSet::empty()
|
||||
$(
|
||||
.union(ScalarSet::$scalar)
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(in crate::proc::overloads) use scalar_set;
|
113
naga/src/proc/overloads/utils.rs
Normal file
113
naga/src/proc/overloads/utils.rs
Normal file
@ -0,0 +1,113 @@
|
||||
//! Utility functions for constructing [`List`] overload sets.
|
||||
//!
|
||||
//! [`List`]: crate::proc::overloads::list::List
|
||||
|
||||
use crate::ir;
|
||||
use crate::proc::overloads::list::List;
|
||||
use crate::proc::overloads::rule::{Conclusion, Rule};
|
||||
use crate::proc::TypeResolution;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// Produce all vector sizes.
|
||||
pub fn vector_sizes() -> impl Iterator<Item = ir::VectorSize> + Clone {
|
||||
static SIZES: [ir::VectorSize; 3] = [
|
||||
ir::VectorSize::Bi,
|
||||
ir::VectorSize::Tri,
|
||||
ir::VectorSize::Quad,
|
||||
];
|
||||
|
||||
SIZES.iter().cloned()
|
||||
}
|
||||
|
||||
/// Produce all the floating-point [`ir::Scalar`]s.
|
||||
///
|
||||
/// Note that `F32` must appear before other sizes; this is how we
|
||||
/// represent conversion rank.
|
||||
pub fn float_scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
|
||||
[
|
||||
ir::Scalar::ABSTRACT_FLOAT,
|
||||
ir::Scalar::F32,
|
||||
ir::Scalar::F16,
|
||||
ir::Scalar::F64,
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
|
||||
/// Produce all the floating-point [`ir::Scalar`]s, but omit
|
||||
/// abstract types, for #7405.
|
||||
pub fn float_scalars_unimplemented_abstract() -> impl Iterator<Item = ir::Scalar> + Clone {
|
||||
[ir::Scalar::F32, ir::Scalar::F16, ir::Scalar::F64].into_iter()
|
||||
}
|
||||
|
||||
/// Produce all concrete integer [`ir::Scalar`]s.
|
||||
///
|
||||
/// Note that `I32` and `U32` must come first; this is how we
|
||||
/// represent conversion rank.
|
||||
pub fn concrete_int_scalars() -> impl Iterator<Item = ir::Scalar> {
|
||||
[
|
||||
ir::Scalar::I32,
|
||||
ir::Scalar::U32,
|
||||
ir::Scalar::I64,
|
||||
ir::Scalar::U64,
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
|
||||
/// Produce the scalar and vector [`ir::TypeInner`]s that have `s` as
|
||||
/// their scalar.
|
||||
pub fn scalar_or_vecn(scalar: ir::Scalar) -> impl Iterator<Item = ir::TypeInner> {
|
||||
[
|
||||
ir::TypeInner::Scalar(scalar),
|
||||
ir::TypeInner::Vector {
|
||||
size: ir::VectorSize::Bi,
|
||||
scalar,
|
||||
},
|
||||
ir::TypeInner::Vector {
|
||||
size: ir::VectorSize::Tri,
|
||||
scalar,
|
||||
},
|
||||
ir::TypeInner::Vector {
|
||||
size: ir::VectorSize::Quad,
|
||||
scalar,
|
||||
},
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
|
||||
/// Construct a [`Rule`] for an operation with the given
|
||||
/// argument types and return type.
|
||||
pub fn rule<const N: usize>(args: [ir::TypeInner; N], ret: ir::TypeInner) -> Rule {
|
||||
Rule {
|
||||
arguments: Vec::from_iter(args.into_iter().map(TypeResolution::Value)),
|
||||
conclusion: Conclusion::Value(ret),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [`List`] from the given rules.
|
||||
pub fn list(rules: impl Iterator<Item = Rule>) -> List {
|
||||
List::from_rules(rules.collect())
|
||||
}
|
||||
|
||||
/// Return the cartesian product of two iterators.
|
||||
pub fn pairs<T: Clone, U>(
|
||||
left: impl Iterator<Item = T>,
|
||||
right: impl Iterator<Item = U> + Clone,
|
||||
) -> impl Iterator<Item = (T, U)> {
|
||||
left.flat_map(move |t| right.clone().map(move |u| (t.clone(), u)))
|
||||
}
|
||||
|
||||
/// Return the cartesian product of three iterators.
|
||||
pub fn triples<T: Clone, U: Clone, V>(
|
||||
left: impl Iterator<Item = T>,
|
||||
middle: impl Iterator<Item = U> + Clone,
|
||||
right: impl Iterator<Item = V> + Clone,
|
||||
) -> impl Iterator<Item = (T, U, V)> {
|
||||
left.flat_map(move |t| {
|
||||
let right = right.clone();
|
||||
middle.clone().flat_map(move |u| {
|
||||
let t = t.clone();
|
||||
right.clone().map(move |v| (t.clone(), u.clone(), v))
|
||||
})
|
||||
})
|
||||
}
|
@ -3,6 +3,7 @@ use alloc::{format, string::String};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::arena::{Arena, Handle, UniqueArena};
|
||||
use crate::common::ForDebugWithTypes;
|
||||
|
||||
/// The result of computing an expression's type.
|
||||
///
|
||||
@ -191,6 +192,14 @@ pub enum ResolveError {
|
||||
FunctionArgumentNotFound(u32),
|
||||
#[error("Special type is not registered within the module")]
|
||||
MissingSpecialType,
|
||||
#[error("Call to builtin {0} has incorrect or ambiguous arguments")]
|
||||
BuiltinArgumentsInvalid(String),
|
||||
}
|
||||
|
||||
impl From<crate::proc::MissingSpecialType> for ResolveError {
|
||||
fn from(_unit_struct: crate::proc::MissingSpecialType) -> Self {
|
||||
ResolveError::MissingSpecialType
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ResolveContext<'a> {
|
||||
@ -635,210 +644,46 @@ impl<'a> ResolveContext<'a> {
|
||||
arg2: _,
|
||||
arg3: _,
|
||||
} => {
|
||||
use crate::MathFunction as Mf;
|
||||
use crate::proc::OverloadSet as _;
|
||||
|
||||
let mut overloads = fun.overloads();
|
||||
log::debug!(
|
||||
"initial overloads for {fun:?}, {:#?}",
|
||||
overloads.for_debug(types)
|
||||
);
|
||||
|
||||
// If any argument is not a constant expression, then no
|
||||
// overloads that accept abstract values should be considered.
|
||||
// `OverloadSet::concrete_only` is supposed to help impose this
|
||||
// restriction. However, no `MathFunction` accepts a mix of
|
||||
// abstract and concrete arguments, so we don't need to worry
|
||||
// about that here.
|
||||
|
||||
let res_arg = past(arg)?;
|
||||
match fun {
|
||||
Mf::Abs
|
||||
| Mf::Min
|
||||
| Mf::Max
|
||||
| Mf::Clamp
|
||||
| Mf::Saturate
|
||||
| Mf::Cos
|
||||
| Mf::Cosh
|
||||
| Mf::Sin
|
||||
| Mf::Sinh
|
||||
| Mf::Tan
|
||||
| Mf::Tanh
|
||||
| Mf::Acos
|
||||
| Mf::Asin
|
||||
| Mf::Atan
|
||||
| Mf::Atan2
|
||||
| Mf::Asinh
|
||||
| Mf::Acosh
|
||||
| Mf::Atanh
|
||||
| Mf::Radians
|
||||
| Mf::Degrees
|
||||
| Mf::Ceil
|
||||
| Mf::Floor
|
||||
| Mf::Round
|
||||
| Mf::Fract
|
||||
| Mf::Trunc
|
||||
| Mf::Ldexp
|
||||
| Mf::Exp
|
||||
| Mf::Exp2
|
||||
| Mf::Log
|
||||
| Mf::Log2
|
||||
| Mf::Pow
|
||||
| Mf::QuantizeToF16 => res_arg.clone(),
|
||||
Mf::Modf | Mf::Frexp => {
|
||||
let (size, scalar) = match res_arg.inner_with(types) {
|
||||
&Ti::Scalar(scalar) => (None, scalar),
|
||||
&Ti::Vector { scalar, size } => (Some(size), scalar),
|
||||
ref other => {
|
||||
return Err(ResolveError::IncompatibleOperands(format!(
|
||||
"{fun:?}({other:?}, _)"
|
||||
)))
|
||||
}
|
||||
};
|
||||
let result = self
|
||||
.special_types
|
||||
.predeclared_types
|
||||
.get(&if fun == Mf::Modf {
|
||||
crate::PredeclaredType::ModfResult { size, scalar }
|
||||
} else {
|
||||
crate::PredeclaredType::FrexpResult { size, scalar }
|
||||
})
|
||||
.ok_or(ResolveError::MissingSpecialType)?;
|
||||
TypeResolution::Handle(*result)
|
||||
}
|
||||
Mf::Dot => match *res_arg.inner_with(types) {
|
||||
Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)),
|
||||
ref other => {
|
||||
return Err(ResolveError::IncompatibleOperands(format!(
|
||||
"{fun:?}({other:?}, _)"
|
||||
)))
|
||||
}
|
||||
},
|
||||
Mf::Outer => {
|
||||
let arg1 = arg1.ok_or_else(|| {
|
||||
ResolveError::IncompatibleOperands(format!("{fun:?}(_, None)"))
|
||||
})?;
|
||||
match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) {
|
||||
(
|
||||
&Ti::Vector {
|
||||
size: columns,
|
||||
scalar,
|
||||
},
|
||||
&Ti::Vector { size: rows, .. },
|
||||
) => TypeResolution::Value(Ti::Matrix {
|
||||
columns,
|
||||
rows,
|
||||
scalar,
|
||||
}),
|
||||
(left, right) => {
|
||||
return Err(ResolveError::IncompatibleOperands(format!(
|
||||
"{fun:?}({left:?}, {right:?})"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
Mf::Cross => res_arg.clone(),
|
||||
Mf::Distance | Mf::Length => match *res_arg.inner_with(types) {
|
||||
Ti::Scalar(scalar) | Ti::Vector { scalar, size: _ } => {
|
||||
TypeResolution::Value(Ti::Scalar(scalar))
|
||||
}
|
||||
ref other => {
|
||||
return Err(ResolveError::IncompatibleOperands(format!(
|
||||
"{fun:?}({other:?})"
|
||||
)))
|
||||
}
|
||||
},
|
||||
Mf::Normalize | Mf::FaceForward | Mf::Reflect | Mf::Refract => res_arg.clone(),
|
||||
// computational
|
||||
Mf::Sign
|
||||
| Mf::Fma
|
||||
| Mf::Mix
|
||||
| Mf::Step
|
||||
| Mf::SmoothStep
|
||||
| Mf::Sqrt
|
||||
| Mf::InverseSqrt => res_arg.clone(),
|
||||
Mf::Transpose => match *res_arg.inner_with(types) {
|
||||
Ti::Matrix {
|
||||
columns,
|
||||
rows,
|
||||
scalar,
|
||||
} => TypeResolution::Value(Ti::Matrix {
|
||||
columns: rows,
|
||||
rows: columns,
|
||||
scalar,
|
||||
}),
|
||||
ref other => {
|
||||
return Err(ResolveError::IncompatibleOperands(format!(
|
||||
"{fun:?}({other:?})"
|
||||
)))
|
||||
}
|
||||
},
|
||||
Mf::Inverse => match *res_arg.inner_with(types) {
|
||||
Ti::Matrix {
|
||||
columns,
|
||||
rows,
|
||||
scalar,
|
||||
} if columns == rows => TypeResolution::Value(Ti::Matrix {
|
||||
columns,
|
||||
rows,
|
||||
scalar,
|
||||
}),
|
||||
ref other => {
|
||||
return Err(ResolveError::IncompatibleOperands(format!(
|
||||
"{fun:?}({other:?})"
|
||||
)))
|
||||
}
|
||||
},
|
||||
Mf::Determinant => match *res_arg.inner_with(types) {
|
||||
Ti::Matrix { scalar, .. } => TypeResolution::Value(Ti::Scalar(scalar)),
|
||||
ref other => {
|
||||
return Err(ResolveError::IncompatibleOperands(format!(
|
||||
"{fun:?}({other:?})"
|
||||
)))
|
||||
}
|
||||
},
|
||||
// bits
|
||||
Mf::CountTrailingZeros
|
||||
| Mf::CountLeadingZeros
|
||||
| Mf::CountOneBits
|
||||
| Mf::ReverseBits
|
||||
| Mf::ExtractBits
|
||||
| Mf::InsertBits
|
||||
| Mf::FirstTrailingBit
|
||||
| Mf::FirstLeadingBit => match *res_arg.inner_with(types) {
|
||||
Ti::Scalar(
|
||||
scalar @ crate::Scalar {
|
||||
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
|
||||
..
|
||||
},
|
||||
) => TypeResolution::Value(Ti::Scalar(scalar)),
|
||||
Ti::Vector {
|
||||
size,
|
||||
scalar:
|
||||
scalar @ crate::Scalar {
|
||||
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
|
||||
..
|
||||
},
|
||||
} => TypeResolution::Value(Ti::Vector { size, scalar }),
|
||||
ref other => {
|
||||
return Err(ResolveError::IncompatibleOperands(format!(
|
||||
"{fun:?}({other:?})"
|
||||
)))
|
||||
}
|
||||
},
|
||||
// data packing
|
||||
Mf::Pack4x8snorm
|
||||
| Mf::Pack4x8unorm
|
||||
| Mf::Pack2x16snorm
|
||||
| Mf::Pack2x16unorm
|
||||
| Mf::Pack2x16float
|
||||
| Mf::Pack4xI8
|
||||
| Mf::Pack4xU8 => TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)),
|
||||
// data unpacking
|
||||
Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector {
|
||||
size: crate::VectorSize::Quad,
|
||||
scalar: crate::Scalar::F32,
|
||||
}),
|
||||
Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack2x16float => {
|
||||
TypeResolution::Value(Ti::Vector {
|
||||
size: crate::VectorSize::Bi,
|
||||
scalar: crate::Scalar::F32,
|
||||
})
|
||||
}
|
||||
Mf::Unpack4xI8 => TypeResolution::Value(Ti::Vector {
|
||||
size: crate::VectorSize::Quad,
|
||||
scalar: crate::Scalar::I32,
|
||||
}),
|
||||
Mf::Unpack4xU8 => TypeResolution::Value(Ti::Vector {
|
||||
size: crate::VectorSize::Quad,
|
||||
scalar: crate::Scalar::U32,
|
||||
}),
|
||||
overloads = overloads.arg(0, res_arg.inner_with(types), types);
|
||||
log::debug!(
|
||||
"overloads after arg 0 of type {:?}: {:#?}",
|
||||
res_arg.for_debug(types),
|
||||
overloads.for_debug(types)
|
||||
);
|
||||
|
||||
if let Some(arg1) = arg1 {
|
||||
let res_arg1 = past(arg1)?;
|
||||
overloads = overloads.arg(1, res_arg1.inner_with(types), types);
|
||||
log::debug!(
|
||||
"overloads after arg 1 of type {:?}: {:#?}",
|
||||
res_arg1.for_debug(types),
|
||||
overloads.for_debug(types)
|
||||
);
|
||||
}
|
||||
|
||||
if overloads.is_empty() {
|
||||
return Err(ResolveError::BuiltinArgumentsInvalid(format!("{fun:?}")));
|
||||
}
|
||||
|
||||
let rule = overloads.most_preferred();
|
||||
|
||||
rule.conclusion.into_resolution(self.special_types)?
|
||||
}
|
||||
crate::Expression::As {
|
||||
expr,
|
||||
|
@ -2,6 +2,7 @@ use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, T
|
||||
use crate::arena::UniqueArena;
|
||||
use crate::{
|
||||
arena::Handle,
|
||||
proc::OverloadSet as _,
|
||||
proc::{IndexableLengthError, ResolveError},
|
||||
};
|
||||
|
||||
@ -1016,659 +1017,59 @@ impl super::Validator {
|
||||
arg2,
|
||||
arg3,
|
||||
} => {
|
||||
use crate::MathFunction as Mf;
|
||||
let actuals: &[_] = match (arg1, arg2, arg3) {
|
||||
(None, None, None) => &[arg],
|
||||
(Some(arg1), None, None) => &[arg, arg1],
|
||||
(Some(arg1), Some(arg2), None) => &[arg, arg1, arg2],
|
||||
(Some(arg1), Some(arg2), Some(arg3)) => &[arg, arg1, arg2, arg3],
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
|
||||
let resolve = |arg| &resolver[arg];
|
||||
let arg_ty = resolve(arg);
|
||||
let arg1_ty = arg1.map(resolve);
|
||||
let arg2_ty = arg2.map(resolve);
|
||||
let arg3_ty = arg3.map(resolve);
|
||||
match fun {
|
||||
Mf::Abs => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
let good = match *arg_ty {
|
||||
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
|
||||
scalar.kind != Sk::Bool
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
|
||||
}
|
||||
let actual_types: &[_] = match *actuals {
|
||||
[arg0] => &[resolve(arg0)],
|
||||
[arg0, arg1] => &[resolve(arg0), resolve(arg1)],
|
||||
[arg0, arg1, arg2] => &[resolve(arg0), resolve(arg1), resolve(arg2)],
|
||||
[arg0, arg1, arg2, arg3] => {
|
||||
&[resolve(arg0), resolve(arg1), resolve(arg2), resolve(arg3)]
|
||||
}
|
||||
Mf::Min | Mf::Max => {
|
||||
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), None, None) => ty1,
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
let good = match *arg_ty {
|
||||
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
|
||||
scalar.kind != Sk::Bool
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Clamp => {
|
||||
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), Some(ty2), None) => (ty1, ty2),
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
let good = match *arg_ty {
|
||||
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
|
||||
scalar.kind != Sk::Bool
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
if arg2_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
2,
|
||||
arg2.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Saturate
|
||||
| Mf::Cos
|
||||
| Mf::Cosh
|
||||
| Mf::Sin
|
||||
| Mf::Sinh
|
||||
| Mf::Tan
|
||||
| Mf::Tanh
|
||||
| Mf::Acos
|
||||
| Mf::Asin
|
||||
| Mf::Atan
|
||||
| Mf::Asinh
|
||||
| Mf::Acosh
|
||||
| Mf::Atanh
|
||||
| Mf::Radians
|
||||
| Mf::Degrees
|
||||
| Mf::Ceil
|
||||
| Mf::Floor
|
||||
| Mf::Round
|
||||
| Mf::Fract
|
||||
| Mf::Trunc
|
||||
| Mf::Exp
|
||||
| Mf::Exp2
|
||||
| Mf::Log
|
||||
| Mf::Log2
|
||||
| Mf::Length
|
||||
| Mf::Sqrt
|
||||
| Mf::InverseSqrt => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
|
||||
if scalar.kind == Sk::Float => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::Sign => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Scalar(Sc {
|
||||
kind: Sk::Float | Sk::Sint,
|
||||
..
|
||||
})
|
||||
| Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float | Sk::Sint,
|
||||
..
|
||||
},
|
||||
..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
|
||||
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), None, None) => ty1,
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
match *arg_ty {
|
||||
Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
|
||||
if scalar.kind == Sk::Float => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Modf | Mf::Frexp => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
if !matches!(*arg_ty,
|
||||
Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
|
||||
if scalar.kind == Sk::Float)
|
||||
{
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, 1, arg));
|
||||
}
|
||||
}
|
||||
Mf::Ldexp => {
|
||||
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), None, None) => ty1,
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
let size0 = match *arg_ty {
|
||||
Ti::Scalar(Sc {
|
||||
kind: Sk::Float, ..
|
||||
}) => None,
|
||||
Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float, ..
|
||||
},
|
||||
size,
|
||||
} => Some(size),
|
||||
_ => {
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
|
||||
}
|
||||
};
|
||||
let good = match *arg1_ty {
|
||||
Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true,
|
||||
Ti::Vector {
|
||||
size,
|
||||
scalar: Sc { kind: Sk::Sint, .. },
|
||||
} if Some(size) == size0 => true,
|
||||
_ => false,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Dot => {
|
||||
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), None, None) => ty1,
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
match *arg_ty {
|
||||
Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float | Sk::Sint | Sk::Uint,
|
||||
..
|
||||
},
|
||||
..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Outer | Mf::Reflect => {
|
||||
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), None, None) => ty1,
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
match *arg_ty {
|
||||
Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float, ..
|
||||
},
|
||||
..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Cross => {
|
||||
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), None, None) => ty1,
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
match *arg_ty {
|
||||
Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float, ..
|
||||
},
|
||||
size: crate::VectorSize::Tri,
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Refract => {
|
||||
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), Some(ty2), None) => (ty1, ty2),
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
match *arg_ty {
|
||||
Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float, ..
|
||||
},
|
||||
..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
// Start with the set of all overloads available for `fun`.
|
||||
let mut overloads = fun.overloads();
|
||||
log::debug!(
|
||||
"initial overloads for {:?}: {:#?}",
|
||||
fun,
|
||||
overloads.for_debug(&module.types)
|
||||
);
|
||||
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
// If any argument is not a constant expression, then no
|
||||
// overloads that accept abstract values should be considered.
|
||||
// `OverloadSet::concrete_only` is supposed to help impose this
|
||||
// restriction. However, no `MathFunction` accepts a mix of
|
||||
// abstract and concrete arguments, so we don't need to worry
|
||||
// about that here.
|
||||
|
||||
match (arg_ty, arg2_ty) {
|
||||
(
|
||||
&Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
width: vector_width,
|
||||
..
|
||||
},
|
||||
..
|
||||
},
|
||||
&Ti::Scalar(Sc {
|
||||
width: scalar_width,
|
||||
kind: Sk::Float,
|
||||
}),
|
||||
) if vector_width == scalar_width => {}
|
||||
_ => {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
2,
|
||||
arg2.unwrap(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
Mf::Normalize => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float, ..
|
||||
},
|
||||
..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
|
||||
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), Some(ty2), None) => (ty1, ty2),
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
match *arg_ty {
|
||||
Ti::Scalar(Sc {
|
||||
kind: Sk::Float, ..
|
||||
})
|
||||
| Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float, ..
|
||||
},
|
||||
..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
if arg2_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
2,
|
||||
arg2.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Mf::Mix => {
|
||||
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), Some(ty2), None) => (ty1, ty2),
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
let arg_width = match *arg_ty {
|
||||
Ti::Scalar(Sc {
|
||||
kind: Sk::Float,
|
||||
width,
|
||||
})
|
||||
| Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float,
|
||||
width,
|
||||
},
|
||||
..
|
||||
} => width,
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
};
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
// the last argument can always be a scalar
|
||||
match *arg2_ty {
|
||||
Ti::Scalar(Sc {
|
||||
kind: Sk::Float,
|
||||
width,
|
||||
}) if width == arg_width => {}
|
||||
_ if arg2_ty == arg_ty => {}
|
||||
_ => {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
2,
|
||||
arg2.unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Mf::Inverse | Mf::Determinant => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
let good = match *arg_ty {
|
||||
Ti::Matrix { columns, rows, .. } => columns == rows,
|
||||
_ => false,
|
||||
};
|
||||
if !good {
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
|
||||
}
|
||||
}
|
||||
Mf::Transpose => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Matrix { .. } => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::QuantizeToF16 => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Scalar(Sc {
|
||||
kind: Sk::Float,
|
||||
width: 4,
|
||||
})
|
||||
| Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float,
|
||||
width: 4,
|
||||
},
|
||||
..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
// Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
|
||||
Mf::CountLeadingZeros
|
||||
| Mf::CountTrailingZeros
|
||||
| Mf::CountOneBits
|
||||
| Mf::ReverseBits
|
||||
| Mf::FirstLeadingBit
|
||||
| Mf::FirstTrailingBit => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
|
||||
Sk::Sint | Sk::Uint => {
|
||||
if scalar.width != 4 {
|
||||
return Err(ExpressionError::UnsupportedWidth(
|
||||
fun,
|
||||
scalar.kind,
|
||||
scalar.width,
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
},
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::InsertBits => {
|
||||
let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
match *arg_ty {
|
||||
Ti::Scalar(Sc {
|
||||
kind: Sk::Sint | Sk::Uint,
|
||||
..
|
||||
})
|
||||
| Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Sint | Sk::Uint,
|
||||
..
|
||||
},
|
||||
..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
if arg1_ty != arg_ty {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
1,
|
||||
arg1.unwrap(),
|
||||
));
|
||||
}
|
||||
match *arg2_ty {
|
||||
Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
|
||||
_ => {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
2,
|
||||
arg2.unwrap(),
|
||||
))
|
||||
}
|
||||
}
|
||||
match *arg3_ty {
|
||||
Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
|
||||
_ => {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
2,
|
||||
arg3.unwrap(),
|
||||
))
|
||||
}
|
||||
}
|
||||
// Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
|
||||
for &arg in [arg_ty, arg1_ty, arg2_ty, arg3_ty].iter() {
|
||||
match *arg {
|
||||
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
|
||||
if scalar.width != 4 {
|
||||
return Err(ExpressionError::UnsupportedWidth(
|
||||
fun,
|
||||
scalar.kind,
|
||||
scalar.width,
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Mf::ExtractBits => {
|
||||
let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
|
||||
(Some(ty1), Some(ty2), None) => (ty1, ty2),
|
||||
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
|
||||
};
|
||||
match *arg_ty {
|
||||
Ti::Scalar(Sc {
|
||||
kind: Sk::Sint | Sk::Uint,
|
||||
..
|
||||
})
|
||||
| Ti::Vector {
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Sint | Sk::Uint,
|
||||
..
|
||||
},
|
||||
..
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
match *arg1_ty {
|
||||
Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
|
||||
_ => {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
2,
|
||||
arg1.unwrap(),
|
||||
))
|
||||
}
|
||||
}
|
||||
match *arg2_ty {
|
||||
Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
|
||||
_ => {
|
||||
return Err(ExpressionError::InvalidArgumentType(
|
||||
fun,
|
||||
2,
|
||||
arg2.unwrap(),
|
||||
))
|
||||
}
|
||||
}
|
||||
// Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
|
||||
for &arg in [arg_ty, arg1_ty, arg2_ty].iter() {
|
||||
match *arg {
|
||||
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
|
||||
if scalar.width != 4 {
|
||||
return Err(ExpressionError::UnsupportedWidth(
|
||||
fun,
|
||||
scalar.kind,
|
||||
scalar.width,
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Vector {
|
||||
size: crate::VectorSize::Bi,
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float, ..
|
||||
},
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Vector {
|
||||
size: crate::VectorSize::Quad,
|
||||
scalar:
|
||||
Sc {
|
||||
kind: Sk::Float, ..
|
||||
},
|
||||
} => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
mf @ (Mf::Pack4xI8 | Mf::Pack4xU8) => {
|
||||
let scalar_kind = match mf {
|
||||
Mf::Pack4xI8 => Sk::Sint,
|
||||
Mf::Pack4xU8 => Sk::Uint,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Vector {
|
||||
size: crate::VectorSize::Quad,
|
||||
scalar: Sc { kind, .. },
|
||||
} if kind == scalar_kind => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::Unpack2x16float
|
||||
| Mf::Unpack2x16snorm
|
||||
| Mf::Unpack2x16unorm
|
||||
| Mf::Unpack4x8snorm
|
||||
| Mf::Unpack4x8unorm
|
||||
| Mf::Unpack4xI8
|
||||
| Mf::Unpack4xU8 => {
|
||||
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
match *arg_ty {
|
||||
Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() {
|
||||
// Remove overloads that cannot accept an `i`'th
|
||||
// argument arguments of type `ty`.
|
||||
overloads = overloads.arg(i, ty, &module.types);
|
||||
log::debug!(
|
||||
"overloads after arg {i}: {:#?}",
|
||||
overloads.for_debug(&module.types)
|
||||
);
|
||||
|
||||
if overloads.is_empty() {
|
||||
log::debug!("all overloads eliminated");
|
||||
return Err(ExpressionError::InvalidArgumentType(fun, i as u32, expr));
|
||||
}
|
||||
}
|
||||
|
||||
if actuals.len() < overloads.min_arguments() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
|
||||
ShaderStages::all()
|
||||
}
|
||||
E::As {
|
||||
|
1
naga/tests/in/wgsl/abstract-types-builtins.toml
Normal file
1
naga/tests/in/wgsl/abstract-types-builtins.toml
Normal file
@ -0,0 +1 @@
|
||||
targets = "SPIRV | METAL | GLSL | WGSL"
|
102
naga/tests/in/wgsl/abstract-types-builtins.wgsl
Normal file
102
naga/tests/in/wgsl/abstract-types-builtins.wgsl
Normal file
@ -0,0 +1,102 @@
|
||||
fn f() {
|
||||
// For calls that return abstract types, we can't write the type we
|
||||
// actually expect, but we can at least require an automatic
|
||||
// conversion.
|
||||
//
|
||||
// Error cases are covered in `wgsl_errors::more_inconsistent_type`.
|
||||
|
||||
// start
|
||||
var clamp_aiaiai: i32 = clamp(1, 1, 1);
|
||||
var clamp_aiaiaf: f32 = clamp(1, 1, 1.0);
|
||||
var clamp_aiaii: i32 = clamp(1, 1, 1i);
|
||||
var clamp_aiaif: f32 = clamp(1, 1, 1f);
|
||||
var clamp_aiafai: f32 = clamp(1, 1.0, 1);
|
||||
var clamp_aiafaf: f32 = clamp(1, 1.0, 1.0);
|
||||
//var clamp_aiafi: f32 = clamp(1, 1.0, 1i); error
|
||||
var clamp_aiaff: f32 = clamp(1, 1.0, 1f);
|
||||
var clamp_aiiai: i32 = clamp(1, 1i, 1);
|
||||
//var clamp_aiiaf: f32 = clamp(1, 1i, 1.0); error
|
||||
var clamp_aiii: i32 = clamp(1, 1i, 1i);
|
||||
//var clamp_aiif: f32 = clamp(1, 1i, 1f); error
|
||||
var clamp_aifai: f32 = clamp(1, 1f, 1);
|
||||
var clamp_aifaf: f32 = clamp(1, 1f, 1.0);
|
||||
//var clamp_aifi: f32 = clamp(1, 1f, 1i); error
|
||||
var clamp_aiff: f32 = clamp(1, 1f, 1f);
|
||||
var clamp_afaiai: f32 = clamp(1.0, 1, 1);
|
||||
var clamp_afaiaf: f32 = clamp(1.0, 1, 1.0);
|
||||
//var clamp_afaii: f32 = clamp(1.0, 1, 1i); error
|
||||
var clamp_afaif: f32 = clamp(1.0, 1, 1f);
|
||||
var clamp_afafai: f32 = clamp(1.0, 1.0, 1);
|
||||
var clamp_afafaf: f32 = clamp(1.0, 1.0, 1.0);
|
||||
//var clamp_afafi: f32 = clamp(1.0, 1.0, 1i); error
|
||||
var clamp_afaff: f32 = clamp(1.0, 1.0, 1f);
|
||||
//var clamp_afiai: f32 = clamp(1.0, 1i, 1); error
|
||||
//var clamp_afiaf: f32 = clamp(1.0, 1i, 1.0); error
|
||||
//var clamp_afii: f32 = clamp(1.0, 1i, 1i); error
|
||||
//var clamp_afif: f32 = clamp(1.0, 1i, 1f); error
|
||||
var clamp_affai: f32 = clamp(1.0, 1f, 1);
|
||||
var clamp_affaf: f32 = clamp(1.0, 1f, 1.0);
|
||||
//var clamp_affi: f32 = clamp(1.0, 1f, 1i); error
|
||||
var clamp_afff: f32 = clamp(1.0, 1f, 1f);
|
||||
var clamp_iaiai: i32 = clamp(1i, 1, 1);
|
||||
//var clamp_iaiaf: f32 = clamp(1i, 1, 1.0); error
|
||||
var clamp_iaii: i32 = clamp(1i, 1, 1i);
|
||||
//var clamp_iaif: f32 = clamp(1i, 1, 1f); error
|
||||
//var clamp_iafai: f32 = clamp(1i, 1.0, 1); error
|
||||
//var clamp_iafaf: f32 = clamp(1i, 1.0, 1.0); error
|
||||
//var clamp_iafi: f32 = clamp(1i, 1.0, 1i); error
|
||||
//var clamp_iaff: f32 = clamp(1i, 1.0, 1f); error
|
||||
var clamp_iiai: i32 = clamp(1i, 1i, 1);
|
||||
//var clamp_iiaf: f32 = clamp(1i, 1i, 1.0); error
|
||||
var clamp_iii: i32 = clamp(1i, 1i, 1i);
|
||||
//var clamp_iif: f32 = clamp(1i, 1i, 1f); error
|
||||
//var clamp_ifai: f32 = clamp(1i, 1f, 1); error
|
||||
//var clamp_ifaf: f32 = clamp(1i, 1f, 1.0); error
|
||||
//var clamp_ifi: f32 = clamp(1i, 1f, 1i); error
|
||||
//var clamp_iff: f32 = clamp(1i, 1f, 1f); error
|
||||
var clamp_faiai: f32 = clamp(1f, 1, 1);
|
||||
var clamp_faiaf: f32 = clamp(1f, 1, 1.0);
|
||||
//var clamp_faii: f32 = clamp(1f, 1, 1i); error
|
||||
var clamp_faif: f32 = clamp(1f, 1, 1f);
|
||||
var clamp_fafai: f32 = clamp(1f, 1.0, 1);
|
||||
var clamp_fafaf: f32 = clamp(1f, 1.0, 1.0);
|
||||
//var clamp_fafi: f32 = clamp(1f, 1.0, 1i); error
|
||||
var clamp_faff: f32 = clamp(1f, 1.0, 1f);
|
||||
//var clamp_fiai: f32 = clamp(1f, 1i, 1); error
|
||||
//var clamp_fiaf: f32 = clamp(1f, 1i, 1.0); error
|
||||
//var clamp_fii: f32 = clamp(1f, 1i, 1i); error
|
||||
//var clamp_fif: f32 = clamp(1f, 1i, 1f); error
|
||||
var clamp_ffai: f32 = clamp(1f, 1f, 1);
|
||||
var clamp_ffaf: f32 = clamp(1f, 1f, 1.0);
|
||||
//var clamp_ffi: f32 = clamp(1f, 1f, 1i); error
|
||||
var clamp_fff: f32 = clamp(1f, 1f, 1f);
|
||||
// end
|
||||
|
||||
|
||||
var min_aiai: i32 = min(1, 1);
|
||||
var min_aiaf: f32 = min(1, 1.0);
|
||||
var min_aii: i32 = min(1, 1i);
|
||||
var min_aif: f32 = min(1, 1f);
|
||||
var min_afai: f32 = min(1.0, 1);
|
||||
var min_afaf: f32 = min(1.0, 1.0);
|
||||
//var min_afi: f32 = min(1.0, 1i); error
|
||||
var min_aff: f32 = min(1.0, 1f);
|
||||
var min_iai: i32 = min(1i, 1);
|
||||
//var min_iaf: f32 = min(1i, 1.0); error
|
||||
var min_ii: i32 = min(1i, 1i);
|
||||
//var min_if: f32 = min(1i, 1f); error
|
||||
var min_fai: f32 = min(1f, 1);
|
||||
var min_faf: f32 = min(1f, 1.0);
|
||||
//var min_fi: f32 = min(1f, 1i); error
|
||||
var min_ff: f32 = min(1f, 1f);
|
||||
|
||||
var pow_aiai = pow(1, 1);
|
||||
var pow_aiaf = pow(1, 1.0);
|
||||
var pow_aif = pow(1, 1f);
|
||||
var pow_afai = pow(1.0, 1);
|
||||
var pow_afaf = pow(1.0, 1.0);
|
||||
var pow_aff = pow(1.0, 1f);
|
||||
var pow_fai = pow(1f, 1);
|
||||
var pow_faf = pow(1f, 1.0);
|
||||
var pow_ff = pow(1f, 1f);
|
||||
}
|
@ -266,13 +266,15 @@ fn builtin_cross_product_args() {
|
||||
use naga::{MathFunction, Module, Type, TypeInner, VectorSize};
|
||||
|
||||
// We want to ensure that the *only* problem with the code is the
|
||||
// arity of the vectors passed to `cross`. So validate two
|
||||
// versions of the module varying only in that aspect.
|
||||
// arity of the call, or the size of the vectors passed to
|
||||
// `cross`. So validate different versions of the module varying
|
||||
// only in those aspects.
|
||||
//
|
||||
// Looking at uses of the `wg_load` makes it easy to identify the
|
||||
// differences between the two variants.
|
||||
// Looking at uses of `size` and `arity` makes it easy to identify
|
||||
// the differences between the variants.
|
||||
fn variant(
|
||||
size: VectorSize,
|
||||
arity: usize,
|
||||
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
|
||||
let span = naga::Span::default();
|
||||
let mut module = Module::default();
|
||||
@ -311,9 +313,9 @@ fn builtin_cross_product_args() {
|
||||
Expression::Math {
|
||||
fun: MathFunction::Cross,
|
||||
arg: ex_zero,
|
||||
arg1: Some(ex_zero),
|
||||
arg2: None,
|
||||
arg3: None,
|
||||
arg1: (arity >= 2).then_some(ex_zero),
|
||||
arg2: (arity >= 3).then_some(ex_zero),
|
||||
arg3: (arity >= 4).then_some(ex_zero),
|
||||
},
|
||||
span,
|
||||
);
|
||||
@ -338,9 +340,14 @@ fn builtin_cross_product_args() {
|
||||
.validate(&module)
|
||||
}
|
||||
|
||||
assert!(variant(VectorSize::Bi).is_err());
|
||||
variant(VectorSize::Tri).expect("module should validate");
|
||||
assert!(variant(VectorSize::Quad).is_err());
|
||||
assert!(variant(VectorSize::Bi, 2).is_err());
|
||||
|
||||
assert!(variant(VectorSize::Tri, 1).is_err());
|
||||
variant(VectorSize::Tri, 2).expect("module should validate");
|
||||
assert!(variant(VectorSize::Tri, 3).is_err());
|
||||
assert!(variant(VectorSize::Tri, 4).is_err());
|
||||
|
||||
assert!(variant(VectorSize::Quad, 2).is_err());
|
||||
}
|
||||
|
||||
#[cfg(feature = "wgsl-in")]
|
||||
@ -758,3 +765,61 @@ fn bad_texture_dimensions_level() {
|
||||
assert!(validate("1i").is_ok());
|
||||
assert!(validate("1").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arity_check() {
|
||||
use ir::MathFunction as Mf;
|
||||
use naga::Span;
|
||||
use naga::{ir, valid};
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
|
||||
type Result = core::result::Result<naga::valid::ModuleInfo, naga::valid::ValidationError>;
|
||||
|
||||
fn validate(fun: ir::MathFunction, args: &[usize]) -> Result {
|
||||
let nowhere = Span::default();
|
||||
let mut module = ir::Module::default();
|
||||
let ty_f32 = module.types.insert(
|
||||
ir::Type {
|
||||
name: Some("f32".to_string()),
|
||||
inner: ir::TypeInner::Scalar(ir::Scalar::F32),
|
||||
},
|
||||
nowhere,
|
||||
);
|
||||
let mut f = ir::Function {
|
||||
result: Some(ir::FunctionResult {
|
||||
ty: ty_f32,
|
||||
binding: None,
|
||||
}),
|
||||
..ir::Function::default()
|
||||
};
|
||||
let ex_zero = f
|
||||
.expressions
|
||||
.append(ir::Expression::ZeroValue(ty_f32), nowhere);
|
||||
let ex_pow = f.expressions.append(
|
||||
dbg!(ir::Expression::Math {
|
||||
fun,
|
||||
arg: ex_zero,
|
||||
arg1: args.contains(&1).then_some(ex_zero),
|
||||
arg2: args.contains(&2).then_some(ex_zero),
|
||||
arg3: args.contains(&3).then_some(ex_zero),
|
||||
}),
|
||||
nowhere,
|
||||
);
|
||||
f.body = ir::Block::from_vec(vec![
|
||||
ir::Statement::Emit(naga::Range::new_from_bounds(ex_pow, ex_pow)),
|
||||
ir::Statement::Return {
|
||||
value: Some(ex_pow),
|
||||
},
|
||||
]);
|
||||
module.functions.append(f, nowhere);
|
||||
valid::Validator::new(Default::default(), valid::Capabilities::all())
|
||||
.validate(&module)
|
||||
.map_err(|err| err.into_inner()) // discard spans
|
||||
}
|
||||
|
||||
assert!(validate(Mf::Sin, &[]).is_ok());
|
||||
assert!(validate(Mf::Sin, &[1]).is_err());
|
||||
assert!(validate(Mf::Sin, &[3]).is_err());
|
||||
assert!(validate(Mf::Pow, &[1]).is_ok());
|
||||
assert!(validate(Mf::Pow, &[3]).is_err());
|
||||
}
|
||||
|
@ -179,6 +179,56 @@ fn bad_type_cast() {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cross_vec2() {
|
||||
check(
|
||||
r#"
|
||||
fn x() -> f32 {
|
||||
return cross(vec2(0., 1.), vec2(0., 1.));
|
||||
}
|
||||
"#,
|
||||
"\
|
||||
error: wrong type passed as argument #1 to `cross`
|
||||
┌─ wgsl:3:24
|
||||
│
|
||||
3 │ return cross(vec2(0., 1.), vec2(0., 1.));
|
||||
│ ^^^^^ ^^^^^^^^^^^^ argument #1 has type `vec2<{AbstractFloat}>`
|
||||
│
|
||||
= note: `cross` accepts the following types for argument #1:
|
||||
= note: allowed type: vec3<{AbstractFloat}>
|
||||
= note: allowed type: vec3<f32>
|
||||
= note: allowed type: vec3<f16>
|
||||
= note: allowed type: vec3<f64>
|
||||
|
||||
",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cross_vec4() {
|
||||
check(
|
||||
r#"
|
||||
fn x() -> f32 {
|
||||
return cross(vec4(0., 1., 2., 3.), vec4(0., 1., 2., 3.));
|
||||
}
|
||||
"#,
|
||||
"\
|
||||
error: wrong type passed as argument #1 to `cross`
|
||||
┌─ wgsl:3:24
|
||||
│
|
||||
3 │ return cross(vec4(0., 1., 2., 3.), vec4(0., 1., 2., 3.));
|
||||
│ ^^^^^ ^^^^^^^^^^^^^^^^^^^^ argument #1 has type `vec4<{AbstractFloat}>`
|
||||
│
|
||||
= note: `cross` accepts the following types for argument #1:
|
||||
= note: allowed type: vec3<{AbstractFloat}>
|
||||
= note: allowed type: vec3<f32>
|
||||
= note: allowed type: vec3<f16>
|
||||
= note: allowed type: vec3<f64>
|
||||
|
||||
",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn type_not_constructible() {
|
||||
check(
|
||||
@ -3136,3 +3186,141 @@ fn issue7165() {
|
||||
// rendering an error if the module contained spans.
|
||||
let _location = err.location("");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_argument_count() {
|
||||
check(
|
||||
"fn foo() -> f32 {
|
||||
return sin();
|
||||
}",
|
||||
r#"error: wrong number of arguments: expected 1, found 0
|
||||
┌─ wgsl:2:20
|
||||
│
|
||||
2 │ return sin();
|
||||
│ ^^^ wrong number of arguments
|
||||
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn too_many_arguments() {
|
||||
check(
|
||||
"fn foo() -> f32 {
|
||||
return sin(1.0, 2.0);
|
||||
}",
|
||||
r#"error: too many arguments passed to `sin`
|
||||
┌─ wgsl:2:20
|
||||
│
|
||||
2 │ return sin(1.0, 2.0);
|
||||
│ ^^^ ^^^ unexpected argument #2
|
||||
│
|
||||
= note: The `sin` function accepts at most 1 argument(s)
|
||||
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn too_many_arguments_2() {
|
||||
check(
|
||||
"fn foo() -> f32 {
|
||||
return distance(vec2<f32>(), 0i);
|
||||
}",
|
||||
r#"error: wrong type passed as argument #2 to `distance`
|
||||
┌─ wgsl:2:20
|
||||
│
|
||||
2 │ return distance(vec2<f32>(), 0i);
|
||||
│ ^^^^^^^^ ^^ argument #2 has type `i32`
|
||||
│
|
||||
= note: `distance` accepts the following types for argument #2:
|
||||
= note: allowed type: vec2<f32>
|
||||
= note: allowed type: vec2<f16>
|
||||
= note: allowed type: vec2<f64>
|
||||
= note: allowed type: vec3<f32>
|
||||
= note: allowed type: vec3<f16>
|
||||
= note: allowed type: vec3<f64>
|
||||
= note: allowed type: vec4<f32>
|
||||
= note: allowed type: vec4<f16>
|
||||
= note: allowed type: vec4<f64>
|
||||
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inconsistent_type() {
|
||||
check(
|
||||
"fn foo() -> f32 {
|
||||
return dot(vec4<f32>(), vec3<f32>());
|
||||
}",
|
||||
r#"error: inconsistent type passed as argument #2 to `dot`
|
||||
┌─ wgsl:2:20
|
||||
│
|
||||
2 │ return dot(vec4<f32>(), vec3<f32>());
|
||||
│ ^^^ ^^^^^^^^^^ ^^^^^^^^^^ argument #2 has type vec3<f32>
|
||||
│ │
|
||||
│ this argument has type vec4<f32>, which constrains subsequent arguments
|
||||
│
|
||||
= note: Because argument #1 has type vec4<f32>, only the following types
|
||||
= note: (or types that automatically convert to them) are accepted for argument #2:
|
||||
= note: allowed type: vec4<f32>
|
||||
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn more_inconsistent_type() {
|
||||
#[track_caller]
|
||||
fn variant(call: &str) {
|
||||
let input = format!(
|
||||
r#"
|
||||
fn f() {{ var x = {call}; }}
|
||||
"#
|
||||
);
|
||||
let result = naga::front::wgsl::parse_str(&input);
|
||||
let Err(ref err) = result else {
|
||||
panic!("expected ParseError, got {result:#?}");
|
||||
};
|
||||
if !err.message().contains("inconsistent type") {
|
||||
panic!("expected 'inconsistent type' error, got {result:#?}");
|
||||
}
|
||||
}
|
||||
|
||||
variant("min(1.0, 1i)");
|
||||
variant("min(1i, 1.0)");
|
||||
variant("min(1i, 1f)");
|
||||
variant("min(1f, 1i)");
|
||||
|
||||
variant("clamp(1, 1.0, 1i)");
|
||||
variant("clamp(1, 1i, 1.0)");
|
||||
variant("clamp(1, 1i, 1f)");
|
||||
variant("clamp(1, 1f, 1i)");
|
||||
variant("clamp(1.0, 1, 1i)");
|
||||
variant("clamp(1.0, 1.0, 1i)");
|
||||
variant("clamp(1.0, 1i, 1)");
|
||||
variant("clamp(1.0, 1i, 1.0)");
|
||||
variant("clamp(1.0, 1i, 1i)");
|
||||
variant("clamp(1.0, 1i, 1f)");
|
||||
variant("clamp(1.0, 1f, 1i)");
|
||||
variant("clamp(1i, 1, 1.0)");
|
||||
variant("clamp(1i, 1, 1f)");
|
||||
variant("clamp(1i, 1.0, 1)");
|
||||
variant("clamp(1i, 1.0, 1.0)");
|
||||
variant("clamp(1i, 1.0, 1i)");
|
||||
variant("clamp(1i, 1.0, 1f)");
|
||||
variant("clamp(1i, 1i, 1.0)");
|
||||
variant("clamp(1i, 1i, 1f)");
|
||||
variant("clamp(1i, 1f, 1)");
|
||||
variant("clamp(1i, 1f, 1.0)");
|
||||
variant("clamp(1i, 1f, 1i)");
|
||||
variant("clamp(1i, 1f, 1f)");
|
||||
variant("clamp(1f, 1, 1i)");
|
||||
variant("clamp(1f, 1.0, 1i)");
|
||||
variant("clamp(1f, 1i, 1)");
|
||||
variant("clamp(1f, 1i, 1.0)");
|
||||
variant("clamp(1f, 1i, 1i)");
|
||||
variant("clamp(1f, 1i, 1f)");
|
||||
variant("clamp(1f, 1f, 1i)");
|
||||
}
|
||||
|
@ -868,7 +868,13 @@
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Handle(1),
|
||||
ty: Value(Vector(
|
||||
size: Tri,
|
||||
scalar: (
|
||||
kind: Float,
|
||||
width: 4,
|
||||
),
|
||||
)),
|
||||
),
|
||||
(
|
||||
uniformity: (
|
||||
@ -1231,7 +1237,13 @@
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Handle(1),
|
||||
ty: Value(Vector(
|
||||
size: Tri,
|
||||
scalar: (
|
||||
kind: Float,
|
||||
width: 4,
|
||||
),
|
||||
)),
|
||||
),
|
||||
(
|
||||
uniformity: (
|
||||
@ -1252,7 +1264,10 @@
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Handle(0),
|
||||
ty: Value(Scalar((
|
||||
kind: Float,
|
||||
width: 4,
|
||||
))),
|
||||
),
|
||||
(
|
||||
uniformity: (
|
||||
@ -1261,7 +1276,10 @@
|
||||
),
|
||||
ref_count: 1,
|
||||
assignable_global: None,
|
||||
ty: Handle(0),
|
||||
ty: Value(Scalar((
|
||||
kind: Float,
|
||||
width: 4,
|
||||
))),
|
||||
),
|
||||
(
|
||||
uniformity: (
|
||||
|
66
naga/tests/out/msl/abstract-types-builtins.msl
Normal file
66
naga/tests/out/msl/abstract-types-builtins.msl
Normal file
@ -0,0 +1,66 @@
|
||||
// language: metal1.0
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using metal::uint;
|
||||
|
||||
|
||||
void f(
|
||||
) {
|
||||
int clamp_aiaiai = 1;
|
||||
float clamp_aiaiaf = 1.0;
|
||||
int clamp_aiaii = 1;
|
||||
float clamp_aiaif = 1.0;
|
||||
float clamp_aiafai = 1.0;
|
||||
float clamp_aiafaf = 1.0;
|
||||
float clamp_aiaff = 1.0;
|
||||
int clamp_aiiai = 1;
|
||||
int clamp_aiii = 1;
|
||||
float clamp_aifai = 1.0;
|
||||
float clamp_aifaf = 1.0;
|
||||
float clamp_aiff = 1.0;
|
||||
float clamp_afaiai = 1.0;
|
||||
float clamp_afaiaf = 1.0;
|
||||
float clamp_afaif = 1.0;
|
||||
float clamp_afafai = 1.0;
|
||||
float clamp_afafaf = 1.0;
|
||||
float clamp_afaff = 1.0;
|
||||
float clamp_affai = 1.0;
|
||||
float clamp_affaf = 1.0;
|
||||
float clamp_afff = 1.0;
|
||||
int clamp_iaiai = 1;
|
||||
int clamp_iaii = 1;
|
||||
int clamp_iiai = 1;
|
||||
int clamp_iii = 1;
|
||||
float clamp_faiai = 1.0;
|
||||
float clamp_faiaf = 1.0;
|
||||
float clamp_faif = 1.0;
|
||||
float clamp_fafai = 1.0;
|
||||
float clamp_fafaf = 1.0;
|
||||
float clamp_faff = 1.0;
|
||||
float clamp_ffai = 1.0;
|
||||
float clamp_ffaf = 1.0;
|
||||
float clamp_fff = 1.0;
|
||||
int min_aiai = 1;
|
||||
float min_aiaf = 1.0;
|
||||
int min_aii = 1;
|
||||
float min_aif = 1.0;
|
||||
float min_afai = 1.0;
|
||||
float min_afaf = 1.0;
|
||||
float min_aff = 1.0;
|
||||
int min_iai = 1;
|
||||
int min_ii = 1;
|
||||
float min_fai = 1.0;
|
||||
float min_faf = 1.0;
|
||||
float min_ff = 1.0;
|
||||
float pow_aiai = 1.0;
|
||||
float pow_aiaf = 1.0;
|
||||
float pow_aif = 1.0;
|
||||
float pow_afai = 1.0;
|
||||
float pow_afaf = 1.0;
|
||||
float pow_aff = 1.0;
|
||||
float pow_fai = 1.0;
|
||||
float pow_faf = 1.0;
|
||||
float pow_ff = 1.0;
|
||||
return;
|
||||
}
|
77
naga/tests/out/spv/abstract-types-builtins.spvasm
Normal file
77
naga/tests/out/spv/abstract-types-builtins.spvasm
Normal file
@ -0,0 +1,77 @@
|
||||
; SPIR-V
|
||||
; Version: 1.1
|
||||
; Generator: rspirv
|
||||
; Bound: 68
|
||||
OpCapability Shader
|
||||
OpCapability Linkage
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeInt 32 1
|
||||
%4 = OpTypeFloat 32
|
||||
%7 = OpTypeFunction %2
|
||||
%8 = OpConstant %3 1
|
||||
%9 = OpConstant %4 1.0
|
||||
%11 = OpTypePointer Function %3
|
||||
%13 = OpTypePointer Function %4
|
||||
%6 = OpFunction %2 None %7
|
||||
%5 = OpLabel
|
||||
%66 = OpVariable %13 Function %9
|
||||
%63 = OpVariable %13 Function %9
|
||||
%60 = OpVariable %13 Function %9
|
||||
%57 = OpVariable %13 Function %9
|
||||
%54 = OpVariable %11 Function %8
|
||||
%51 = OpVariable %13 Function %9
|
||||
%48 = OpVariable %11 Function %8
|
||||
%45 = OpVariable %13 Function %9
|
||||
%42 = OpVariable %13 Function %9
|
||||
%39 = OpVariable %13 Function %9
|
||||
%36 = OpVariable %11 Function %8
|
||||
%33 = OpVariable %11 Function %8
|
||||
%30 = OpVariable %13 Function %9
|
||||
%27 = OpVariable %13 Function %9
|
||||
%24 = OpVariable %13 Function %9
|
||||
%21 = OpVariable %13 Function %9
|
||||
%18 = OpVariable %13 Function %9
|
||||
%15 = OpVariable %13 Function %9
|
||||
%10 = OpVariable %11 Function %8
|
||||
%64 = OpVariable %13 Function %9
|
||||
%61 = OpVariable %13 Function %9
|
||||
%58 = OpVariable %13 Function %9
|
||||
%55 = OpVariable %13 Function %9
|
||||
%52 = OpVariable %13 Function %9
|
||||
%49 = OpVariable %13 Function %9
|
||||
%46 = OpVariable %11 Function %8
|
||||
%43 = OpVariable %13 Function %9
|
||||
%40 = OpVariable %13 Function %9
|
||||
%37 = OpVariable %13 Function %9
|
||||
%34 = OpVariable %11 Function %8
|
||||
%31 = OpVariable %13 Function %9
|
||||
%28 = OpVariable %13 Function %9
|
||||
%25 = OpVariable %13 Function %9
|
||||
%22 = OpVariable %13 Function %9
|
||||
%19 = OpVariable %11 Function %8
|
||||
%16 = OpVariable %13 Function %9
|
||||
%12 = OpVariable %13 Function %9
|
||||
%65 = OpVariable %13 Function %9
|
||||
%62 = OpVariable %13 Function %9
|
||||
%59 = OpVariable %13 Function %9
|
||||
%56 = OpVariable %13 Function %9
|
||||
%53 = OpVariable %11 Function %8
|
||||
%50 = OpVariable %13 Function %9
|
||||
%47 = OpVariable %13 Function %9
|
||||
%44 = OpVariable %13 Function %9
|
||||
%41 = OpVariable %13 Function %9
|
||||
%38 = OpVariable %13 Function %9
|
||||
%35 = OpVariable %11 Function %8
|
||||
%32 = OpVariable %13 Function %9
|
||||
%29 = OpVariable %13 Function %9
|
||||
%26 = OpVariable %13 Function %9
|
||||
%23 = OpVariable %13 Function %9
|
||||
%20 = OpVariable %11 Function %8
|
||||
%17 = OpVariable %13 Function %9
|
||||
%14 = OpVariable %11 Function %8
|
||||
OpBranch %67
|
||||
%67 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
@ -492,8 +492,9 @@ OpLine %3 93 25
|
||||
%200 = OpVectorShuffle %11 %199 %199 0 1 2
|
||||
%201 = OpFSub %11 %198 %200
|
||||
%202 = OpExtInst %11 %1 Normalize %201
|
||||
OpLine %3 94 23
|
||||
OpLine %3 94 32
|
||||
%203 = OpDot %4 %155 %202
|
||||
OpLine %3 94 23
|
||||
%204 = OpExtInst %4 %1 FMax %48 %203
|
||||
OpLine %3 96 9
|
||||
%205 = OpFMul %4 %196 %204
|
||||
@ -599,8 +600,9 @@ OpLine %3 112 25
|
||||
%276 = OpVectorShuffle %11 %275 %275 0 1 2
|
||||
%277 = OpFSub %11 %274 %276
|
||||
%278 = OpExtInst %11 %1 Normalize %277
|
||||
OpLine %3 113 23
|
||||
OpLine %3 113 32
|
||||
%279 = OpDot %4 %239 %278
|
||||
OpLine %3 113 23
|
||||
%280 = OpExtInst %4 %1 FMax %48 %279
|
||||
OpLine %3 114 9
|
||||
%281 = OpFMul %4 %272 %280
|
||||
|
60
naga/tests/out/wgsl/abstract-types-builtins.wgsl
Normal file
60
naga/tests/out/wgsl/abstract-types-builtins.wgsl
Normal file
@ -0,0 +1,60 @@
|
||||
fn f() {
|
||||
var clamp_aiaiai: i32 = 1i;
|
||||
var clamp_aiaiaf: f32 = 1f;
|
||||
var clamp_aiaii: i32 = 1i;
|
||||
var clamp_aiaif: f32 = 1f;
|
||||
var clamp_aiafai: f32 = 1f;
|
||||
var clamp_aiafaf: f32 = 1f;
|
||||
var clamp_aiaff: f32 = 1f;
|
||||
var clamp_aiiai: i32 = 1i;
|
||||
var clamp_aiii: i32 = 1i;
|
||||
var clamp_aifai: f32 = 1f;
|
||||
var clamp_aifaf: f32 = 1f;
|
||||
var clamp_aiff: f32 = 1f;
|
||||
var clamp_afaiai: f32 = 1f;
|
||||
var clamp_afaiaf: f32 = 1f;
|
||||
var clamp_afaif: f32 = 1f;
|
||||
var clamp_afafai: f32 = 1f;
|
||||
var clamp_afafaf: f32 = 1f;
|
||||
var clamp_afaff: f32 = 1f;
|
||||
var clamp_affai: f32 = 1f;
|
||||
var clamp_affaf: f32 = 1f;
|
||||
var clamp_afff: f32 = 1f;
|
||||
var clamp_iaiai: i32 = 1i;
|
||||
var clamp_iaii: i32 = 1i;
|
||||
var clamp_iiai: i32 = 1i;
|
||||
var clamp_iii: i32 = 1i;
|
||||
var clamp_faiai: f32 = 1f;
|
||||
var clamp_faiaf: f32 = 1f;
|
||||
var clamp_faif: f32 = 1f;
|
||||
var clamp_fafai: f32 = 1f;
|
||||
var clamp_fafaf: f32 = 1f;
|
||||
var clamp_faff: f32 = 1f;
|
||||
var clamp_ffai: f32 = 1f;
|
||||
var clamp_ffaf: f32 = 1f;
|
||||
var clamp_fff: f32 = 1f;
|
||||
var min_aiai: i32 = 1i;
|
||||
var min_aiaf: f32 = 1f;
|
||||
var min_aii: i32 = 1i;
|
||||
var min_aif: f32 = 1f;
|
||||
var min_afai: f32 = 1f;
|
||||
var min_afaf: f32 = 1f;
|
||||
var min_aff: f32 = 1f;
|
||||
var min_iai: i32 = 1i;
|
||||
var min_ii: i32 = 1i;
|
||||
var min_fai: f32 = 1f;
|
||||
var min_faf: f32 = 1f;
|
||||
var min_ff: f32 = 1f;
|
||||
var pow_aiai: f32 = 1f;
|
||||
var pow_aiaf: f32 = 1f;
|
||||
var pow_aif: f32 = 1f;
|
||||
var pow_afai: f32 = 1f;
|
||||
var pow_afaf: f32 = 1f;
|
||||
var pow_aff: f32 = 1f;
|
||||
var pow_fai: f32 = 1f;
|
||||
var pow_faf: f32 = 1f;
|
||||
var pow_ff: f32 = 1f;
|
||||
|
||||
return;
|
||||
}
|
||||
|
@ -6,6 +6,8 @@ extend-exclude = [
|
||||
# spirv-asm isn't real source code
|
||||
'*.spvasm',
|
||||
'docs/big-picture.xml',
|
||||
# This test has weird pattern-derived variable names.
|
||||
'naga/tests/in/wgsl/abstract-types-builtins.wgsl',
|
||||
]
|
||||
|
||||
# Corrections take the form of a key/value pair. The key is the incorrect word
|
||||
|
Loading…
Reference in New Issue
Block a user