Use macro

Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com>
This commit is contained in:
sagudev 2024-11-11 07:36:43 +01:00
parent 318a5fd21c
commit 21acc9d7a4

View File

@ -254,6 +254,48 @@ gen_component_wise_extractor! {
], ],
} }
macro_rules! match_literal_vector {
($($x:expr),* => $( $( ($($mat:ident($arg:ident)),*) -> $ret:path )|+ => $body:expr ),*,_ => $body2:expr) => {
match $($x),* {
$(
$(
($( LiteralVector::$mat($arg) ),*) => $ret($body),
)*
)*
_ => $body2
}
};
($($x:expr),* => $( $( ($($mat:ident($arg:ident)),*) )|+ => $body:expr ),*,_ => $body2:expr) => {
match $($x),* {
$(
$(
($( LiteralVector::$mat($arg) ),*) => $body,
)*
)*
_ => $body2
}
};
($($x:expr),* => $( $( $mat:ident$arg:tt -> $ret:path )|+ => $body:expr ),*,_ => $body2:expr) => {
match $($x),* {
$(
$(
LiteralVector::$mat$arg => $ret($body),
)*
)*
_ => $body2
}
};
($($x:expr),* => $( $( $mat:ident$arg:tt )|+ => $body:expr ),*) => {
match $($x),* {
$(
$(
LiteralVector::$mat$arg => $body,
)*
)*
}
};
}
/// Vectors with a concrete element type. /// Vectors with a concrete element type.
#[derive(Debug)] #[derive(Debug)]
enum LiteralVector { enum LiteralVector {
@ -270,17 +312,17 @@ enum LiteralVector {
impl LiteralVector { impl LiteralVector {
const fn len(&self) -> usize { const fn len(&self) -> usize {
match *self { match_literal_vector!(*self =>
LiteralVector::F64(ref v) => v.len(), F64(ref v)
LiteralVector::F32(ref v) => v.len(), | F32(ref v)
LiteralVector::U32(ref v) => v.len(), | U32(ref v)
LiteralVector::I32(ref v) => v.len(), | I32(ref v)
LiteralVector::U64(ref v) => v.len(), | U64(ref v)
LiteralVector::I64(ref v) => v.len(), | I64(ref v)
LiteralVector::Bool(ref v) => v.len(), | Bool(ref v)
LiteralVector::AbstractInt(ref v) => v.len(), | AbstractInt(ref v)
LiteralVector::AbstractFloat(ref v) => v.len(), | AbstractFloat(ref v) => v.len()
} )
} }
/// Creates [`LiteralVector`] of size 1 from single [`Literal`] /// Creates [`LiteralVector`] of size 1 from single [`Literal`]
@ -1506,23 +1548,15 @@ impl<'a> ConstantEvaluator<'a> {
)) ))
} }
LiteralVector::from_literal(match (e1, e2) { LiteralVector::from_literal(match_literal_vector! {(e1, e2) =>
(LiteralVector::AbstractFloat(e1), LiteralVector::AbstractFloat(e2)) => { (AbstractFloat(e1), AbstractFloat(e2)) -> Literal::AbstractFloat
Literal::AbstractFloat(float_dot(e1, e2)) | (F32(e1), F32(e2)) -> Literal::F32
} => float_dot(e1, e2),
(LiteralVector::F32(e1), LiteralVector::F32(e2)) => { (AbstractInt(e1), AbstractInt(e2)) -> Literal::AbstractInt
Literal::F32(float_dot(e1, e2)) | (I32(e1), I32(e2)) -> Literal::I32
} | (U32(e1), U32(e2)) -> Literal::U32
(LiteralVector::AbstractInt(e1), LiteralVector::AbstractInt(e2)) => { => int_dot(e1, e2)?,
Literal::AbstractInt(int_dot(e1, e2)?) _ => return Err(ConstantEvaluatorError::InvalidMathArg)
}
(LiteralVector::I32(e1), LiteralVector::I32(e2)) => {
Literal::I32(int_dot(e1, e2)?)
}
(LiteralVector::U32(e1), LiteralVector::U32(e2)) => {
Literal::U32(int_dot(e1, e2)?)
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
}) })
.handle(self, span) .handle(self, span)
} }
@ -1546,14 +1580,11 @@ impl<'a> ConstantEvaluator<'a> {
.into_iter() .into_iter()
.collect() .collect()
} }
match (e1, e2) { match_literal_vector! {(e1, e2) =>
(LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => { (AbstractFloat(e1), AbstractFloat(e2)) -> LiteralVector::AbstractFloat
LiteralVector::AbstractFloat(float_cross(a, b)) | (F32(e1), F32(e2)) -> LiteralVector::F32
} => float_cross(e1, e2),
(LiteralVector::F32(a), LiteralVector::F32(b)) => { _ => return Err(ConstantEvaluatorError::InvalidMathArg)
LiteralVector::F32(float_cross(a, b))
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
} }
.handle(self, span) .handle(self, span)
} else { } else {
@ -1571,10 +1602,11 @@ impl<'a> ConstantEvaluator<'a> {
e.iter().map(|&ei| ei * ei).sum::<F>().sqrt() e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
} }
LiteralVector::from_literal(match e1 { LiteralVector::from_literal(match_literal_vector! {e1 =>
LiteralVector::AbstractFloat(a) => Literal::AbstractFloat(float_length(a)), AbstractFloat(e1) -> Literal::AbstractFloat
LiteralVector::F32(a) => Literal::F32(float_length(a)), | F32(e1) -> Literal::F32
_ => return Err(ConstantEvaluatorError::InvalidMathArg), => float_length(e1),
_ => return Err(ConstantEvaluatorError::InvalidMathArg)
}) })
.handle(self, span) .handle(self, span)
} }
@ -1600,14 +1632,11 @@ impl<'a> ConstantEvaluator<'a> {
.sum::<F>() .sum::<F>()
.sqrt() .sqrt()
} }
LiteralVector::from_literal(match (e1, e2) { LiteralVector::from_literal(match_literal_vector! {(e1, e2) =>
(LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => { (AbstractFloat(e1), AbstractFloat(e2)) -> Literal::AbstractFloat
Literal::AbstractFloat(float_distance(a, b)) | (F32(e1), F32(e2)) -> Literal::F32
} => float_distance(e1, e2),
(LiteralVector::F32(a), LiteralVector::F32(b)) => { _ => return Err(ConstantEvaluatorError::InvalidMathArg)
Literal::F32(float_distance(a, b))
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
}) })
.handle(self, span) .handle(self, span)
} }