Add matrix access to ConstantSolver

This commit is contained in:
João Capucho 2021-02-01 21:36:30 +00:00 committed by Dzmitry Malyshau
parent fcbf2aa4c4
commit bb88aac937

View File

@ -1,6 +1,6 @@
use crate::{
arena::{Arena, Handle},
ArraySize, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, UnaryOperator,
Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, UnaryOperator,
};
#[derive(Debug)]
@ -134,36 +134,20 @@ impl<'a> ConstantSolver<'a> {
match self.constants[base].inner {
crate::ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
crate::ConstantInner::Composite { ty, ref components } => match self.types[ty].inner {
crate::TypeInner::Vector { size, .. } => {
if size as usize <= index {
Err(ConstantSolvingError::InvalidAccessIndex)
} else {
Ok(components[index])
}
crate::ConstantInner::Composite { ty, ref components } => {
match self.types[ty].inner {
crate::TypeInner::Vector { .. }
| crate::TypeInner::Matrix { .. }
| crate::TypeInner::Array { .. }
| crate::TypeInner::Struct { .. } => (),
_ => return Err(ConstantSolvingError::InvalidAccessBase),
}
crate::TypeInner::Matrix { .. } => todo!(),
crate::TypeInner::Array { size, .. } => match size {
ArraySize::Constant(constant) => {
let size = self.constant_index(constant)?;
if size <= index {
Err(ConstantSolvingError::InvalidAccessIndex)
} else {
Ok(components[index])
}
}
ArraySize::Dynamic => Err(ConstantSolvingError::ArrayLengthDynamic),
},
crate::TypeInner::Struct { ref members, .. } => {
if members.len() <= index {
Err(ConstantSolvingError::InvalidAccessIndex)
} else {
Ok(components[index])
}
}
_ => Err(ConstantSolvingError::InvalidAccessBase),
},
components
.get(index)
.copied()
.ok_or(ConstantSolvingError::InvalidAccessIndex)
}
}
}
@ -198,11 +182,11 @@ impl<'a> ConstantSolver<'a> {
ConstantInner::Scalar { ref mut value, .. } => {
let intial = value.clone();
match kind {
ScalarKind::Sint => *value = ScalarValue::Sint(inner_cast(intial)),
ScalarKind::Uint => *value = ScalarValue::Uint(inner_cast(intial)),
ScalarKind::Float => *value = ScalarValue::Float(inner_cast(intial)),
ScalarKind::Bool => *value = ScalarValue::Bool(inner_cast::<u64>(intial) != 0),
*value = match kind {
ScalarKind::Sint => ScalarValue::Sint(inner_cast(intial)),
ScalarKind::Uint => ScalarValue::Uint(inner_cast(intial)),
ScalarKind::Float => ScalarValue::Float(inner_cast(intial)),
ScalarKind::Bool => ScalarValue::Bool(inner_cast::<u64>(intial) != 0),
}
}
ConstantInner::Composite {
@ -432,4 +416,141 @@ mod tests {
},
);
}
#[test]
fn access() {
let mut types = Arena::new();
let mut expressions = Arena::new();
let mut constants = Arena::new();
let matrix_ty = types.append(Type {
name: None,
inner: TypeInner::Matrix {
columns: VectorSize::Bi,
rows: VectorSize::Tri,
width: 4,
},
});
let vec_ty = types.append(Type {
name: None,
inner: TypeInner::Vector {
size: VectorSize::Tri,
kind: ScalarKind::Float,
width: 4,
},
});
let mut vec1_components = Vec::with_capacity(3);
let mut vec2_components = Vec::with_capacity(3);
for i in 0..3 {
let h = constants.append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Scalar {
width: 4,
value: ScalarValue::Float(i as f64),
},
});
vec1_components.push(h)
}
for i in 3..6 {
let h = constants.append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Scalar {
width: 4,
value: ScalarValue::Float(i as f64),
},
});
vec2_components.push(h)
}
let vec1 = constants.append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Composite {
ty: vec_ty,
components: vec1_components,
},
});
let vec2 = constants.append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Composite {
ty: vec_ty,
components: vec2_components,
},
});
let h = constants.append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Composite {
ty: matrix_ty,
components: vec![vec1, vec2],
},
});
let base = expressions.append(Expression::Constant(h));
let root1 = expressions.append(Expression::AccessIndex { base, index: 1 });
let root2 = expressions.append(Expression::AccessIndex {
base: root1,
index: 2,
});
let mut solver = ConstantSolver {
types: &types,
expressions: &expressions,
constants: &mut constants,
};
let res1 = solver.solve(root1).unwrap();
let res2 = solver.solve(root2).unwrap();
let res1_inner = &constants[res1].inner;
match res1_inner {
ConstantInner::Composite { ty, components } => {
assert_eq!(*ty, vec_ty);
let mut components_iter = components.iter().copied();
assert_eq!(
constants[components_iter.next().unwrap()].inner,
ConstantInner::Scalar {
width: 4,
value: ScalarValue::Float(3.),
},
);
assert_eq!(
constants[components_iter.next().unwrap()].inner,
ConstantInner::Scalar {
width: 4,
value: ScalarValue::Float(4.),
},
);
assert_eq!(
constants[components_iter.next().unwrap()].inner,
ConstantInner::Scalar {
width: 4,
value: ScalarValue::Float(5.),
},
);
assert!(components_iter.next().is_none());
}
_ => panic!("Expected vector"),
}
assert_eq!(
constants[res2].inner,
ConstantInner::Scalar {
width: 4,
value: ScalarValue::Float(5.),
},
);
}
}