LiteralVector and some demo

Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com>
This commit is contained in:
sagudev 2024-09-04 18:16:52 +02:00
parent 79a6f2cd31
commit 52560ca8c7

View File

@ -254,6 +254,254 @@ gen_component_wise_extractor! {
],
}
/// Vector for each [`Literal`] type
///
/// This type ensures that all elements have same type
enum LiteralVector {
F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
}
impl LiteralVector {
#[allow(clippy::pattern_type_mismatch)]
const fn len(&self) -> usize {
match self {
LiteralVector::F64(v) => v.len(),
LiteralVector::F32(v) => v.len(),
LiteralVector::U32(v) => v.len(),
LiteralVector::I32(v) => v.len(),
LiteralVector::U64(v) => v.len(),
LiteralVector::I64(v) => v.len(),
LiteralVector::Bool(v) => v.len(),
LiteralVector::AbstractInt(v) => v.len(),
LiteralVector::AbstractFloat(v) => v.len(),
}
}
/// Creates [`LiteralVector`] of size 1 from single [`Literal`]
fn from_literal(literal: Literal) -> Self {
match literal {
Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))),
Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))),
Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))),
Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))),
Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))),
Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))),
Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))),
Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))),
Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))),
}
}
/// Creates [`LiteralVector`] from Array of [`Literal`]s
///
/// Panics if vector is empty
fn from_literal_vec(
components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
) -> Result<Self, ConstantEvaluatorError> {
let scalar = components[0].scalar();
Self::from_literal_vec_with_scalar_type(components, scalar)
}
/// Creates [`LiteralVector`] of type provided by scalar from Array of [`Literal`]s
///
/// Panics if vector is empty, returns error if types do not match
fn from_literal_vec_with_scalar_type(
components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
scalar: crate::Scalar,
) -> Result<Self, ConstantEvaluatorError> {
assert!(!components.is_empty());
Ok(match scalar {
crate::Scalar::I32 => Self::I32(
components
.iter()
.map(|l| match l {
&Literal::I32(v) => Ok(v),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, _>>()?,
),
crate::Scalar::U32 => Self::U32(
components
.iter()
.map(|l| match l {
&Literal::U32(v) => Ok(v),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, _>>()?,
),
crate::Scalar::I64 => Self::I64(
components
.iter()
.map(|l| match l {
&Literal::I64(v) => Ok(v),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, _>>()?,
),
crate::Scalar::U64 => Self::U64(
components
.iter()
.map(|l| match l {
&Literal::U64(v) => Ok(v),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, _>>()?,
),
crate::Scalar::F32 => Self::F32(
components
.iter()
.map(|l| match l {
&Literal::F32(v) => Ok(v),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, _>>()?,
),
crate::Scalar::F64 => Self::F64(
components
.iter()
.map(|l| match l {
&Literal::F64(v) => Ok(v),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, _>>()?,
),
crate::Scalar::BOOL => Self::Bool(
components
.iter()
.map(|l| match l {
&Literal::Bool(v) => Ok(v),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, _>>()?,
),
crate::Scalar::ABSTRACT_INT => Self::AbstractInt(
components
.iter()
.map(|l| match l {
&Literal::AbstractInt(v) => Ok(v),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, _>>()?,
),
crate::Scalar::ABSTRACT_FLOAT => Self::AbstractFloat(
components
.iter()
.map(|l| match l {
&Literal::AbstractFloat(v) => Ok(v),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, _>>()?,
),
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
})
}
fn from_expr(
expr: Handle<Expression>,
eval: &mut ConstantEvaluator<'_>,
span: Span,
allow_single: bool,
) -> Result<Self, ConstantEvaluatorError> {
let expr = eval
.eval_zero_value_and_splat(expr, span)
.map(|expr| &eval.expressions[expr])?;
match *expr {
Expression::Literal(literal) => {
if allow_single {
Ok(Self::from_literal(literal))
} else {
Err(ConstantEvaluatorError::InvalidMathArg)
}
}
Expression::Compose { ty, ref components } => match eval.types[ty].inner {
TypeInner::Vector { scalar, .. } => {
if components.len() > crate::VectorSize::MAX {
return Err(ConstantEvaluatorError::InvalidMathArg);
}
let components: ArrayVec<Literal, { crate::VectorSize::MAX }> =
crate::proc::flatten_compose(ty, components, eval.expressions, eval.types)
.map(|expr| match eval.expressions[expr] {
Expression::Literal(l) => Ok(l),
_ => Err(ConstantEvaluatorError::InvalidMathArg),
})
.collect::<Result<_, ConstantEvaluatorError>>()?;
Self::from_literal_vec_with_scalar_type(components, scalar)
}
_ => Err(ConstantEvaluatorError::InvalidMathArg),
},
_ => Err(ConstantEvaluatorError::InvalidMathArg),
}
}
/// Returns [`ArrayVec`] of [`Literals`]
fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
#[allow(clippy::pattern_type_mismatch)]
match self {
LiteralVector::F64(v) => v.iter().map(|e| (Literal::F64(*e))).collect(),
LiteralVector::F32(v) => v.iter().map(|e| (Literal::F32(*e))).collect(),
LiteralVector::U32(v) => v.iter().map(|e| (Literal::U32(*e))).collect(),
LiteralVector::I32(v) => v.iter().map(|e| (Literal::I32(*e))).collect(),
LiteralVector::U64(v) => v.iter().map(|e| (Literal::U64(*e))).collect(),
LiteralVector::I64(v) => v.iter().map(|e| (Literal::I64(*e))).collect(),
LiteralVector::Bool(v) => v.iter().map(|e| (Literal::Bool(*e))).collect(),
LiteralVector::AbstractInt(v) => v.iter().map(|e| (Literal::AbstractInt(*e))).collect(),
LiteralVector::AbstractFloat(v) => {
v.iter().map(|e| (Literal::AbstractFloat(*e))).collect()
}
}
}
fn to_expr(&self, eval: &mut ConstantEvaluator<'_>) -> Expression {
let lit_vec = self.to_literal_vec();
assert!(!lit_vec.is_empty());
if lit_vec.len() == 1 {
Expression::Literal(lit_vec[0])
} else {
Expression::Compose {
ty: eval.types.insert(
Type {
name: None,
inner: TypeInner::Vector {
size: match lit_vec.len() {
2 => crate::VectorSize::Bi,
3 => crate::VectorSize::Tri,
4 => crate::VectorSize::Quad,
_ => unreachable!(),
},
scalar: lit_vec[0].scalar(),
},
},
Span::UNDEFINED,
),
components: lit_vec
.iter()
.map(|&l| {
eval.expressions
.append(Expression::Literal(l), Span::UNDEFINED)
})
.collect(),
}
}
}
/// Puts self into eval's expressions arena and returns handle to it
fn handle(
&self,
eval: &mut ConstantEvaluator<'_>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let expr = self.to_expr(eval);
eval.register_evaluated_expr(expr, span)
}
}
#[derive(Debug)]
enum Behavior<'a> {
Wgsl(WgslRestrictions<'a>),
@ -917,9 +1165,10 @@ impl<'a> ConstantEvaluator<'a> {
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
"select built-in function".into(),
)),
Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
format!("{fun:?} built-in function"),
)),
Expression::Relational { fun, argument } => {
let arg = self.check_and_get(argument)?;
self.relational_op(fun, arg, span)
}
Expression::ArrayLength(expr) => match self.behavior {
Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
Behavior::Glsl(_) => {
@ -1230,6 +1479,90 @@ impl<'a> ConstantEvaluator<'a> {
})
}
// geometry
crate::MathFunction::Dot => {
let e1 = LiteralVector::from_expr(arg, self, span, false)?;
let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?;
if e1.len() != e2.len() {
return Err(ConstantEvaluatorError::InvalidMathArg);
}
LiteralVector::from_literal(match (e1, e2) {
(LiteralVector::AbstractFloat(e1), LiteralVector::AbstractFloat(e2)) => {
Literal::AbstractFloat(
e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum(),
)
}
(LiteralVector::F32(e1), LiteralVector::F32(e2)) => {
Literal::F32(e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum())
}
(LiteralVector::AbstractInt(e1), LiteralVector::AbstractInt(e2)) => {
Literal::AbstractInt(
e1.iter()
.zip(e2.iter())
.map(|(&e1, &e2)| e1.checked_mul(e2))
.try_fold(0_i64, |acc, x| {
if let Some(x) = x {
acc.checked_add(x)
} else {
None
}
})
.ok_or(ConstantEvaluatorError::Overflow(
"in dot built-in".to_string(),
))?,
)
}
// TODO: more
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
})
.handle(self, span)
}
crate::MathFunction::Cross => {
let e1 = LiteralVector::from_expr(arg, self, span, false)?;
let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?;
if e1.len() == 3 && e2.len() == 3 {
match (e1, e2) {
(LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => {
LiteralVector::AbstractFloat(
[
a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0],
]
.into_iter()
.collect(),
)
}
(LiteralVector::AbstractInt(a), LiteralVector::AbstractInt(b)) => {
LiteralVector::AbstractInt(
[
a[1].checked_mul(b[2])
.zip(a[2].checked_mul(b[1]))
.and_then(|(a, b)| a.checked_sub(b)),
a[2].checked_mul(b[0])
.zip(a[0].checked_mul(b[2]))
.and_then(|(a, b)| a.checked_sub(b)),
a[0].checked_mul(b[1])
.zip(a[1].checked_mul(b[0]))
.and_then(|(a, b)| a.checked_sub(b)),
]
.into_iter()
.collect::<Option<_>>()
.ok_or(
ConstantEvaluatorError::Overflow(
"in cross built-in".to_string(),
),
)?,
)
}
// TODO: more
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
}
.handle(self, span)
} else {
Err(ConstantEvaluatorError::InvalidMathArg)
}
}
// computational
crate::MathFunction::Sign => {
component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
@ -2059,6 +2392,38 @@ impl<'a> ConstantEvaluator<'a> {
Ok(Expression::Compose { ty, components })
}
fn relational_op(
&mut self,
fun: crate::RelationalFunction,
arg: Handle<Expression>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let arg = LiteralVector::from_expr(arg, self, span, true)?;
let res = LiteralVector::Bool(match fun {
crate::RelationalFunction::IsNan => match arg {
LiteralVector::F64(f) => f.iter().map(|e| e.is_nan()).collect(),
LiteralVector::F32(f) => f.iter().map(|e| e.is_nan()).collect(),
LiteralVector::AbstractFloat(f) => f.iter().map(|e| e.is_nan()).collect(),
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
},
crate::RelationalFunction::IsInf => match arg {
LiteralVector::F64(f) => f.iter().map(|e| e.is_infinite()).collect(),
LiteralVector::F32(f) => f.iter().map(|e| e.is_infinite()).collect(),
LiteralVector::AbstractFloat(f) => f.iter().map(|e| e.is_infinite()).collect(),
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
},
crate::RelationalFunction::All => match arg {
LiteralVector::Bool(bools) => iter::once(bools.iter().all(|b| *b)).collect(),
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
},
crate::RelationalFunction::Any => match arg {
LiteralVector::Bool(bools) => iter::once(bools.iter().any(|b| *b)).collect(),
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
},
});
res.handle(self, span)
}
/// Deep copy `expr` from `expressions` into `self.expressions`.
///
/// Return the root of the new copy.