Fix matrix multiplication types

This commit is contained in:
Dzmitry Malyshau 2021-04-01 15:43:33 -04:00 committed by Dzmitry Malyshau
parent cbdfbed32e
commit 8ff27187d1
3 changed files with 47 additions and 40 deletions

View File

@ -369,43 +369,53 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::Divide | crate::BinaryOperator::Divide
| crate::BinaryOperator::Modulo => past(left).clone(), | crate::BinaryOperator::Modulo => past(left).clone(),
crate::BinaryOperator::Multiply => { crate::BinaryOperator::Multiply => {
let res_left = past(left); let (res_left, res_right) = (past(left), past(right));
let ty_left = res_left.inner_with(types); match (res_left.inner_with(types), res_right.inner_with(types)) {
let res_right = past(right); (
let ty_right = res_right.inner_with(types); &Ti::Matrix {
if ty_left == ty_right { columns: _,
res_left.clone() rows,
} else if let Ti::Scalar { .. } = *ty_left { width,
res_right.clone() },
} else if let Ti::Scalar { .. } = *ty_right { &Ti::Matrix { columns, .. },
res_left.clone() ) => TypeResolution::Value(Ti::Matrix {
} else if let Ti::Matrix { columns,
columns: _, rows,
rows, width,
width, }),
} = *ty_left (
{ &Ti::Matrix {
TypeResolution::Value(Ti::Vector { columns: _,
rows,
width,
},
&Ti::Vector { .. },
) => TypeResolution::Value(Ti::Vector {
size: rows, size: rows,
kind: crate::ScalarKind::Float, kind: crate::ScalarKind::Float,
width, width,
}) }),
} else if let Ti::Matrix { (
columns, &Ti::Vector { .. },
rows: _, &Ti::Matrix {
width, columns,
} = *ty_right rows: _,
{ width,
TypeResolution::Value(Ti::Vector { },
) => TypeResolution::Value(Ti::Vector {
size: columns, size: columns,
kind: crate::ScalarKind::Float, kind: crate::ScalarKind::Float,
width, width,
}) }),
} else { (&Ti::Scalar { .. }, _) => res_right.clone(),
return Err(ResolveError::IncompatibleOperands(format!( (_, &Ti::Scalar { .. }) => res_left.clone(),
"{:?} * {:?}", (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
ty_left, ty_right (tl, tr) => {
))); return Err(ResolveError::IncompatibleOperands(format!(
"{:?} * {:?}",
tl, tr
)))
}
} }
} }
crate::BinaryOperator::Equal crate::BinaryOperator::Equal

View File

@ -265,10 +265,7 @@ expression: output
), ),
ref_count: 1, ref_count: 1,
assignable_global: None, assignable_global: None,
ty: Value(Scalar( ty: Handle(1),
kind: Uint,
width: 4,
)),
), ),
( (
uniformity: ( uniformity: (
@ -293,10 +290,7 @@ expression: output
), ),
ref_count: 1, ref_count: 1,
assignable_global: None, assignable_global: None,
ty: Value(Scalar( ty: Handle(1),
kind: Uint,
width: 4,
)),
), ),
( (
uniformity: ( uniformity: (

View File

@ -2416,7 +2416,10 @@ expression: output
), ),
ref_count: 1, ref_count: 1,
assignable_global: None, assignable_global: None,
ty: Handle(1), ty: Value(Scalar(
kind: Float,
width: 4,
)),
), ),
( (
uniformity: ( uniformity: (