From bb88aac937f2eed5ffe6b4ab2d137d32864392eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Capucho?= Date: Mon, 1 Feb 2021 21:36:30 +0000 Subject: [PATCH] Add matrix access to ConstantSolver --- src/proc/constants.rs | 189 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 155 insertions(+), 34 deletions(-) diff --git a/src/proc/constants.rs b/src/proc/constants.rs index 8097ed80c..d7b28b2ba 100644 --- a/src/proc/constants.rs +++ b/src/proc/constants.rs @@ -1,6 +1,6 @@ use crate::{ arena::{Arena, Handle}, - ArraySize, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, UnaryOperator, + Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, UnaryOperator, }; #[derive(Debug)] @@ -134,36 +134,20 @@ impl<'a> ConstantSolver<'a> { match self.constants[base].inner { crate::ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase), - crate::ConstantInner::Composite { ty, ref components } => match self.types[ty].inner { - crate::TypeInner::Vector { size, .. } => { - if size as usize <= index { - Err(ConstantSolvingError::InvalidAccessIndex) - } else { - Ok(components[index]) - } + crate::ConstantInner::Composite { ty, ref components } => { + match self.types[ty].inner { + crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } + | crate::TypeInner::Array { .. } + | crate::TypeInner::Struct { .. } => (), + _ => return Err(ConstantSolvingError::InvalidAccessBase), } - crate::TypeInner::Matrix { .. } => todo!(), - crate::TypeInner::Array { size, .. } => match size { - ArraySize::Constant(constant) => { - let size = self.constant_index(constant)?; - if size <= index { - Err(ConstantSolvingError::InvalidAccessIndex) - } else { - Ok(components[index]) - } - } - ArraySize::Dynamic => Err(ConstantSolvingError::ArrayLengthDynamic), - }, - crate::TypeInner::Struct { ref members, .. } => { - if members.len() <= index { - Err(ConstantSolvingError::InvalidAccessIndex) - } else { - Ok(components[index]) - } - } - _ => Err(ConstantSolvingError::InvalidAccessBase), - }, + components + .get(index) + .copied() + .ok_or(ConstantSolvingError::InvalidAccessIndex) + } } } @@ -198,11 +182,11 @@ impl<'a> ConstantSolver<'a> { ConstantInner::Scalar { ref mut value, .. } => { let intial = value.clone(); - match kind { - ScalarKind::Sint => *value = ScalarValue::Sint(inner_cast(intial)), - ScalarKind::Uint => *value = ScalarValue::Uint(inner_cast(intial)), - ScalarKind::Float => *value = ScalarValue::Float(inner_cast(intial)), - ScalarKind::Bool => *value = ScalarValue::Bool(inner_cast::(intial) != 0), + *value = match kind { + ScalarKind::Sint => ScalarValue::Sint(inner_cast(intial)), + ScalarKind::Uint => ScalarValue::Uint(inner_cast(intial)), + ScalarKind::Float => ScalarValue::Float(inner_cast(intial)), + ScalarKind::Bool => ScalarValue::Bool(inner_cast::(intial) != 0), } } ConstantInner::Composite { @@ -432,4 +416,141 @@ mod tests { }, ); } + + #[test] + fn access() { + let mut types = Arena::new(); + let mut expressions = Arena::new(); + let mut constants = Arena::new(); + + let matrix_ty = types.append(Type { + name: None, + inner: TypeInner::Matrix { + columns: VectorSize::Bi, + rows: VectorSize::Tri, + width: 4, + }, + }); + + let vec_ty = types.append(Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Tri, + kind: ScalarKind::Float, + width: 4, + }, + }); + + let mut vec1_components = Vec::with_capacity(3); + let mut vec2_components = Vec::with_capacity(3); + + for i in 0..3 { + let h = constants.append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value: ScalarValue::Float(i as f64), + }, + }); + + vec1_components.push(h) + } + + for i in 3..6 { + let h = constants.append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value: ScalarValue::Float(i as f64), + }, + }); + + vec2_components.push(h) + } + + let vec1 = constants.append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Composite { + ty: vec_ty, + components: vec1_components, + }, + }); + + let vec2 = constants.append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Composite { + ty: vec_ty, + components: vec2_components, + }, + }); + + let h = constants.append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Composite { + ty: matrix_ty, + components: vec![vec1, vec2], + }, + }); + + let base = expressions.append(Expression::Constant(h)); + let root1 = expressions.append(Expression::AccessIndex { base, index: 1 }); + let root2 = expressions.append(Expression::AccessIndex { + base: root1, + index: 2, + }); + + let mut solver = ConstantSolver { + types: &types, + expressions: &expressions, + constants: &mut constants, + }; + + let res1 = solver.solve(root1).unwrap(); + let res2 = solver.solve(root2).unwrap(); + + let res1_inner = &constants[res1].inner; + + match res1_inner { + ConstantInner::Composite { ty, components } => { + assert_eq!(*ty, vec_ty); + let mut components_iter = components.iter().copied(); + assert_eq!( + constants[components_iter.next().unwrap()].inner, + ConstantInner::Scalar { + width: 4, + value: ScalarValue::Float(3.), + }, + ); + assert_eq!( + constants[components_iter.next().unwrap()].inner, + ConstantInner::Scalar { + width: 4, + value: ScalarValue::Float(4.), + }, + ); + assert_eq!( + constants[components_iter.next().unwrap()].inner, + ConstantInner::Scalar { + width: 4, + value: ScalarValue::Float(5.), + }, + ); + assert!(components_iter.next().is_none()); + } + _ => panic!("Expected vector"), + } + + assert_eq!( + constants[res2].inner, + ConstantInner::Scalar { + width: 4, + value: ScalarValue::Float(5.), + }, + ); + } }