Use macro

Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com>
This commit is contained in:
sagudev 2024-11-11 07:36:43 +01:00
parent 318a5fd21c
commit 21acc9d7a4

View File

@ -254,6 +254,48 @@ gen_component_wise_extractor! {
],
}
macro_rules! match_literal_vector {
($($x:expr),* => $( $( ($($mat:ident($arg:ident)),*) -> $ret:path )|+ => $body:expr ),*,_ => $body2:expr) => {
match $($x),* {
$(
$(
($( LiteralVector::$mat($arg) ),*) => $ret($body),
)*
)*
_ => $body2
}
};
($($x:expr),* => $( $( ($($mat:ident($arg:ident)),*) )|+ => $body:expr ),*,_ => $body2:expr) => {
match $($x),* {
$(
$(
($( LiteralVector::$mat($arg) ),*) => $body,
)*
)*
_ => $body2
}
};
($($x:expr),* => $( $( $mat:ident$arg:tt -> $ret:path )|+ => $body:expr ),*,_ => $body2:expr) => {
match $($x),* {
$(
$(
LiteralVector::$mat$arg => $ret($body),
)*
)*
_ => $body2
}
};
($($x:expr),* => $( $( $mat:ident$arg:tt )|+ => $body:expr ),*) => {
match $($x),* {
$(
$(
LiteralVector::$mat$arg => $body,
)*
)*
}
};
}
/// Vectors with a concrete element type.
#[derive(Debug)]
enum LiteralVector {
@ -270,17 +312,17 @@ enum LiteralVector {
impl LiteralVector {
const fn len(&self) -> usize {
match *self {
LiteralVector::F64(ref v) => v.len(),
LiteralVector::F32(ref v) => v.len(),
LiteralVector::U32(ref v) => v.len(),
LiteralVector::I32(ref v) => v.len(),
LiteralVector::U64(ref v) => v.len(),
LiteralVector::I64(ref v) => v.len(),
LiteralVector::Bool(ref v) => v.len(),
LiteralVector::AbstractInt(ref v) => v.len(),
LiteralVector::AbstractFloat(ref v) => v.len(),
}
match_literal_vector!(*self =>
F64(ref v)
| F32(ref v)
| U32(ref v)
| I32(ref v)
| U64(ref v)
| I64(ref v)
| Bool(ref v)
| AbstractInt(ref v)
| AbstractFloat(ref v) => v.len()
)
}
/// Creates [`LiteralVector`] of size 1 from single [`Literal`]
@ -1506,23 +1548,15 @@ impl<'a> ConstantEvaluator<'a> {
))
}
LiteralVector::from_literal(match (e1, e2) {
(LiteralVector::AbstractFloat(e1), LiteralVector::AbstractFloat(e2)) => {
Literal::AbstractFloat(float_dot(e1, e2))
}
(LiteralVector::F32(e1), LiteralVector::F32(e2)) => {
Literal::F32(float_dot(e1, e2))
}
(LiteralVector::AbstractInt(e1), LiteralVector::AbstractInt(e2)) => {
Literal::AbstractInt(int_dot(e1, e2)?)
}
(LiteralVector::I32(e1), LiteralVector::I32(e2)) => {
Literal::I32(int_dot(e1, e2)?)
}
(LiteralVector::U32(e1), LiteralVector::U32(e2)) => {
Literal::U32(int_dot(e1, e2)?)
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
LiteralVector::from_literal(match_literal_vector! {(e1, e2) =>
(AbstractFloat(e1), AbstractFloat(e2)) -> Literal::AbstractFloat
| (F32(e1), F32(e2)) -> Literal::F32
=> float_dot(e1, e2),
(AbstractInt(e1), AbstractInt(e2)) -> Literal::AbstractInt
| (I32(e1), I32(e2)) -> Literal::I32
| (U32(e1), U32(e2)) -> Literal::U32
=> int_dot(e1, e2)?,
_ => return Err(ConstantEvaluatorError::InvalidMathArg)
})
.handle(self, span)
}
@ -1546,14 +1580,11 @@ impl<'a> ConstantEvaluator<'a> {
.into_iter()
.collect()
}
match (e1, e2) {
(LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => {
LiteralVector::AbstractFloat(float_cross(a, b))
}
(LiteralVector::F32(a), LiteralVector::F32(b)) => {
LiteralVector::F32(float_cross(a, b))
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
match_literal_vector! {(e1, e2) =>
(AbstractFloat(e1), AbstractFloat(e2)) -> LiteralVector::AbstractFloat
| (F32(e1), F32(e2)) -> LiteralVector::F32
=> float_cross(e1, e2),
_ => return Err(ConstantEvaluatorError::InvalidMathArg)
}
.handle(self, span)
} else {
@ -1571,10 +1602,11 @@ impl<'a> ConstantEvaluator<'a> {
e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
}
LiteralVector::from_literal(match e1 {
LiteralVector::AbstractFloat(a) => Literal::AbstractFloat(float_length(a)),
LiteralVector::F32(a) => Literal::F32(float_length(a)),
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
LiteralVector::from_literal(match_literal_vector! {e1 =>
AbstractFloat(e1) -> Literal::AbstractFloat
| F32(e1) -> Literal::F32
=> float_length(e1),
_ => return Err(ConstantEvaluatorError::InvalidMathArg)
})
.handle(self, span)
}
@ -1600,14 +1632,11 @@ impl<'a> ConstantEvaluator<'a> {
.sum::<F>()
.sqrt()
}
LiteralVector::from_literal(match (e1, e2) {
(LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => {
Literal::AbstractFloat(float_distance(a, b))
}
(LiteralVector::F32(a), LiteralVector::F32(b)) => {
Literal::F32(float_distance(a, b))
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
LiteralVector::from_literal(match_literal_vector! {(e1, e2) =>
(AbstractFloat(e1), AbstractFloat(e2)) -> Literal::AbstractFloat
| (F32(e1), F32(e2)) -> Literal::F32
=> float_distance(e1, e2),
_ => return Err(ConstantEvaluatorError::InvalidMathArg)
})
.handle(self, span)
}