mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-22 06:44:14 +00:00
use num-traits
Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com>
This commit is contained in:
parent
52560ca8c7
commit
2aae42d5c3
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -1893,6 +1893,7 @@ dependencies = [
|
||||
"indexmap",
|
||||
"itertools",
|
||||
"log",
|
||||
"num-traits",
|
||||
"petgraph",
|
||||
"pp-rs",
|
||||
"ron",
|
||||
|
@ -81,7 +81,8 @@ serde = { version = "1.0.214", features = ["derive"], optional = true }
|
||||
petgraph = { version = "0.6", optional = true }
|
||||
pp-rs = { version = "0.2.1", optional = true }
|
||||
hexf-parse = { version = "0.2.1", optional = true }
|
||||
unicode-xid = { version = "0.2.6", optional = true }
|
||||
unicode-xid = { version = "0.2.5", optional = true }
|
||||
num-traits = "0.2"
|
||||
|
||||
[build-dependencies]
|
||||
cfg_aliases.workspace = true
|
||||
|
@ -299,6 +299,7 @@ impl LiteralVector {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
/// Creates [`LiteralVector`] from Array of [`Literal`]s
|
||||
///
|
||||
/// Panics if vector is empty
|
||||
@ -1486,33 +1487,53 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
if e1.len() != e2.len() {
|
||||
return Err(ConstantEvaluatorError::InvalidMathArg);
|
||||
}
|
||||
|
||||
fn float_dot<F, const CAP: usize>(a: ArrayVec<F, CAP>, b: ArrayVec<F, CAP>) -> F
|
||||
where
|
||||
F: std::ops::Mul<F>,
|
||||
F: num_traits::Float + std::iter::Sum,
|
||||
{
|
||||
a.iter().zip(b.iter()).map(|(&aa, &bb)| aa * bb).sum()
|
||||
}
|
||||
|
||||
fn int_dot<P, const CAP: usize>(
|
||||
a: ArrayVec<P, CAP>,
|
||||
b: ArrayVec<P, CAP>,
|
||||
) -> Result<P, ConstantEvaluatorError>
|
||||
where
|
||||
P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
|
||||
{
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(&aa, bb)| aa.checked_mul(bb))
|
||||
.try_fold(P::zero(), |acc, x| {
|
||||
if let Some(x) = x {
|
||||
acc.checked_add(&x)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or(ConstantEvaluatorError::Overflow(
|
||||
"in dot built-in".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
LiteralVector::from_literal(match (e1, e2) {
|
||||
(LiteralVector::AbstractFloat(e1), LiteralVector::AbstractFloat(e2)) => {
|
||||
Literal::AbstractFloat(
|
||||
e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum(),
|
||||
)
|
||||
Literal::AbstractFloat(float_dot(e1, e2))
|
||||
}
|
||||
(LiteralVector::F32(e1), LiteralVector::F32(e2)) => {
|
||||
Literal::F32(e1.iter().zip(e2.iter()).map(|(e1, e2)| e1 * e2).sum())
|
||||
Literal::F32(float_dot(e1, e2))
|
||||
}
|
||||
(LiteralVector::AbstractInt(e1), LiteralVector::AbstractInt(e2)) => {
|
||||
Literal::AbstractInt(
|
||||
e1.iter()
|
||||
.zip(e2.iter())
|
||||
.map(|(&e1, &e2)| e1.checked_mul(e2))
|
||||
.try_fold(0_i64, |acc, x| {
|
||||
if let Some(x) = x {
|
||||
acc.checked_add(x)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or(ConstantEvaluatorError::Overflow(
|
||||
"in dot built-in".to_string(),
|
||||
))?,
|
||||
)
|
||||
Literal::AbstractInt(int_dot(e1, e2)?)
|
||||
}
|
||||
(LiteralVector::I32(e1), LiteralVector::I32(e2)) => {
|
||||
Literal::I32(int_dot(e1, e2)?)
|
||||
}
|
||||
(LiteralVector::U32(e1), LiteralVector::U32(e2)) => {
|
||||
Literal::U32(int_dot(e1, e2)?)
|
||||
}
|
||||
// TODO: more
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
})
|
||||
.handle(self, span)
|
||||
@ -1521,41 +1542,29 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
let e1 = LiteralVector::from_expr(arg, self, span, false)?;
|
||||
let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?;
|
||||
if e1.len() == 3 && e2.len() == 3 {
|
||||
fn float_cross<F, const CAP: usize>(
|
||||
a: ArrayVec<F, CAP>,
|
||||
b: ArrayVec<F, CAP>,
|
||||
) -> ArrayVec<F, CAP>
|
||||
where
|
||||
F: std::ops::Mul<F>,
|
||||
F: num_traits::Float + std::iter::Sum,
|
||||
{
|
||||
[
|
||||
a[1] * b[2] - a[2] * b[1],
|
||||
a[2] * b[0] - a[0] * b[2],
|
||||
a[0] * b[1] - a[1] * b[0],
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
match (e1, e2) {
|
||||
(LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => {
|
||||
LiteralVector::AbstractFloat(
|
||||
[
|
||||
a[1] * b[2] - a[2] * b[1],
|
||||
a[2] * b[0] - a[0] * b[2],
|
||||
a[0] * b[1] - a[1] * b[0],
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
LiteralVector::AbstractFloat(float_cross(a, b))
|
||||
}
|
||||
(LiteralVector::AbstractInt(a), LiteralVector::AbstractInt(b)) => {
|
||||
LiteralVector::AbstractInt(
|
||||
[
|
||||
a[1].checked_mul(b[2])
|
||||
.zip(a[2].checked_mul(b[1]))
|
||||
.and_then(|(a, b)| a.checked_sub(b)),
|
||||
a[2].checked_mul(b[0])
|
||||
.zip(a[0].checked_mul(b[2]))
|
||||
.and_then(|(a, b)| a.checked_sub(b)),
|
||||
a[0].checked_mul(b[1])
|
||||
.zip(a[1].checked_mul(b[0]))
|
||||
.and_then(|(a, b)| a.checked_sub(b)),
|
||||
]
|
||||
.into_iter()
|
||||
.collect::<Option<_>>()
|
||||
.ok_or(
|
||||
ConstantEvaluatorError::Overflow(
|
||||
"in cross built-in".to_string(),
|
||||
),
|
||||
)?,
|
||||
)
|
||||
(LiteralVector::F32(a), LiteralVector::F32(b)) => {
|
||||
LiteralVector::F32(float_cross(a, b))
|
||||
}
|
||||
// TODO: more
|
||||
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
|
||||
}
|
||||
.handle(self, span)
|
||||
|
Loading…
Reference in New Issue
Block a user