mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-26 08:44:08 +00:00
wgsl: type inference for local variables
This commit is contained in:
parent
65fbbf1101
commit
beabd62d96
@ -492,8 +492,8 @@ impl<'function> Context<'function> {
|
||||
Some(self.add_expression(Expression::Load { pointer }, body)),
|
||||
meta,
|
||||
));
|
||||
},
|
||||
_ => {},
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,18 +1,14 @@
|
||||
use crate::{
|
||||
proc::ensure_block_returns, Arena, BinaryOperator, Block, Constant, ConstantInner, EntryPoint, Expression, Function,
|
||||
FunctionArgument, FunctionResult, Handle, ImageQuery, LocalVariable, MathFunction,
|
||||
RelationalFunction, SampleLevel, ScalarKind, ScalarValue, Statement, StructMember, SwizzleComponent, Type,
|
||||
TypeInner, VectorSize,
|
||||
proc::ensure_block_returns, Arena, BinaryOperator, Block, Constant, ConstantInner, EntryPoint,
|
||||
Expression, Function, FunctionArgument, FunctionResult, Handle, ImageQuery, LocalVariable,
|
||||
MathFunction, RelationalFunction, SampleLevel, ScalarKind, ScalarValue, Statement,
|
||||
StructMember, SwizzleComponent, Type, TypeInner, VectorSize,
|
||||
};
|
||||
|
||||
use super::{ast::*, error::ErrorKind, SourceMetadata};
|
||||
|
||||
impl Program<'_> {
|
||||
fn add_constant_value(
|
||||
&mut self,
|
||||
scalar_kind: ScalarKind,
|
||||
value: u64,
|
||||
) -> Handle<Constant> {
|
||||
fn add_constant_value(&mut self, scalar_kind: ScalarKind, value: u64) -> Handle<Constant> {
|
||||
let value = match scalar_kind {
|
||||
ScalarKind::Uint => ScalarValue::Uint(value),
|
||||
ScalarKind::Sint => ScalarValue::Sint(value as i64),
|
||||
@ -23,10 +19,7 @@ impl Program<'_> {
|
||||
self.module.constants.fetch_or_append(Constant {
|
||||
name: None,
|
||||
specialization: None,
|
||||
inner: ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value,
|
||||
},
|
||||
inner: ConstantInner::Scalar { width: 4, value },
|
||||
})
|
||||
}
|
||||
|
||||
@ -49,13 +42,16 @@ impl Program<'_> {
|
||||
let expr_type = self.resolve_type(ctx, args[0].0, args[0].1)?;
|
||||
|
||||
let vector_size = match *expr_type {
|
||||
TypeInner::Vector{ size, .. } => Some(size),
|
||||
TypeInner::Vector { size, .. } => Some(size),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Special case: if casting from a bool, we need to use Select and not As.
|
||||
match self.module.types[ty].inner.scalar_kind() {
|
||||
Some(result_scalar_kind) if expr_type.scalar_kind() == Some(ScalarKind::Bool) && result_scalar_kind != ScalarKind::Bool => {
|
||||
Some(result_scalar_kind)
|
||||
if expr_type.scalar_kind() == Some(ScalarKind::Bool)
|
||||
&& result_scalar_kind != ScalarKind::Bool =>
|
||||
{
|
||||
let c0 = self.add_constant_value(result_scalar_kind, 0u64);
|
||||
let c1 = self.add_constant_value(result_scalar_kind, 1u64);
|
||||
let mut reject = ctx.add_expression(Expression::Constant(c0), body);
|
||||
|
@ -11,6 +11,7 @@ use crate::{
|
||||
arena::{Arena, Handle},
|
||||
proc::{ResolveContext, ResolveError, TypeResolution},
|
||||
};
|
||||
use std::ops;
|
||||
|
||||
/// Helper class to emit expressions
|
||||
#[allow(dead_code)]
|
||||
@ -85,3 +86,10 @@ impl Typifier {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ops::Index<Handle<crate::Expression>> for Typifier {
|
||||
type Output = TypeResolution;
|
||||
fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
|
||||
&self.resolutions[handle.index()]
|
||||
}
|
||||
}
|
||||
|
@ -124,7 +124,8 @@ pub enum Error<'a> {
|
||||
ZeroSizeOrAlign(Span),
|
||||
InconsistentBinding(Span),
|
||||
UnknownLocalFunction(Span),
|
||||
LetTypeMismatch(Span, Handle<crate::Type>),
|
||||
InitializationTypeMismatch(Span, Handle<crate::Type>),
|
||||
MissingType(Span),
|
||||
Other,
|
||||
}
|
||||
|
||||
@ -326,11 +327,16 @@ impl<'a> Error<'a> {
|
||||
labels: vec![(span.clone(), "unknown local function".into())],
|
||||
notes: vec![],
|
||||
},
|
||||
Error::LetTypeMismatch(ref name_span, ref expected_ty) => ParseError {
|
||||
Error::InitializationTypeMismatch(ref name_span, ref expected_ty) => ParseError {
|
||||
message: format!("the type of `{}` is expected to be {:?}", &source[name_span.clone()], expected_ty),
|
||||
labels: vec![(name_span.clone(), format!("definition of `{}`", &source[name_span.clone()]).into())],
|
||||
notes: vec![],
|
||||
},
|
||||
Error::MissingType(ref name_span) => ParseError {
|
||||
message: format!("variable `{}` needs a type", &source[name_span.clone()]),
|
||||
labels: vec![(name_span.clone(), format!("definition of `{}`", &source[name_span.clone()]).into())],
|
||||
notes: vec![],
|
||||
},
|
||||
Error::Other => ParseError {
|
||||
message: "other error".to_string(),
|
||||
labels: vec![],
|
||||
@ -2567,7 +2573,7 @@ impl Parser {
|
||||
given_inner,
|
||||
expr_inner
|
||||
);
|
||||
return Err(Error::LetTypeMismatch(name_span, ty));
|
||||
return Err(Error::InitializationTypeMismatch(name_span, ty));
|
||||
}
|
||||
}
|
||||
block.extend(emitter.finish(context.expressions));
|
||||
@ -2583,24 +2589,69 @@ impl Parser {
|
||||
Variable(Handle<crate::Expression>),
|
||||
}
|
||||
|
||||
let (name, _name_span, ty, _access) =
|
||||
self.parse_variable_ident_decl(lexer, context.types, context.constants)?;
|
||||
let (name, name_span) = lexer.next_ident_with_span()?;
|
||||
let given_ty = if lexer.skip(Token::Separator(':')) {
|
||||
let (ty, _access) =
|
||||
self.parse_type_decl(lexer, None, context.types, context.constants)?;
|
||||
Some(ty)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let init = if lexer.skip(Token::Operation('=')) {
|
||||
let (init, ty) = if lexer.skip(Token::Operation('=')) {
|
||||
emitter.start(context.expressions);
|
||||
let value = self.parse_general_expression(
|
||||
lexer,
|
||||
context.as_expression(block, &mut emitter),
|
||||
)?;
|
||||
block.extend(emitter.finish(context.expressions));
|
||||
match context.expressions[value] {
|
||||
|
||||
// prepare the typifier, but work around mutable borrowing...
|
||||
let _ = context
|
||||
.as_expression(block, &mut emitter)
|
||||
.resolve_type(value)?;
|
||||
|
||||
//TODO: share more of this code with `let` arm
|
||||
let ty = match given_ty {
|
||||
Some(ty) => {
|
||||
let expr_inner = context.typifier.get(value, context.types);
|
||||
let given_inner = &context.types[ty].inner;
|
||||
if given_inner != expr_inner {
|
||||
log::error!(
|
||||
"Given type {:?} doesn't match expected {:?}",
|
||||
given_inner,
|
||||
expr_inner
|
||||
);
|
||||
return Err(Error::InitializationTypeMismatch(name_span, ty));
|
||||
}
|
||||
ty
|
||||
}
|
||||
None => {
|
||||
// register the type, if needed
|
||||
match context.typifier[value].clone() {
|
||||
TypeResolution::Handle(ty) => ty,
|
||||
TypeResolution::Value(inner) => context
|
||||
.types
|
||||
.fetch_or_append(crate::Type { name: None, inner }),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let init = match context.expressions[value] {
|
||||
crate::Expression::Constant(handle) if is_uniform_control_flow => {
|
||||
Init::Constant(handle)
|
||||
}
|
||||
_ => Init::Variable(value),
|
||||
}
|
||||
};
|
||||
(init, ty)
|
||||
} else {
|
||||
Init::Empty
|
||||
match given_ty {
|
||||
Some(ty) => (Init::Empty, ty),
|
||||
None => {
|
||||
log::error!("Variable '{}' without an initializer needs a type", name);
|
||||
return Err(Error::MissingType(name_span));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
lexer.expect(Token::Separator(';'))?;
|
||||
@ -3186,7 +3237,7 @@ impl Parser {
|
||||
crate::ConstantInner::Composite { ty, components: _ } => ty == explicit_ty,
|
||||
};
|
||||
if !type_match {
|
||||
return Err(Error::LetTypeMismatch(name_span, explicit_ty));
|
||||
return Err(Error::InitializationTypeMismatch(name_span, explicit_ty));
|
||||
}
|
||||
//TODO: check `ty` against `const_handle`.
|
||||
lexer.expect(Token::Separator(';'))?;
|
||||
|
@ -32,6 +32,8 @@ fn parse_type_inference() {
|
||||
fn foo() {
|
||||
let a = 2u;
|
||||
let b: u32 = a;
|
||||
var x = 3f32;
|
||||
var y = vec2<f32>(1, 2);
|
||||
}",
|
||||
)
|
||||
.unwrap();
|
||||
|
@ -146,7 +146,7 @@ fn bad_texture() {
|
||||
7 │ return textureSample(a, sampler, vec2<f32>(0.0));
|
||||
│ ^ not an image
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -186,7 +186,7 @@ fn bad_texture_sample_type() {
|
||||
3 │ [[group(0), binding(1)]] var texture : texture_2d<bool>;
|
||||
│ ^^^^ must be one of f32, i32 or u32
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -237,7 +237,7 @@ fn unknown_attribute() {
|
||||
2 │ [[a]]
|
||||
│ ^ unknown attribute
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -253,7 +253,7 @@ fn unknown_built_in() {
|
||||
2 │ fn x([[builtin(unknown_built_in)]] y: u32) {}
|
||||
│ ^^^^^^^^^^^^^^^^ unknown builtin
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -269,7 +269,7 @@ fn unknown_access() {
|
||||
2 │ var<storage> x: [[access(unknown_access)]] array<u32>;
|
||||
│ ^^^^^^^^^^^^^^ unknown access
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -285,7 +285,7 @@ fn unknown_shader_stage() {
|
||||
2 │ [[stage(geometry)]] fn main() {}
|
||||
│ ^^^^^^^^ unknown shader stage
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -303,7 +303,7 @@ fn unknown_ident() {
|
||||
3 │ let a = b;
|
||||
│ ^ unknown identifier
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -321,7 +321,7 @@ fn unknown_scalar_type() {
|
||||
│
|
||||
= note: Valid scalar types are f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, bool
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -337,7 +337,7 @@ fn unknown_type() {
|
||||
2 │ let a: Vec<f32>;
|
||||
│ ^^^ unknown type
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -353,7 +353,7 @@ fn unknown_storage_format() {
|
||||
2 │ let storage: [[access(read)]] texture_storage_1d<rgba>;
|
||||
│ ^^^^ unknown storage format
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -369,7 +369,7 @@ fn unknown_conservative_depth() {
|
||||
2 │ [[early_depth_test(abc)]] fn main() {}
|
||||
│ ^^^ unknown conservative depth
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -385,7 +385,7 @@ fn zero_array_stride() {
|
||||
2 │ type zero = [[stride(0)]] array<f32>;
|
||||
│ ^ array stride must not be zero
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -403,7 +403,7 @@ fn struct_member_zero_size() {
|
||||
3 │ [[size(0)]] data: array<f32>;
|
||||
│ ^ struct member size or alignment must not be 0
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -421,7 +421,7 @@ fn struct_member_zero_align() {
|
||||
3 │ [[align(0)]] data: array<f32>;
|
||||
│ ^ struct member size or alignment must not be 0
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -437,7 +437,7 @@ fn inconsistent_binding() {
|
||||
2 │ fn foo([[builtin(vertex_index), location(0)]] x: u32) {}
|
||||
│ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input/output binding is not consistent
|
||||
|
||||
"#
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user