mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-25 08:13:27 +00:00
refactor(const_eval): add component_wise_float
helper, reimpl. math_pow
This commit is contained in:
parent
c2058487ca
commit
5c900f2568
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2055,6 +2055,7 @@ name = "naga"
|
||||
version = "0.19.0"
|
||||
dependencies = [
|
||||
"arbitrary",
|
||||
"arrayvec 0.7.4",
|
||||
"bincode",
|
||||
"bit-set",
|
||||
"bitflags 2.4.1",
|
||||
|
@ -60,6 +60,7 @@ petgraph = { version = "0.6", optional = true }
|
||||
pp-rs = { version = "0.2.1", optional = true }
|
||||
hexf-parse = { version = "0.2.1", optional = true }
|
||||
unicode-xid = { version = "0.2.3", optional = true }
|
||||
arrayvec.workspace = true
|
||||
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
|
||||
criterion = { version = "0.5", features = [] }
|
||||
|
@ -458,6 +458,10 @@ pub enum VectorSize {
|
||||
Quad = 4,
|
||||
}
|
||||
|
||||
impl VectorSize {
|
||||
const MAX: usize = Self::Quad as u8 as usize;
|
||||
}
|
||||
|
||||
/// Primitive type for a scalar.
|
||||
#[repr(u8)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
|
@ -1,9 +1,209 @@
|
||||
use std::iter;
|
||||
|
||||
use arrayvec::ArrayVec;
|
||||
|
||||
use crate::{
|
||||
arena::{Arena, Handle, UniqueArena},
|
||||
ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner,
|
||||
UnaryOperator,
|
||||
};
|
||||
|
||||
/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
|
||||
/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
|
||||
///
|
||||
/// Technique stolen directly from
|
||||
/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
|
||||
macro_rules! with_dollar_sign {
|
||||
($($body:tt)*) => {
|
||||
macro_rules! __with_dollar_sign { $($body)* }
|
||||
__with_dollar_sign!($);
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! gen_component_wise_extractor {
|
||||
(
|
||||
$ident:ident -> $target:ident,
|
||||
literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
|
||||
scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
|
||||
) => {
|
||||
/// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
|
||||
enum $target<const N: usize> {
|
||||
$(
|
||||
#[doc = concat!(
|
||||
"Maps to [`Literal::",
|
||||
stringify!($mapping),
|
||||
"`]",
|
||||
)]
|
||||
$mapping([$ty; N]),
|
||||
)+
|
||||
}
|
||||
|
||||
impl From<$target<1>> for Expression {
|
||||
fn from(value: $target<1>) -> Self {
|
||||
match value {
|
||||
$(
|
||||
$target::$mapping([value]) => {
|
||||
Expression::Literal(Literal::$literal(value))
|
||||
}
|
||||
)+
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc = concat!(
|
||||
"Attempts to evaluate multiple `exprs` as a combined [`",
|
||||
stringify!($target),
|
||||
"`] to pass to `handler`. ",
|
||||
)]
|
||||
/// If `exprs` are vectors of the same length, `handler` is called for each corresponding
|
||||
/// component of each vector.
|
||||
///
|
||||
/// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
|
||||
/// same length, a new vector expression is registered, composed of each component emitted
|
||||
/// by `handler`.
|
||||
fn $ident<const N: usize, const M: usize, F>(
|
||||
eval: &mut ConstantEvaluator<'_>,
|
||||
span: Span,
|
||||
exprs: [Handle<Expression>; N],
|
||||
mut handler: F,
|
||||
) -> Result<Handle<Expression>, ConstantEvaluatorError>
|
||||
where
|
||||
$target<M>: Into<Expression>,
|
||||
F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
|
||||
{
|
||||
assert!(N > 0);
|
||||
let err = ConstantEvaluatorError::InvalidMathArg;
|
||||
let mut exprs = exprs.into_iter();
|
||||
|
||||
macro_rules! sanitize {
|
||||
($expr:expr) => {
|
||||
eval.eval_zero_value_and_splat($expr, span)
|
||||
.map(|expr| &eval.expressions[expr])
|
||||
};
|
||||
}
|
||||
|
||||
let new_expr = match sanitize!(exprs.next().unwrap())? {
|
||||
$(
|
||||
&Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
|
||||
.chain(exprs.map(|expr| {
|
||||
sanitize!(expr).and_then(|expr| match expr {
|
||||
&Expression::Literal(Literal::$literal(x)) => Ok(x),
|
||||
_ => Err(err.clone()),
|
||||
})
|
||||
}))
|
||||
.collect::<Result<ArrayVec<_, N>, _>>()
|
||||
.map(|a| a.into_inner().unwrap())
|
||||
.map($target::$mapping)
|
||||
.and_then(|comps| Ok(handler(comps)?.into())),
|
||||
)+
|
||||
&Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
|
||||
&TypeInner::Vector { size: _, scalar } => match scalar.kind {
|
||||
$(ScalarKind::$scalar_kind)|* => {
|
||||
let first_ty = ty;
|
||||
let mut component_groups =
|
||||
ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
|
||||
component_groups.push(crate::proc::flatten_compose(
|
||||
first_ty,
|
||||
components,
|
||||
eval.expressions,
|
||||
eval.types,
|
||||
).collect());
|
||||
component_groups.extend(
|
||||
exprs
|
||||
.map(|expr| {
|
||||
sanitize!(expr).and_then(|expr| match expr {
|
||||
&Expression::Compose { ty, ref components }
|
||||
if &eval.types[ty].inner
|
||||
== &eval.types[first_ty].inner =>
|
||||
{
|
||||
Ok(crate::proc::flatten_compose(
|
||||
ty,
|
||||
components,
|
||||
eval.expressions,
|
||||
eval.types,
|
||||
).collect())
|
||||
}
|
||||
_ => Err(err.clone()),
|
||||
})
|
||||
})
|
||||
.collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
|
||||
)?,
|
||||
);
|
||||
let component_groups = component_groups.into_inner().unwrap();
|
||||
let mut new_components =
|
||||
ArrayVec::<_, { crate::VectorSize::MAX }>::new();
|
||||
for idx in 0..N {
|
||||
let group = component_groups
|
||||
.iter()
|
||||
.map(|cs| cs[idx])
|
||||
.collect::<ArrayVec<_, N>>()
|
||||
.into_inner()
|
||||
.unwrap();
|
||||
new_components.push($ident(
|
||||
eval,
|
||||
span,
|
||||
group,
|
||||
handler.clone(),
|
||||
)?);
|
||||
}
|
||||
Ok(Expression::Compose {
|
||||
ty: first_ty,
|
||||
components: new_components.into_iter().collect(),
|
||||
})
|
||||
}
|
||||
_ => return Err(err),
|
||||
},
|
||||
_ => return Err(err),
|
||||
},
|
||||
_ => return Err(err),
|
||||
}?;
|
||||
eval.register_evaluated_expr(new_expr, span)
|
||||
}
|
||||
|
||||
with_dollar_sign! {
|
||||
($d:tt) => {
|
||||
#[allow(unused)]
|
||||
#[doc = concat!(
|
||||
"A convenience macro for using the same RHS for each [`",
|
||||
stringify!($target),
|
||||
"`] variant in a call to [`",
|
||||
stringify!($ident),
|
||||
"`].",
|
||||
)]
|
||||
macro_rules! $ident {
|
||||
(
|
||||
$eval:expr,
|
||||
$span:expr,
|
||||
[$d ($d expr:expr),+ $d (,)?],
|
||||
|$d ($d arg:ident),+| $d tt:tt
|
||||
) => {
|
||||
$ident($eval, $span, [$d ($d expr),+], |args| match args {
|
||||
$(
|
||||
$target::$mapping([$d ($d arg),+]) => {
|
||||
let res = $d tt;
|
||||
Result::map(res, $target::$mapping)
|
||||
},
|
||||
)+
|
||||
})
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
gen_component_wise_extractor! {
|
||||
component_wise_float -> Float,
|
||||
literals: [
|
||||
AbstractFloat => Abstract: f64,
|
||||
F32 => F32: f32,
|
||||
],
|
||||
scalar_kinds: [
|
||||
Float,
|
||||
AbstractFloat,
|
||||
],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Behavior {
|
||||
Wgsl,
|
||||
@ -606,64 +806,7 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
e2: Handle<Expression>,
|
||||
span: Span,
|
||||
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
|
||||
let e1 = self.eval_zero_value_and_splat(e1, span)?;
|
||||
let e2 = self.eval_zero_value_and_splat(e2, span)?;
|
||||
|
||||
let expr = match (&self.expressions[e1], &self.expressions[e2]) {
|
||||
(&Expression::Literal(Literal::F32(a)), &Expression::Literal(Literal::F32(b))) => {
|
||||
Expression::Literal(Literal::F32(a.powf(b)))
|
||||
}
|
||||
(
|
||||
&Expression::Compose {
|
||||
components: ref src_components0,
|
||||
ty: ty0,
|
||||
},
|
||||
&Expression::Compose {
|
||||
components: ref src_components1,
|
||||
ty: ty1,
|
||||
},
|
||||
) if ty0 == ty1
|
||||
&& matches!(
|
||||
self.types[ty0].inner,
|
||||
crate::TypeInner::Vector {
|
||||
scalar: crate::Scalar {
|
||||
kind: ScalarKind::Float,
|
||||
..
|
||||
},
|
||||
..
|
||||
}
|
||||
) =>
|
||||
{
|
||||
let mut components: Vec<_> = crate::proc::flatten_compose(
|
||||
ty0,
|
||||
src_components0,
|
||||
self.expressions,
|
||||
self.types,
|
||||
)
|
||||
.chain(crate::proc::flatten_compose(
|
||||
ty1,
|
||||
src_components1,
|
||||
self.expressions,
|
||||
self.types,
|
||||
))
|
||||
.collect();
|
||||
|
||||
let mid = components.len() / 2;
|
||||
let (first, last) = components.split_at_mut(mid);
|
||||
for (a, b) in first.iter_mut().zip(&*last) {
|
||||
*a = self.math_pow(*a, *b, span)?;
|
||||
}
|
||||
components.truncate(mid);
|
||||
|
||||
Expression::Compose {
|
||||
ty: ty0,
|
||||
components,
|
||||
}
|
||||
}
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
};
|
||||
|
||||
self.register_evaluated_expr(expr, span)
|
||||
component_wise_float!(self, span, [e1, e2], |e1, e2| { Ok([e1.powf(e2)]) })
|
||||
}
|
||||
|
||||
fn math_clamp(
|
||||
|
Loading…
Reference in New Issue
Block a user