refactor(const_eval): add component_wise_float helper, reimpl. math_pow

This commit is contained in:
Erich Gubler 2023-12-14 10:37:43 -05:00
parent c2058487ca
commit 5c900f2568
4 changed files with 207 additions and 58 deletions

1
Cargo.lock generated
View File

@ -2055,6 +2055,7 @@ name = "naga"
version = "0.19.0"
dependencies = [
"arbitrary",
"arrayvec 0.7.4",
"bincode",
"bit-set",
"bitflags 2.4.1",

View File

@ -60,6 +60,7 @@ petgraph = { version = "0.6", optional = true }
pp-rs = { version = "0.2.1", optional = true }
hexf-parse = { version = "0.2.1", optional = true }
unicode-xid = { version = "0.2.3", optional = true }
arrayvec.workspace = true
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
criterion = { version = "0.5", features = [] }

View File

@ -458,6 +458,10 @@ pub enum VectorSize {
Quad = 4,
}
impl VectorSize {
const MAX: usize = Self::Quad as u8 as usize;
}
/// Primitive type for a scalar.
#[repr(u8)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]

View File

@ -1,9 +1,209 @@
use std::iter;
use arrayvec::ArrayVec;
use crate::{
arena::{Arena, Handle, UniqueArena},
ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner,
UnaryOperator,
};
/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
///
/// Technique stolen directly from
/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
macro_rules! with_dollar_sign {
($($body:tt)*) => {
macro_rules! __with_dollar_sign { $($body)* }
__with_dollar_sign!($);
}
}
macro_rules! gen_component_wise_extractor {
(
$ident:ident -> $target:ident,
literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
) => {
/// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
enum $target<const N: usize> {
$(
#[doc = concat!(
"Maps to [`Literal::",
stringify!($mapping),
"`]",
)]
$mapping([$ty; N]),
)+
}
impl From<$target<1>> for Expression {
fn from(value: $target<1>) -> Self {
match value {
$(
$target::$mapping([value]) => {
Expression::Literal(Literal::$literal(value))
}
)+
}
}
}
#[doc = concat!(
"Attempts to evaluate multiple `exprs` as a combined [`",
stringify!($target),
"`] to pass to `handler`. ",
)]
/// If `exprs` are vectors of the same length, `handler` is called for each corresponding
/// component of each vector.
///
/// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
/// same length, a new vector expression is registered, composed of each component emitted
/// by `handler`.
fn $ident<const N: usize, const M: usize, F>(
eval: &mut ConstantEvaluator<'_>,
span: Span,
exprs: [Handle<Expression>; N],
mut handler: F,
) -> Result<Handle<Expression>, ConstantEvaluatorError>
where
$target<M>: Into<Expression>,
F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
{
assert!(N > 0);
let err = ConstantEvaluatorError::InvalidMathArg;
let mut exprs = exprs.into_iter();
macro_rules! sanitize {
($expr:expr) => {
eval.eval_zero_value_and_splat($expr, span)
.map(|expr| &eval.expressions[expr])
};
}
let new_expr = match sanitize!(exprs.next().unwrap())? {
$(
&Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
.chain(exprs.map(|expr| {
sanitize!(expr).and_then(|expr| match expr {
&Expression::Literal(Literal::$literal(x)) => Ok(x),
_ => Err(err.clone()),
})
}))
.collect::<Result<ArrayVec<_, N>, _>>()
.map(|a| a.into_inner().unwrap())
.map($target::$mapping)
.and_then(|comps| Ok(handler(comps)?.into())),
)+
&Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
&TypeInner::Vector { size: _, scalar } => match scalar.kind {
$(ScalarKind::$scalar_kind)|* => {
let first_ty = ty;
let mut component_groups =
ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
component_groups.push(crate::proc::flatten_compose(
first_ty,
components,
eval.expressions,
eval.types,
).collect());
component_groups.extend(
exprs
.map(|expr| {
sanitize!(expr).and_then(|expr| match expr {
&Expression::Compose { ty, ref components }
if &eval.types[ty].inner
== &eval.types[first_ty].inner =>
{
Ok(crate::proc::flatten_compose(
ty,
components,
eval.expressions,
eval.types,
).collect())
}
_ => Err(err.clone()),
})
})
.collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
)?,
);
let component_groups = component_groups.into_inner().unwrap();
let mut new_components =
ArrayVec::<_, { crate::VectorSize::MAX }>::new();
for idx in 0..N {
let group = component_groups
.iter()
.map(|cs| cs[idx])
.collect::<ArrayVec<_, N>>()
.into_inner()
.unwrap();
new_components.push($ident(
eval,
span,
group,
handler.clone(),
)?);
}
Ok(Expression::Compose {
ty: first_ty,
components: new_components.into_iter().collect(),
})
}
_ => return Err(err),
},
_ => return Err(err),
},
_ => return Err(err),
}?;
eval.register_evaluated_expr(new_expr, span)
}
with_dollar_sign! {
($d:tt) => {
#[allow(unused)]
#[doc = concat!(
"A convenience macro for using the same RHS for each [`",
stringify!($target),
"`] variant in a call to [`",
stringify!($ident),
"`].",
)]
macro_rules! $ident {
(
$eval:expr,
$span:expr,
[$d ($d expr:expr),+ $d (,)?],
|$d ($d arg:ident),+| $d tt:tt
) => {
$ident($eval, $span, [$d ($d expr),+], |args| match args {
$(
$target::$mapping([$d ($d arg),+]) => {
let res = $d tt;
Result::map(res, $target::$mapping)
},
)+
})
};
}
};
}
};
}
gen_component_wise_extractor! {
component_wise_float -> Float,
literals: [
AbstractFloat => Abstract: f64,
F32 => F32: f32,
],
scalar_kinds: [
Float,
AbstractFloat,
],
}
#[derive(Debug)]
enum Behavior {
Wgsl,
@ -606,64 +806,7 @@ impl<'a> ConstantEvaluator<'a> {
e2: Handle<Expression>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let e1 = self.eval_zero_value_and_splat(e1, span)?;
let e2 = self.eval_zero_value_and_splat(e2, span)?;
let expr = match (&self.expressions[e1], &self.expressions[e2]) {
(&Expression::Literal(Literal::F32(a)), &Expression::Literal(Literal::F32(b))) => {
Expression::Literal(Literal::F32(a.powf(b)))
}
(
&Expression::Compose {
components: ref src_components0,
ty: ty0,
},
&Expression::Compose {
components: ref src_components1,
ty: ty1,
},
) if ty0 == ty1
&& matches!(
self.types[ty0].inner,
crate::TypeInner::Vector {
scalar: crate::Scalar {
kind: ScalarKind::Float,
..
},
..
}
) =>
{
let mut components: Vec<_> = crate::proc::flatten_compose(
ty0,
src_components0,
self.expressions,
self.types,
)
.chain(crate::proc::flatten_compose(
ty1,
src_components1,
self.expressions,
self.types,
))
.collect();
let mid = components.len() / 2;
let (first, last) = components.split_at_mut(mid);
for (a, b) in first.iter_mut().zip(&*last) {
*a = self.math_pow(*a, *b, span)?;
}
components.truncate(mid);
Expression::Compose {
ty: ty0,
components,
}
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
};
self.register_evaluated_expr(expr, span)
component_wise_float!(self, span, [e1, e2], |e1, e2| { Ok([e1.powf(e2)]) })
}
fn math_clamp(