mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-22 14:55:05 +00:00
LiteralVector
and some demo
Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com>
This commit is contained in:
parent
79a6f2cd31
commit
52560ca8c7
@ -254,6 +254,254 @@ gen_component_wise_extractor! {
|
||||
],
|
||||
}
|
||||
|
||||
/// Vector for each [`Literal`] type
|
||||
///
|
||||
/// This type ensures that all elements have same type
|
||||
enum LiteralVector {
|
||||
F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
|
||||
F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
|
||||
U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
|
||||
I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
|
||||
U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
|
||||
I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
|
||||
Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
|
||||
AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
|
||||
AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
|
||||
}
|
||||
|
||||
impl LiteralVector {
|
||||
#[allow(clippy::pattern_type_mismatch)]
|
||||
const fn len(&self) -> usize {
|
||||
match self {
|
||||
LiteralVector::F64(v) => v.len(),
|
||||
LiteralVector::F32(v) => v.len(),
|
||||
LiteralVector::U32(v) => v.len(),
|
||||
LiteralVector::I32(v) => v.len(),
|
||||
LiteralVector::U64(v) => v.len(),
|
||||
LiteralVector::I64(v) => v.len(),
|
||||
LiteralVector::Bool(v) => v.len(),
|
||||
LiteralVector::AbstractInt(v) => v.len(),
|
||||
LiteralVector::AbstractFloat(v) => v.len(),
|
||||
}
|
||||
}
|
||||
/// Creates [`LiteralVector`] of size 1 from single [`Literal`]
|
||||
fn from_literal(literal: Literal) -> Self {
|
||||
match literal {
|
||||
Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
|
||||
Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
|
||||
Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
|
||||
Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
|
||||
Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
|
||||
Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
|
||||
Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
|
||||
Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
|
||||
Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates [`LiteralVector`] from Array of [`Literal`]s
|
||||
///
|
||||
/// Panics if vector is empty
|
||||
fn from_literal_vec(
|
||||
components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
|
||||
) -> Result<Self, ConstantEvaluatorError> {
|
||||
let scalar = components[0].scalar();
|
||||
Self::from_literal_vec_with_scalar_type(components, scalar)
|
||||
}
|
||||
|
||||
/// Creates [`LiteralVector`] of type provided by scalar from Array of [`Literal`]s
|
||||
///
|
||||
/// Panics if vector is empty, returns error if types do not match
|
||||
fn from_literal_vec_with_scalar_type(
|
||||
components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
|
||||
scalar: crate::Scalar,
|
||||
) -> Result<Self, ConstantEvaluatorError> {
|
||||
assert!(!components.is_empty());
|
||||
Ok(match scalar {
|
||||
crate::Scalar::I32 => Self::I32(
|
||||
components
|
||||
.iter()
|
||||
.map(|l| match l {
|
||||
&Literal::I32(v) => Ok(v),
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
crate::Scalar::U32 => Self::U32(
|
||||
components
|
||||
.iter()
|
||||
.map(|l| match l {
|
||||
&Literal::U32(v) => Ok(v),
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
crate::Scalar::I64 => Self::I64(
|
||||
components
|
||||
.iter()
|
||||
.map(|l| match l {
|
||||
&Literal::I64(v) => Ok(v),
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
crate::Scalar::U64 => Self::U64(
|
||||
components
|
||||
.iter()
|
||||
.map(|l| match l {
|
||||
&Literal::U64(v) => Ok(v),
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
crate::Scalar::F32 => Self::F32(
|
||||
components
|
||||
.iter()
|
||||
.map(|l| match l {
|
||||
&Literal::F32(v) => Ok(v),
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
crate::Scalar::F64 => Self::F64(
|
||||
components
|
||||
.iter()
|
||||
.map(|l| match l {
|
||||
&Literal::F64(v) => Ok(v),
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
crate::Scalar::BOOL => Self::Bool(
|
||||
components
|
||||
.iter()
|
||||
.map(|l| match l {
|
||||
&Literal::Bool(v) => Ok(v),
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
crate::Scalar::ABSTRACT_INT => Self::AbstractInt(
|
||||
components
|
||||
.iter()
|
||||
.map(|l| match l {
|
||||
&Literal::AbstractInt(v) => Ok(v),
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
crate::Scalar::ABSTRACT_FLOAT => Self::AbstractFloat(
|
||||
components
|
||||
.iter()
|
||||
.map(|l| match l {
|
||||
&Literal::AbstractFloat(v) => Ok(v),
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
}
|
||||
|
||||
fn from_expr(
|
||||
expr: Handle<Expression>,
|
||||
eval: &mut ConstantEvaluator<'_>,
|
||||
span: Span,
|
||||
allow_single: bool,
|
||||
) -> Result<Self, ConstantEvaluatorError> {
|
||||
let expr = eval
|
||||
.eval_zero_value_and_splat(expr, span)
|
||||
.map(|expr| &eval.expressions[expr])?;
|
||||
match *expr {
|
||||
Expression::Literal(literal) => {
|
||||
if allow_single {
|
||||
Ok(Self::from_literal(literal))
|
||||
} else {
|
||||
Err(ConstantEvaluatorError::InvalidMathArg)
|
||||
}
|
||||
}
|
||||
Expression::Compose { ty, ref components } => match eval.types[ty].inner {
|
||||
TypeInner::Vector { scalar, .. } => {
|
||||
if components.len() > crate::VectorSize::MAX {
|
||||
return Err(ConstantEvaluatorError::InvalidMathArg);
|
||||
}
|
||||
let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
|
||||
crate::proc::flatten_compose(ty, components, eval.expressions, eval.types)
|
||||
.map(|expr| match eval.expressions[expr] {
|
||||
Expression::Literal(l) => Ok(l),
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.collect::<Result<_, ConstantEvaluatorError>>()?;
|
||||
Self::from_literal_vec_with_scalar_type(components, scalar)
|
||||
}
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
},
|
||||
_ => Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns [`ArrayVec`] of [`Literals`]
|
||||
fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
|
||||
#[allow(clippy::pattern_type_mismatch)]
|
||||
match self {
|
||||
LiteralVector::F64(v) => v.iter().map(|e| (Literal::F64(*e))).collect(),
|
||||
LiteralVector::F32(v) => v.iter().map(|e| (Literal::F32(*e))).collect(),
|
||||
LiteralVector::U32(v) => v.iter().map(|e| (Literal::U32(*e))).collect(),
|
||||
LiteralVector::I32(v) => v.iter().map(|e| (Literal::I32(*e))).collect(),
|
||||
LiteralVector::U64(v) => v.iter().map(|e| (Literal::U64(*e))).collect(),
|
||||
LiteralVector::I64(v) => v.iter().map(|e| (Literal::I64(*e))).collect(),
|
||||
LiteralVector::Bool(v) => v.iter().map(|e| (Literal::Bool(*e))).collect(),
|
||||
LiteralVector::AbstractInt(v) => v.iter().map(|e| (Literal::AbstractInt(*e))).collect(),
|
||||
LiteralVector::AbstractFloat(v) => {
|
||||
v.iter().map(|e| (Literal::AbstractFloat(*e))).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_expr(&self, eval: &mut ConstantEvaluator<'_>) -> Expression {
|
||||
let lit_vec = self.to_literal_vec();
|
||||
assert!(!lit_vec.is_empty());
|
||||
if lit_vec.len() == 1 {
|
||||
Expression::Literal(lit_vec[0])
|
||||
} else {
|
||||
Expression::Compose {
|
||||
ty: eval.types.insert(
|
||||
Type {
|
||||
name: None,
|
||||
inner: TypeInner::Vector {
|
||||
size: match lit_vec.len() {
|
||||
2 => crate::VectorSize::Bi,
|
||||
3 => crate::VectorSize::Tri,
|
||||
4 => crate::VectorSize::Quad,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
scalar: lit_vec[0].scalar(),
|
||||
},
|
||||
},
|
||||
Span::UNDEFINED,
|
||||
),
|
||||
components: lit_vec
|
||||
.iter()
|
||||
.map(|&l| {
|
||||
eval.expressions
|
||||
.append(Expression::Literal(l), Span::UNDEFINED)
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Puts self into eval's expressions arena and returns handle to it
|
||||
fn handle(
|
||||
&self,
|
||||
eval: &mut ConstantEvaluator<'_>,
|
||||
span: Span,
|
||||
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
|
||||
let expr = self.to_expr(eval);
|
||||
eval.register_evaluated_expr(expr, span)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Behavior<'a> {
|
||||
Wgsl(WgslRestrictions<'a>),
|
||||
@ -917,9 +1165,10 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
|
||||
"select built-in function".into(),
|
||||
)),
|
||||
Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
|
||||
format!("{fun:?} built-in function"),
|
||||
)),
|
||||
Expression::Relational { fun, argument } => {
|
||||
let arg = self.check_and_get(argument)?;
|
||||
self.relational_op(fun, arg, span)
|
||||
}
|
||||
Expression::ArrayLength(expr) => match self.behavior {
|
||||
Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
|
||||
Behavior::Glsl(_) => {
|
||||
@ -1230,6 +1479,90 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
})
|
||||
}
|
||||
|
||||
// geometry
|
||||
crate::MathFunction::Dot => {
|
||||
let e1 = LiteralVector::from_expr(arg, self, span, false)?;
|
||||
let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?;
|
||||
if e1.len() != e2.len() {
|
||||
return Err(ConstantEvaluatorError::InvalidMathArg);
|
||||
}
|
||||
LiteralVector::from_literal(match (e1, e2) {
|
||||
(LiteralVector::AbstractFloat(e1), LiteralVector::AbstractFloat(e2)) => {
|
||||
Literal::AbstractFloat(
|
||||
e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum(),
|
||||
)
|
||||
}
|
||||
(LiteralVector::F32(e1), LiteralVector::F32(e2)) => {
|
||||
Literal::F32(e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum())
|
||||
}
|
||||
(LiteralVector::AbstractInt(e1), LiteralVector::AbstractInt(e2)) => {
|
||||
Literal::AbstractInt(
|
||||
e1.iter()
|
||||
.zip(e2.iter())
|
||||
.map(|(&e1, &e2)| e1.checked_mul(e2))
|
||||
.try_fold(0_i64, |acc, x| {
|
||||
if let Some(x) = x {
|
||||
acc.checked_add(x)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or(ConstantEvaluatorError::Overflow(
|
||||
"in dot built-in".to_string(),
|
||||
))?,
|
||||
)
|
||||
}
|
||||
// TODO: more
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.handle(self, span)
|
||||
}
|
||||
crate::MathFunction::Cross => {
|
||||
let e1 = LiteralVector::from_expr(arg, self, span, false)?;
|
||||
let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?;
|
||||
if e1.len() == 3 && e2.len() == 3 {
|
||||
match (e1, e2) {
|
||||
(LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => {
|
||||
LiteralVector::AbstractFloat(
|
||||
[
|
||||
a[1] * b[2] - a[2] * b[1],
|
||||
a[2] * b[0] - a[0] * b[2],
|
||||
a[0] * b[1] - a[1] * b[0],
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
(LiteralVector::AbstractInt(a), LiteralVector::AbstractInt(b)) => {
|
||||
LiteralVector::AbstractInt(
|
||||
[
|
||||
a[1].checked_mul(b[2])
|
||||
.zip(a[2].checked_mul(b[1]))
|
||||
.and_then(|(a, b)| a.checked_sub(b)),
|
||||
a[2].checked_mul(b[0])
|
||||
.zip(a[0].checked_mul(b[2]))
|
||||
.and_then(|(a, b)| a.checked_sub(b)),
|
||||
a[0].checked_mul(b[1])
|
||||
.zip(a[1].checked_mul(b[0]))
|
||||
.and_then(|(a, b)| a.checked_sub(b)),
|
||||
]
|
||||
.into_iter()
|
||||
.collect::<Option<_>>()
|
||||
.ok_or(
|
||||
ConstantEvaluatorError::Overflow(
|
||||
"in cross built-in".to_string(),
|
||||
),
|
||||
)?,
|
||||
)
|
||||
}
|
||||
// TODO: more
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
}
|
||||
.handle(self, span)
|
||||
} else {
|
||||
Err(ConstantEvaluatorError::InvalidMathArg)
|
||||
}
|
||||
}
|
||||
// computational
|
||||
crate::MathFunction::Sign => {
|
||||
component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
|
||||
@ -2059,6 +2392,38 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Ok(Expression::Compose { ty, components })
|
||||
}
|
||||
|
||||
fn relational_op(
|
||||
&mut self,
|
||||
fun: crate::RelationalFunction,
|
||||
arg: Handle<Expression>,
|
||||
span: Span,
|
||||
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
|
||||
let arg = LiteralVector::from_expr(arg, self, span, true)?;
|
||||
let res = LiteralVector::Bool(match fun {
|
||||
crate::RelationalFunction::IsNan => match arg {
|
||||
LiteralVector::F64(f) => f.iter().map(|e| e.is_nan()).collect(),
|
||||
LiteralVector::F32(f) => f.iter().map(|e| e.is_nan()).collect(),
|
||||
LiteralVector::AbstractFloat(f) => f.iter().map(|e| e.is_nan()).collect(),
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
},
|
||||
crate::RelationalFunction::IsInf => match arg {
|
||||
LiteralVector::F64(f) => f.iter().map(|e| e.is_infinite()).collect(),
|
||||
LiteralVector::F32(f) => f.iter().map(|e| e.is_infinite()).collect(),
|
||||
LiteralVector::AbstractFloat(f) => f.iter().map(|e| e.is_infinite()).collect(),
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
},
|
||||
crate::RelationalFunction::All => match arg {
|
||||
LiteralVector::Bool(bools) => iter::once(bools.iter().all(|b| *b)).collect(),
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
},
|
||||
crate::RelationalFunction::Any => match arg {
|
||||
LiteralVector::Bool(bools) => iter::once(bools.iter().any(|b| *b)).collect(),
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
},
|
||||
});
|
||||
res.handle(self, span)
|
||||
}
|
||||
|
||||
/// Deep copy `expr` from `expressions` into `self.expressions`.
|
||||
///
|
||||
/// Return the root of the new copy.
|
||||
|
Loading…
Reference in New Issue
Block a user