use num-traits

Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com>
This commit is contained in:
sagudev 2024-09-06 10:51:27 +02:00
parent 52560ca8c7
commit 2aae42d5c3
3 changed files with 63 additions and 52 deletions

1
Cargo.lock generated
View File

@ -1893,6 +1893,7 @@ dependencies = [
"indexmap",
"itertools",
"log",
"num-traits",
"petgraph",
"pp-rs",
"ron",

View File

@ -81,7 +81,8 @@ serde = { version = "1.0.214", features = ["derive"], optional = true }
petgraph = { version = "0.6", optional = true }
pp-rs = { version = "0.2.1", optional = true }
hexf-parse = { version = "0.2.1", optional = true }
unicode-xid = { version = "0.2.6", optional = true }
unicode-xid = { version = "0.2.5", optional = true }
num-traits = "0.2"
[build-dependencies]
cfg_aliases.workspace = true

View File

@ -299,6 +299,7 @@ impl LiteralVector {
}
}
#[allow(dead_code)]
/// Creates [`LiteralVector`] from Array of [`Literal`]s
///
/// Panics if vector is empty
@ -1486,33 +1487,53 @@ impl<'a> ConstantEvaluator<'a> {
if e1.len() != e2.len() {
return Err(ConstantEvaluatorError::InvalidMathArg);
}
fn float_dot<F, const CAP: usize>(a: ArrayVec<F, CAP>, b: ArrayVec<F, CAP>) -> F
where
F: std::ops::Mul<F>,
F: num_traits::Float + std::iter::Sum,
{
a.iter().zip(b.iter()).map(|(&aa, &bb)| aa * bb).sum()
}
fn int_dot<P, const CAP: usize>(
a: ArrayVec<P, CAP>,
b: ArrayVec<P, CAP>,
) -> Result<P, ConstantEvaluatorError>
where
P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
{
a.iter()
.zip(b.iter())
.map(|(&aa, bb)| aa.checked_mul(bb))
.try_fold(P::zero(), |acc, x| {
if let Some(x) = x {
acc.checked_add(&x)
} else {
None
}
})
.ok_or(ConstantEvaluatorError::Overflow(
"in dot built-in".to_string(),
))
}
LiteralVector::from_literal(match (e1, e2) {
(LiteralVector::AbstractFloat(e1), LiteralVector::AbstractFloat(e2)) => {
Literal::AbstractFloat(
e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum(),
)
Literal::AbstractFloat(float_dot(e1, e2))
}
(LiteralVector::F32(e1), LiteralVector::F32(e2)) => {
Literal::F32(e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum())
Literal::F32(float_dot(e1, e2))
}
(LiteralVector::AbstractInt(e1), LiteralVector::AbstractInt(e2)) => {
Literal::AbstractInt(
e1.iter()
.zip(e2.iter())
.map(|(&e1, &e2)| e1.checked_mul(e2))
.try_fold(0_i64, |acc, x| {
if let Some(x) = x {
acc.checked_add(x)
} else {
None
}
})
.ok_or(ConstantEvaluatorError::Overflow(
"in dot built-in".to_string(),
))?,
)
Literal::AbstractInt(int_dot(e1, e2)?)
}
(LiteralVector::I32(e1), LiteralVector::I32(e2)) => {
Literal::I32(int_dot(e1, e2)?)
}
(LiteralVector::U32(e1), LiteralVector::U32(e2)) => {
Literal::U32(int_dot(e1, e2)?)
}
// TODO: more
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
})
.handle(self, span)
@ -1521,41 +1542,29 @@ impl<'a> ConstantEvaluator<'a> {
let e1 = LiteralVector::from_expr(arg, self, span, false)?;
let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?;
if e1.len() == 3 && e2.len() == 3 {
fn float_cross<F, const CAP: usize>(
a: ArrayVec<F, CAP>,
b: ArrayVec<F, CAP>,
) -> ArrayVec<F, CAP>
where
F: std::ops::Mul<F>,
F: num_traits::Float + std::iter::Sum,
{
[
a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0],
]
.into_iter()
.collect()
}
match (e1, e2) {
(LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => {
LiteralVector::AbstractFloat(
[
a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0],
]
.into_iter()
.collect(),
)
LiteralVector::AbstractFloat(float_cross(a, b))
}
(LiteralVector::AbstractInt(a), LiteralVector::AbstractInt(b)) => {
LiteralVector::AbstractInt(
[
a[1].checked_mul(b[2])
.zip(a[2].checked_mul(b[1]))
.and_then(|(a, b)| a.checked_sub(b)),
a[2].checked_mul(b[0])
.zip(a[0].checked_mul(b[2]))
.and_then(|(a, b)| a.checked_sub(b)),
a[0].checked_mul(b[1])
.zip(a[1].checked_mul(b[0]))
.and_then(|(a, b)| a.checked_sub(b)),
]
.into_iter()
.collect::<Option<_>>()
.ok_or(
ConstantEvaluatorError::Overflow(
"in cross built-in".to_string(),
),
)?,
)
(LiteralVector::F32(a), LiteralVector::F32(b)) => {
LiteralVector::F32(float_cross(a, b))
}
// TODO: more
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
}
.handle(self, span)