Fix matrix multiplication types

This commit is contained in:
Dzmitry Malyshau 2021-04-01 15:43:33 -04:00 committed by Dzmitry Malyshau
parent cbdfbed32e
commit 8ff27187d1
3 changed files with 47 additions and 40 deletions

View File

@ -369,43 +369,53 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::Divide
| crate::BinaryOperator::Modulo => past(left).clone(),
crate::BinaryOperator::Multiply => {
let res_left = past(left);
let ty_left = res_left.inner_with(types);
let res_right = past(right);
let ty_right = res_right.inner_with(types);
if ty_left == ty_right {
res_left.clone()
} else if let Ti::Scalar { .. } = *ty_left {
res_right.clone()
} else if let Ti::Scalar { .. } = *ty_right {
res_left.clone()
} else if let Ti::Matrix {
columns: _,
rows,
width,
} = *ty_left
{
TypeResolution::Value(Ti::Vector {
let (res_left, res_right) = (past(left), past(right));
match (res_left.inner_with(types), res_right.inner_with(types)) {
(
&Ti::Matrix {
columns: _,
rows,
width,
},
&Ti::Matrix { columns, .. },
) => TypeResolution::Value(Ti::Matrix {
columns,
rows,
width,
}),
(
&Ti::Matrix {
columns: _,
rows,
width,
},
&Ti::Vector { .. },
) => TypeResolution::Value(Ti::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
})
} else if let Ti::Matrix {
columns,
rows: _,
width,
} = *ty_right
{
TypeResolution::Value(Ti::Vector {
}),
(
&Ti::Vector { .. },
&Ti::Matrix {
columns,
rows: _,
width,
},
) => TypeResolution::Value(Ti::Vector {
size: columns,
kind: crate::ScalarKind::Float,
width,
})
} else {
return Err(ResolveError::IncompatibleOperands(format!(
"{:?} * {:?}",
ty_left, ty_right
)));
}),
(&Ti::Scalar { .. }, _) => res_right.clone(),
(_, &Ti::Scalar { .. }) => res_left.clone(),
(&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
(tl, tr) => {
return Err(ResolveError::IncompatibleOperands(format!(
"{:?} * {:?}",
tl, tr
)))
}
}
}
crate::BinaryOperator::Equal

View File

@ -265,10 +265,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
ty: Handle(1),
),
(
uniformity: (
@ -293,10 +290,7 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
ty: Handle(1),
),
(
uniformity: (

View File

@ -2416,7 +2416,10 @@ expression: output
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
ty: Value(Scalar(
kind: Float,
width: 4,
)),
),
(
uniformity: (