diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index d4d12fca9..34353dc18 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -254,6 +254,48 @@ gen_component_wise_extractor! { ], } +macro_rules! match_literal_vector { + ($($x:expr),* => $( $( ($($mat:ident($arg:ident)),*) -> $ret:path )|+ => $body:expr ),*,_ => $body2:expr) => { + match $($x),* { + $( + $( + ($( LiteralVector::$mat($arg) ),*) => $ret($body), + )* + )* + _ => $body2 + } + }; + ($($x:expr),* => $( $( ($($mat:ident($arg:ident)),*) )|+ => $body:expr ),*,_ => $body2:expr) => { + match $($x),* { + $( + $( + ($( LiteralVector::$mat($arg) ),*) => $body, + )* + )* + _ => $body2 + } + }; + ($($x:expr),* => $( $( $mat:ident$arg:tt -> $ret:path )|+ => $body:expr ),*,_ => $body2:expr) => { + match $($x),* { + $( + $( + LiteralVector::$mat$arg => $ret($body), + )* + )* + _ => $body2 + } + }; + ($($x:expr),* => $( $( $mat:ident$arg:tt )|+ => $body:expr ),*) => { + match $($x),* { + $( + $( + LiteralVector::$mat$arg => $body, + )* + )* + } + }; +} + /// Vectors with a concrete element type. #[derive(Debug)] enum LiteralVector { @@ -270,17 +312,17 @@ enum LiteralVector { impl LiteralVector { const fn len(&self) -> usize { - match *self { - LiteralVector::F64(ref v) => v.len(), - LiteralVector::F32(ref v) => v.len(), - LiteralVector::U32(ref v) => v.len(), - LiteralVector::I32(ref v) => v.len(), - LiteralVector::U64(ref v) => v.len(), - LiteralVector::I64(ref v) => v.len(), - LiteralVector::Bool(ref v) => v.len(), - LiteralVector::AbstractInt(ref v) => v.len(), - LiteralVector::AbstractFloat(ref v) => v.len(), - } + match_literal_vector!(*self => + F64(ref v) + | F32(ref v) + | U32(ref v) + | I32(ref v) + | U64(ref v) + | I64(ref v) + | Bool(ref v) + | AbstractInt(ref v) + | AbstractFloat(ref v) => v.len() + ) } /// Creates [`LiteralVector`] of size 1 from single [`Literal`] @@ -1506,23 +1548,15 @@ impl<'a> ConstantEvaluator<'a> { )) } - LiteralVector::from_literal(match (e1, e2) { - (LiteralVector::AbstractFloat(e1), LiteralVector::AbstractFloat(e2)) => { - Literal::AbstractFloat(float_dot(e1, e2)) - } - (LiteralVector::F32(e1), LiteralVector::F32(e2)) => { - Literal::F32(float_dot(e1, e2)) - } - (LiteralVector::AbstractInt(e1), LiteralVector::AbstractInt(e2)) => { - Literal::AbstractInt(int_dot(e1, e2)?) - } - (LiteralVector::I32(e1), LiteralVector::I32(e2)) => { - Literal::I32(int_dot(e1, e2)?) - } - (LiteralVector::U32(e1), LiteralVector::U32(e2)) => { - Literal::U32(int_dot(e1, e2)?) - } - _ => return Err(ConstantEvaluatorError::InvalidMathArg), + LiteralVector::from_literal(match_literal_vector! {(e1, e2) => + (AbstractFloat(e1), AbstractFloat(e2)) -> Literal::AbstractFloat + | (F32(e1), F32(e2)) -> Literal::F32 + => float_dot(e1, e2), + (AbstractInt(e1), AbstractInt(e2)) -> Literal::AbstractInt + | (I32(e1), I32(e2)) -> Literal::I32 + | (U32(e1), U32(e2)) -> Literal::U32 + => int_dot(e1, e2)?, + _ => return Err(ConstantEvaluatorError::InvalidMathArg) }) .handle(self, span) } @@ -1546,14 +1580,11 @@ impl<'a> ConstantEvaluator<'a> { .into_iter() .collect() } - match (e1, e2) { - (LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => { - LiteralVector::AbstractFloat(float_cross(a, b)) - } - (LiteralVector::F32(a), LiteralVector::F32(b)) => { - LiteralVector::F32(float_cross(a, b)) - } - _ => return Err(ConstantEvaluatorError::InvalidMathArg), + match_literal_vector! {(e1, e2) => + (AbstractFloat(e1), AbstractFloat(e2)) -> LiteralVector::AbstractFloat + | (F32(e1), F32(e2)) -> LiteralVector::F32 + => float_cross(e1, e2), + _ => return Err(ConstantEvaluatorError::InvalidMathArg) } .handle(self, span) } else { @@ -1571,10 +1602,11 @@ impl<'a> ConstantEvaluator<'a> { e.iter().map(|&ei| ei * ei).sum::().sqrt() } - LiteralVector::from_literal(match e1 { - LiteralVector::AbstractFloat(a) => Literal::AbstractFloat(float_length(a)), - LiteralVector::F32(a) => Literal::F32(float_length(a)), - _ => return Err(ConstantEvaluatorError::InvalidMathArg), + LiteralVector::from_literal(match_literal_vector! {e1 => + AbstractFloat(e1) -> Literal::AbstractFloat + | F32(e1) -> Literal::F32 + => float_length(e1), + _ => return Err(ConstantEvaluatorError::InvalidMathArg) }) .handle(self, span) } @@ -1600,14 +1632,11 @@ impl<'a> ConstantEvaluator<'a> { .sum::() .sqrt() } - LiteralVector::from_literal(match (e1, e2) { - (LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => { - Literal::AbstractFloat(float_distance(a, b)) - } - (LiteralVector::F32(a), LiteralVector::F32(b)) => { - Literal::F32(float_distance(a, b)) - } - _ => return Err(ConstantEvaluatorError::InvalidMathArg), + LiteralVector::from_literal(match_literal_vector! {(e1, e2) => + (AbstractFloat(e1), AbstractFloat(e2)) -> Literal::AbstractFloat + | (F32(e1), F32(e2)) -> Literal::F32 + => float_distance(e1, e2), + _ => return Err(ConstantEvaluatorError::InvalidMathArg) }) .handle(self, span) }