wgsl: type inference for local variables

This commit is contained in:
Dzmitry Malyshau 2021-07-05 16:06:30 -04:00 committed by Dzmitry Malyshau
parent 65fbbf1101
commit beabd62d96
6 changed files with 99 additions and 42 deletions

View File

@ -492,8 +492,8 @@ impl<'function> Context<'function> {
Some(self.add_expression(Expression::Load { pointer }, body)),
meta,
));
},
_ => {},
}
_ => {}
}
}

View File

@ -1,18 +1,14 @@
use crate::{
proc::ensure_block_returns, Arena, BinaryOperator, Block, Constant, ConstantInner, EntryPoint, Expression, Function,
FunctionArgument, FunctionResult, Handle, ImageQuery, LocalVariable, MathFunction,
RelationalFunction, SampleLevel, ScalarKind, ScalarValue, Statement, StructMember, SwizzleComponent, Type,
TypeInner, VectorSize,
proc::ensure_block_returns, Arena, BinaryOperator, Block, Constant, ConstantInner, EntryPoint,
Expression, Function, FunctionArgument, FunctionResult, Handle, ImageQuery, LocalVariable,
MathFunction, RelationalFunction, SampleLevel, ScalarKind, ScalarValue, Statement,
StructMember, SwizzleComponent, Type, TypeInner, VectorSize,
};
use super::{ast::*, error::ErrorKind, SourceMetadata};
impl Program<'_> {
fn add_constant_value(
&mut self,
scalar_kind: ScalarKind,
value: u64,
) -> Handle<Constant> {
fn add_constant_value(&mut self, scalar_kind: ScalarKind, value: u64) -> Handle<Constant> {
let value = match scalar_kind {
ScalarKind::Uint => ScalarValue::Uint(value),
ScalarKind::Sint => ScalarValue::Sint(value as i64),
@ -23,10 +19,7 @@ impl Program<'_> {
self.module.constants.fetch_or_append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Scalar {
width: 4,
value,
},
inner: ConstantInner::Scalar { width: 4, value },
})
}
@ -49,13 +42,16 @@ impl Program<'_> {
let expr_type = self.resolve_type(ctx, args[0].0, args[0].1)?;
let vector_size = match *expr_type {
TypeInner::Vector{ size, .. } => Some(size),
TypeInner::Vector { size, .. } => Some(size),
_ => None,
};
// Special case: if casting from a bool, we need to use Select and not As.
match self.module.types[ty].inner.scalar_kind() {
Some(result_scalar_kind) if expr_type.scalar_kind() == Some(ScalarKind::Bool) && result_scalar_kind != ScalarKind::Bool => {
Some(result_scalar_kind)
if expr_type.scalar_kind() == Some(ScalarKind::Bool)
&& result_scalar_kind != ScalarKind::Bool =>
{
let c0 = self.add_constant_value(result_scalar_kind, 0u64);
let c1 = self.add_constant_value(result_scalar_kind, 1u64);
let mut reject = ctx.add_expression(Expression::Constant(c0), body);

View File

@ -11,6 +11,7 @@ use crate::{
arena::{Arena, Handle},
proc::{ResolveContext, ResolveError, TypeResolution},
};
use std::ops;
/// Helper class to emit expressions
#[allow(dead_code)]
@ -85,3 +86,10 @@ impl Typifier {
Ok(())
}
}
impl ops::Index<Handle<crate::Expression>> for Typifier {
type Output = TypeResolution;
fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
&self.resolutions[handle.index()]
}
}

View File

@ -124,7 +124,8 @@ pub enum Error<'a> {
ZeroSizeOrAlign(Span),
InconsistentBinding(Span),
UnknownLocalFunction(Span),
LetTypeMismatch(Span, Handle<crate::Type>),
InitializationTypeMismatch(Span, Handle<crate::Type>),
MissingType(Span),
Other,
}
@ -326,11 +327,16 @@ impl<'a> Error<'a> {
labels: vec![(span.clone(), "unknown local function".into())],
notes: vec![],
},
Error::LetTypeMismatch(ref name_span, ref expected_ty) => ParseError {
Error::InitializationTypeMismatch(ref name_span, ref expected_ty) => ParseError {
message: format!("the type of `{}` is expected to be {:?}", &source[name_span.clone()], expected_ty),
labels: vec![(name_span.clone(), format!("definition of `{}`", &source[name_span.clone()]).into())],
notes: vec![],
},
Error::MissingType(ref name_span) => ParseError {
message: format!("variable `{}` needs a type", &source[name_span.clone()]),
labels: vec![(name_span.clone(), format!("definition of `{}`", &source[name_span.clone()]).into())],
notes: vec![],
},
Error::Other => ParseError {
message: "other error".to_string(),
labels: vec![],
@ -2567,7 +2573,7 @@ impl Parser {
given_inner,
expr_inner
);
return Err(Error::LetTypeMismatch(name_span, ty));
return Err(Error::InitializationTypeMismatch(name_span, ty));
}
}
block.extend(emitter.finish(context.expressions));
@ -2583,24 +2589,69 @@ impl Parser {
Variable(Handle<crate::Expression>),
}
let (name, _name_span, ty, _access) =
self.parse_variable_ident_decl(lexer, context.types, context.constants)?;
let (name, name_span) = lexer.next_ident_with_span()?;
let given_ty = if lexer.skip(Token::Separator(':')) {
let (ty, _access) =
self.parse_type_decl(lexer, None, context.types, context.constants)?;
Some(ty)
} else {
None
};
let init = if lexer.skip(Token::Operation('=')) {
let (init, ty) = if lexer.skip(Token::Operation('=')) {
emitter.start(context.expressions);
let value = self.parse_general_expression(
lexer,
context.as_expression(block, &mut emitter),
)?;
block.extend(emitter.finish(context.expressions));
match context.expressions[value] {
// prepare the typifier, but work around mutable borrowing...
let _ = context
.as_expression(block, &mut emitter)
.resolve_type(value)?;
//TODO: share more of this code with `let` arm
let ty = match given_ty {
Some(ty) => {
let expr_inner = context.typifier.get(value, context.types);
let given_inner = &context.types[ty].inner;
if given_inner != expr_inner {
log::error!(
"Given type {:?} doesn't match expected {:?}",
given_inner,
expr_inner
);
return Err(Error::InitializationTypeMismatch(name_span, ty));
}
ty
}
None => {
// register the type, if needed
match context.typifier[value].clone() {
TypeResolution::Handle(ty) => ty,
TypeResolution::Value(inner) => context
.types
.fetch_or_append(crate::Type { name: None, inner }),
}
}
};
let init = match context.expressions[value] {
crate::Expression::Constant(handle) if is_uniform_control_flow => {
Init::Constant(handle)
}
_ => Init::Variable(value),
}
};
(init, ty)
} else {
Init::Empty
match given_ty {
Some(ty) => (Init::Empty, ty),
None => {
log::error!("Variable '{}' without an initializer needs a type", name);
return Err(Error::MissingType(name_span));
}
}
};
lexer.expect(Token::Separator(';'))?;
@ -3186,7 +3237,7 @@ impl Parser {
crate::ConstantInner::Composite { ty, components: _ } => ty == explicit_ty,
};
if !type_match {
return Err(Error::LetTypeMismatch(name_span, explicit_ty));
return Err(Error::InitializationTypeMismatch(name_span, explicit_ty));
}
//TODO: check `ty` against `const_handle`.
lexer.expect(Token::Separator(';'))?;

View File

@ -32,6 +32,8 @@ fn parse_type_inference() {
fn foo() {
let a = 2u;
let b: u32 = a;
var x = 3f32;
var y = vec2<f32>(1, 2);
}",
)
.unwrap();

View File

@ -146,7 +146,7 @@ fn bad_texture() {
7 return textureSample(a, sampler, vec2<f32>(0.0));
^ not an image
"#
"#,
);
}
@ -186,7 +186,7 @@ fn bad_texture_sample_type() {
3 [[group(0), binding(1)]] var texture : texture_2d<bool>;
^^^^ must be one of f32, i32 or u32
"#
"#,
);
}
@ -237,7 +237,7 @@ fn unknown_attribute() {
2 [[a]]
^ unknown attribute
"#
"#,
);
}
@ -253,7 +253,7 @@ fn unknown_built_in() {
2 fn x([[builtin(unknown_built_in)]] y: u32) {}
^^^^^^^^^^^^^^^^ unknown builtin
"#
"#,
);
}
@ -269,7 +269,7 @@ fn unknown_access() {
2 var<storage> x: [[access(unknown_access)]] array<u32>;
^^^^^^^^^^^^^^ unknown access
"#
"#,
);
}
@ -285,7 +285,7 @@ fn unknown_shader_stage() {
2 [[stage(geometry)]] fn main() {}
^^^^^^^^ unknown shader stage
"#
"#,
);
}
@ -303,7 +303,7 @@ fn unknown_ident() {
3 let a = b;
^ unknown identifier
"#
"#,
);
}
@ -321,7 +321,7 @@ fn unknown_scalar_type() {
= note: Valid scalar types are f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, bool
"#
"#,
);
}
@ -337,7 +337,7 @@ fn unknown_type() {
2 let a: Vec<f32>;
^^^ unknown type
"#
"#,
);
}
@ -353,7 +353,7 @@ fn unknown_storage_format() {
2 let storage: [[access(read)]] texture_storage_1d<rgba>;
^^^^ unknown storage format
"#
"#,
);
}
@ -369,7 +369,7 @@ fn unknown_conservative_depth() {
2 [[early_depth_test(abc)]] fn main() {}
^^^ unknown conservative depth
"#
"#,
);
}
@ -385,7 +385,7 @@ fn zero_array_stride() {
2 type zero = [[stride(0)]] array<f32>;
^ array stride must not be zero
"#
"#,
);
}
@ -403,7 +403,7 @@ fn struct_member_zero_size() {
3 [[size(0)]] data: array<f32>;
^ struct member size or alignment must not be 0
"#
"#,
);
}
@ -421,7 +421,7 @@ fn struct_member_zero_align() {
3 [[align(0)]] data: array<f32>;
^ struct member size or alignment must not be 0
"#
"#,
);
}
@ -437,7 +437,7 @@ fn inconsistent_binding() {
2 fn foo([[builtin(vertex_index), location(0)]] x: u32) {}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input/output binding is not consistent
"#
"#,
);
}