diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 4c7f8b325..5467b52bc 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -47,7 +47,7 @@ pub use features::Features; use crate::{ back::{self, Baked}, - proc::{self, NameKey}, + proc::{self, ExpressionKindTracker, NameKey}, valid, Handle, ShaderStage, TypeInner, }; use features::FeaturesManager; @@ -1584,6 +1584,7 @@ impl<'a, W: Write> Writer<'a, W> { info, expressions: &func.expressions, named_expressions: &func.named_expressions, + expr_kind_tracker: ExpressionKindTracker::from_arena(&func.expressions), }; self.named_expressions.clear(); diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 85d943e85..e33fc79f2 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -8,7 +8,7 @@ use super::{ }; use crate::{ back::{self, Baked}, - proc::{self, NameKey}, + proc::{self, ExpressionKindTracker, NameKey}, valid, Handle, Module, Scalar, ScalarKind, ShaderStage, TypeInner, }; use std::{fmt, mem}; @@ -346,6 +346,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { info, expressions: &function.expressions, named_expressions: &function.named_expressions, + expr_kind_tracker: ExpressionKindTracker::from_arena(&function.expressions), }; let name = self.names[&NameKey::Function(handle)].clone(); @@ -386,6 +387,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { info, expressions: &ep.function.expressions, named_expressions: &ep.function.named_expressions, + expr_kind_tracker: ExpressionKindTracker::from_arena(&ep.function.expressions), }; self.write_wrapped_functions(module, &ctx)?; diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 352adc37e..58c7fa02c 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -3,6 +3,8 @@ Backend functions that export shader [`Module`](super::Module)s into binary and */ #![allow(dead_code)] // can be dead if none of the enabled backends need it +use crate::proc::ExpressionKindTracker; + #[cfg(dot_out)] pub mod dot; #[cfg(glsl_out)] @@ -118,6 +120,8 @@ pub struct FunctionCtx<'a> { pub expressions: &'a crate::Arena, /// Map of expressions that have associated variable names pub named_expressions: &'a crate::NamedExpressions, + /// For constness checks + pub expr_kind_tracker: ExpressionKindTracker, } impl FunctionCtx<'_> { diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 77ee530be..5f82862f7 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -289,6 +289,7 @@ fn process_function( &mut local_expression_kind_tracker, &mut emitter, &mut block, + false, ); for (old_h, mut expr, span) in expressions.drain() { diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index e5a5e5f64..acbd532ed 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1,7 +1,7 @@ use super::Error; use crate::{ back::{self, Baked}, - proc::{self, NameKey}, + proc::{self, ExpressionKindTracker, NameKey}, valid, Handle, Module, ShaderStage, TypeInner, }; use std::fmt::Write; @@ -166,6 +166,7 @@ impl Writer { info: fun_info, expressions: &function.expressions, named_expressions: &function.named_expressions, + expr_kind_tracker: ExpressionKindTracker::from_arena(&function.expressions), }; // Write the function @@ -193,6 +194,7 @@ impl Writer { info: info.get_entry_point(index), expressions: &ep.function.expressions, named_expressions: &ep.function.named_expressions, + expr_kind_tracker: ExpressionKindTracker::from_arena(&ep.function.expressions), }; self.write_function(module, &ep.function, &func_ctx)?; @@ -1115,8 +1117,14 @@ impl Writer { func_ctx: &back::FunctionCtx, name: &str, ) -> BackendResult { + // Some functions are marked as const, but are not yet implemented as constant expression + let quantifier = if func_ctx.expr_kind_tracker.is_impl_const(handle) { + "const" + } else { + "let" + }; // Write variable name - write!(self.out, "let {name}")?; + write!(self.out, "{quantifier} {name}")?; if self.flags.contains(WriterFlags::EXPLICIT_TYPES) { write!(self.out, ": ")?; let ty = &func_ctx.info[handle].ty; diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index ceed9c399..100d91c61 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -98,7 +98,7 @@ impl<'source> GlobalContext<'source, '_, '_> { types: self.types, module: self.module, const_typifier: self.const_typifier, - expr_type: ExpressionContextType::Constant, + expr_type: ExpressionContextType::Constant(None), global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -160,7 +160,8 @@ pub struct StatementContext<'source, 'temp, 'out> { /// /// [`LocalVariable`]: crate::Expression::LocalVariable /// [`FunctionArgument`]: crate::Expression::FunctionArgument - local_table: &'temp mut FastHashMap, Typed>>, + local_table: + &'temp mut FastHashMap, Declared>>>, const_typifier: &'temp mut Typifier, typifier: &'temp mut Typifier, @@ -184,6 +185,32 @@ pub struct StatementContext<'source, 'temp, 'out> { } impl<'a, 'temp> StatementContext<'a, 'temp, '_> { + fn as_const<'t>( + &'t mut self, + block: &'t mut crate::Block, + emitter: &'t mut Emitter, + ) -> ExpressionContext<'a, 't, '_> + where + 'temp: 't, + { + ExpressionContext { + globals: self.globals, + types: self.types, + ast_expressions: self.ast_expressions, + const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, + module: self.module, + expr_type: ExpressionContextType::Constant(Some(LocalExpressionContext { + local_table: self.local_table, + function: self.function, + block, + emitter, + typifier: self.typifier, + local_expression_kind_tracker: self.local_expression_kind_tracker, + })), + } + } + fn as_expression<'t>( &'t mut self, block: &'t mut crate::Block, @@ -199,7 +226,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { const_typifier: self.const_typifier, global_expression_kind_tracker: self.global_expression_kind_tracker, module: self.module, - expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext { + expr_type: ExpressionContextType::Runtime(LocalExpressionContext { local_table: self.local_table, function: self.function, block, @@ -235,12 +262,12 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { } } -pub struct RuntimeExpressionContext<'temp, 'out> { +pub struct LocalExpressionContext<'temp, 'out> { /// A map from [`ast::Local`] handles to the Naga expressions we've built for them. /// /// This is always [`StatementContext::local_table`] for the /// enclosing statement; see that documentation for details. - local_table: &'temp FastHashMap, Typed>>, + local_table: &'temp FastHashMap, Declared>>>, function: &'out mut crate::Function, block: &'temp mut crate::Block, @@ -262,15 +289,15 @@ pub enum ExpressionContextType<'temp, 'out> { /// The given [`RuntimeExpressionContext`] holds information about local /// variables, arguments, and other definitions available only to runtime /// expressions, not constant or override expressions. - Runtime(RuntimeExpressionContext<'temp, 'out>), + Runtime(LocalExpressionContext<'temp, 'out>), /// We are lowering to a constant expression, to be included in the module's /// constant expression arena. /// - /// Everything constant expressions are allowed to refer to is - /// available in the [`ExpressionContext`], so this variant - /// carries no further information. - Constant, + /// Everything global constant expressions are allowed to refer to is + /// available in the [`ExpressionContext`], but local constant expressions can + /// also refer to other + Constant(Option>), /// We are lowering to an override expression, to be included in the module's /// constant expression arena. @@ -352,7 +379,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ast_expressions: self.ast_expressions, const_typifier: self.const_typifier, module: self.module, - expr_type: ExpressionContextType::Constant, + expr_type: ExpressionContextType::Constant(None), global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -376,8 +403,19 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { rctx.local_expression_kind_tracker, rctx.emitter, rctx.block, + false, ), - ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module( + ExpressionContextType::Constant(Some(ref mut rctx)) => { + ConstantEvaluator::for_wgsl_function( + self.module, + &mut rctx.function.expressions, + rctx.local_expression_kind_tracker, + rctx.emitter, + rctx.block, + true, + ) + } + ExpressionContextType::Constant(None) => ConstantEvaluator::for_wgsl_module( self.module, self.global_expression_kind_tracker, false, @@ -412,15 +450,27 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .eval_expr_to_u32_from(handle, &ctx.function.expressions) .ok() } - ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(), + ExpressionContextType::Constant(Some(ref ctx)) => { + assert!(ctx.local_expression_kind_tracker.is_const(handle)); + self.module + .to_ctx() + .eval_expr_to_u32_from(handle, &ctx.function.expressions) + .ok() + } + ExpressionContextType::Constant(None) => { + self.module.to_ctx().eval_expr_to_u32(handle).ok() + } ExpressionContextType::Override => None, } } fn get_expression_span(&self, handle: Handle) -> Span { match self.expr_type { - ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle), - ExpressionContextType::Constant | ExpressionContextType::Override => { + ExpressionContextType::Runtime(ref ctx) + | ExpressionContextType::Constant(Some(ref ctx)) => { + ctx.function.expressions.get_span(handle) + } + ExpressionContextType::Constant(None) | ExpressionContextType::Override => { self.module.global_expressions.get_span(handle) } } @@ -428,20 +478,35 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { fn typifier(&self) -> &Typifier { match self.expr_type { - ExpressionContextType::Runtime(ref ctx) => ctx.typifier, - ExpressionContextType::Constant | ExpressionContextType::Override => { + ExpressionContextType::Runtime(ref ctx) + | ExpressionContextType::Constant(Some(ref ctx)) => ctx.typifier, + ExpressionContextType::Constant(None) | ExpressionContextType::Override => { self.const_typifier } } } + fn local( + &mut self, + local: &Handle, + span: Span, + ) -> Result>, Error<'source>> { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => Ok(ctx.local_table[local].runtime()), + ExpressionContextType::Constant(Some(ref ctx)) => ctx.local_table[local] + .const_time() + .ok_or(Error::UnexpectedOperationInConstContext(span)), + _ => Err(Error::UnexpectedOperationInConstContext(span)), + } + } + fn runtime_expression_ctx( &mut self, span: Span, - ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> { + ) -> Result<&mut LocalExpressionContext<'temp, 'out>, Error<'source>> { match self.expr_type { ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx), - ExpressionContextType::Constant | ExpressionContextType::Override => { + ExpressionContextType::Constant(_) | ExpressionContextType::Override => { Err(Error::UnexpectedOperationInConstContext(span)) } } @@ -480,7 +545,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { } // This means a `gather` operation appeared in a constant expression. // This error refers to the `gather` itself, not its "component" argument. - ExpressionContextType::Constant | ExpressionContextType::Override => { + ExpressionContextType::Constant(_) | ExpressionContextType::Override => { Err(Error::UnexpectedOperationInConstContext(gather_span)) } } @@ -505,8 +570,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { // except that this lets the borrow checker see that it's okay // to also borrow self.module.types mutably below. let typifier = match self.expr_type { - ExpressionContextType::Runtime(ref ctx) => ctx.typifier, - ExpressionContextType::Constant | ExpressionContextType::Override => { + ExpressionContextType::Runtime(ref ctx) + | ExpressionContextType::Constant(Some(ref ctx)) => ctx.typifier, + ExpressionContextType::Constant(None) | ExpressionContextType::Override => { &*self.const_typifier } }; @@ -542,7 +608,8 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { let typifier; let expressions; match self.expr_type { - ExpressionContextType::Runtime(ref mut ctx) => { + ExpressionContextType::Runtime(ref mut ctx) + | ExpressionContextType::Constant(Some(ref mut ctx)) => { resolve_ctx = ResolveContext::with_locals( self.module, &ctx.function.local_variables, @@ -551,7 +618,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { typifier = &mut *ctx.typifier; expressions = &ctx.function.expressions; } - ExpressionContextType::Constant | ExpressionContextType::Override => { + ExpressionContextType::Constant(None) | ExpressionContextType::Override => { resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]); typifier = self.const_typifier; expressions = &self.module.global_expressions; @@ -643,18 +710,20 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { span: Span, ) -> Result, Error<'source>> { match self.expr_type { - ExpressionContextType::Runtime(ref mut rctx) => { + ExpressionContextType::Runtime(ref mut rctx) + | ExpressionContextType::Constant(Some(ref mut rctx)) => { rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); } - ExpressionContextType::Constant | ExpressionContextType::Override => {} + ExpressionContextType::Constant(None) | ExpressionContextType::Override => {} } let result = self.append_expression(expression, span); match self.expr_type { - ExpressionContextType::Runtime(ref mut rctx) => { + ExpressionContextType::Runtime(ref mut rctx) + | ExpressionContextType::Constant(Some(ref mut rctx)) => { rctx.emitter.start(&rctx.function.expressions); } - ExpressionContextType::Constant | ExpressionContextType::Override => {} + ExpressionContextType::Constant(None) | ExpressionContextType::Override => {} } result } @@ -718,6 +787,30 @@ impl<'source> ArgumentContext<'_, 'source> { } } +#[derive(Debug, Copy, Clone)] +enum Declared { + /// Value declared as const + Const(T), + + /// Value declared as non-const + Runtime(T), +} + +impl Declared { + fn runtime(self) -> T { + match self { + Declared::Const(t) | Declared::Runtime(t) => t, + } + } + + fn const_time(self) -> Option { + match self { + Declared::Const(t) => Some(t), + Declared::Runtime(_) => None, + } + } +} + /// WGSL type annotations on expressions, types, values, etc. /// /// Naga and WGSL types are very close, but Naga lacks WGSL's `ref` types, which @@ -1120,7 +1213,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let ty = self.resolve_ast_type(arg.ty, ctx)?; let expr = expressions .append(crate::Expression::FunctionArgument(i as u32), arg.name.span); - local_table.insert(arg.handle, Typed::Plain(expr)); + local_table.insert(arg.handle, Declared::Runtime(Typed::Plain(expr))); named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span)); local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime); @@ -1268,7 +1361,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } block.extend(emitter.finish(&ctx.function.expressions)); - ctx.local_table.insert(l.handle, Typed::Plain(value)); + ctx.local_table + .insert(l.handle, Declared::Runtime(Typed::Plain(value))); ctx.named_expressions .insert(value, (l.name.name.to_string(), l.name.span)); @@ -1350,7 +1444,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Span::UNDEFINED, )?; block.extend(emitter.finish(&ctx.function.expressions)); - ctx.local_table.insert(v.handle, Typed::Reference(handle)); + ctx.local_table + .insert(v.handle, Declared::Runtime(Typed::Reference(handle))); match initializer { Some(initializer) => crate::Statement::Store { @@ -1360,6 +1455,41 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { None => return Ok(()), } } + ast::LocalDecl::Const(ref c) => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let ectx = &mut ctx.as_const(block, &mut emitter); + + let mut init = self.expression_for_abstract(c.init, ectx)?; + + if let Some(explicit_ty) = c.ty { + let explicit_ty = + self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?; + let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty); + init = ectx + .try_automatic_conversions(init, &explicit_ty_res, c.name.span) + .map_err(|error| match error { + Error::AutoConversion(error) => Error::InitializationTypeMismatch { + name: c.name.span, + expected: error.dest_type, + got: error.source_type, + }, + other => other, + })?; + } else { + init = ectx.concretize(init)?; + ectx.register_type(init)?; + } + + block.extend(emitter.finish(&ctx.function.expressions)); + ctx.local_table + .insert(c.handle, Declared::Const(Typed::Plain(init))); + ctx.named_expressions + .insert(init, (c.name.name.to_string(), c.name.span)); + + return Ok(()); + } }, ast::StatementKind::If { condition, @@ -1658,8 +1788,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(Typed::Plain(handle)); } ast::Expression::Ident(ast::IdentExpr::Local(local)) => { - let rctx = ctx.runtime_expression_ctx(span)?; - return Ok(rctx.local_table[&local]); + return ctx.local(&local, span); } ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => { let global = ctx diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 7df5c8a1c..76c9b04af 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -109,7 +109,7 @@ pub struct EntryPoint<'a> { } #[cfg(doc)] -use crate::front::wgsl::lower::{RuntimeExpressionContext, StatementContext}; +use crate::front::wgsl::lower::{LocalExpressionContext, StatementContext}; #[derive(Debug)] pub struct Function<'a> { @@ -460,10 +460,19 @@ pub struct Let<'a> { pub handle: Handle, } +#[derive(Debug)] +pub struct LocalConst<'a> { + pub name: Ident<'a>, + pub ty: Option>>, + pub init: Handle>, + pub handle: Handle, +} + #[derive(Debug)] pub enum LocalDecl<'a> { Var(LocalVariable<'a>), Let(Let<'a>), + Const(LocalConst<'a>), } #[derive(Debug)] diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index c9114d685..4994f9cd3 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -1688,6 +1688,28 @@ impl Parser { handle, })) } + "const" => { + let _ = lexer.next(); + let name = lexer.next_ident()?; + + let given_ty = if lexer.skip(Token::Separator(':')) { + let ty = self.type_decl(lexer, ctx)?; + Some(ty) + } else { + None + }; + lexer.expect(Token::Operation('='))?; + let expr_id = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + + let handle = ctx.declare_local(name)?; + ast::StatementKind::LocalDecl(ast::LocalDecl::Const(ast::LocalConst { + name, + ty: given_ty, + init: expr_id, + handle, + })) + } "var" => { let _ = lexer.next(); diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index deaa9c93c..a79944a3f 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -317,7 +317,7 @@ pub struct ConstantEvaluator<'a> { #[derive(Debug)] enum WgslRestrictions<'a> { /// - const-expressions will be evaluated and inserted in the arena - Const, + Const(Option>), /// - const-expressions will be evaluated and inserted in the arena /// - override-expressions will be inserted in the arena Override, @@ -347,6 +347,8 @@ struct FunctionLocalData<'a> { #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub enum ExpressionKind { + /// If const is also implemented as const + ImplConst, Const, Override, Runtime, @@ -372,14 +374,23 @@ impl ExpressionKindTracker { pub fn insert(&mut self, value: Handle, expr_type: ExpressionKind) { self.inner.insert(value, expr_type); } + pub fn is_const(&self, h: Handle) -> bool { - matches!(self.type_of(h), ExpressionKind::Const) + matches!( + self.type_of(h), + ExpressionKind::Const | ExpressionKind::ImplConst + ) + } + + /// Returns `true` if naga can also evaluate expression as const + pub fn is_impl_const(&self, h: Handle) -> bool { + matches!(self.type_of(h), ExpressionKind::ImplConst) } pub fn is_const_or_override(&self, h: Handle) -> bool { matches!( self.type_of(h), - ExpressionKind::Const | ExpressionKind::Override + ExpressionKind::Const | ExpressionKind::Override | ExpressionKind::ImplConst ) } @@ -400,13 +411,14 @@ impl ExpressionKindTracker { } fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind { + use crate::MathFunction as Mf; match *expr { Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { - ExpressionKind::Const + ExpressionKind::ImplConst } Expression::Override(_) => ExpressionKind::Override, Expression::Compose { ref components, .. } => { - let mut expr_type = ExpressionKind::Const; + let mut expr_type = ExpressionKind::ImplConst; for component in components { expr_type = expr_type.max(self.type_of(*component)) } @@ -417,13 +429,16 @@ impl ExpressionKindTracker { Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)), Expression::Swizzle { vector, .. } => self.type_of(vector), Expression::Unary { expr, .. } => self.type_of(expr), - Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)), + Expression::Binary { left, right, .. } => self + .type_of(left) + .max(self.type_of(right)) + .max(ExpressionKind::Const), Expression::Math { + fun, arg, arg1, arg2, arg3, - .. } => self .type_of(arg) .max( @@ -437,8 +452,34 @@ impl ExpressionKindTracker { .max( arg3.map(|arg| self.type_of(arg)) .unwrap_or(ExpressionKind::Const), + ) + .max( + if matches!( + fun, + Mf::Dot + | Mf::Outer + | Mf::Cross + | Mf::Distance + | Mf::Length + | Mf::Normalize + | Mf::FaceForward + | Mf::Reflect + | Mf::Refract + | Mf::Ldexp + | Mf::Modf + | Mf::Mix + | Mf::Frexp + ) { + ExpressionKind::Const + } else { + ExpressionKind::ImplConst + }, ), - Expression::As { expr, .. } => self.type_of(expr), + Expression::As { convert, expr, .. } => self.type_of(expr).max(if convert.is_some() { + ExpressionKind::ImplConst + } else { + ExpressionKind::Const + }), Expression::Select { condition, accept, @@ -446,7 +487,8 @@ impl ExpressionKindTracker { } => self .type_of(condition) .max(self.type_of(accept)) - .max(self.type_of(reject)), + .max(self.type_of(reject)) + .max(ExpressionKind::Const), Expression::Relational { argument, .. } => self.type_of(argument), Expression::ArrayLength(expr) => self.type_of(expr), _ => ExpressionKind::Runtime, @@ -556,7 +598,7 @@ impl<'a> ConstantEvaluator<'a> { Behavior::Wgsl(if in_override_ctx { WgslRestrictions::Override } else { - WgslRestrictions::Const + WgslRestrictions::Const(None) }), module, global_expression_kind_tracker, @@ -603,13 +645,19 @@ impl<'a> ConstantEvaluator<'a> { local_expression_kind_tracker: &'a mut ExpressionKindTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, + is_const: bool, ) -> Self { + let local_data = FunctionLocalData { + global_expressions: &module.global_expressions, + emitter, + block, + }; Self { - behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData { - global_expressions: &module.global_expressions, - emitter, - block, - })), + behavior: Behavior::Wgsl(if is_const { + WgslRestrictions::Const(Some(local_data)) + } else { + WgslRestrictions::Runtime(local_data) + }), types: &mut module.types, constants: &module.constants, overrides: &module.overrides, @@ -718,6 +766,7 @@ impl<'a> ConstantEvaluator<'a> { span: Span, ) -> Result, ConstantEvaluatorError> { match self.expression_kind_tracker.type_of_with_expr(&expr) { + ExpressionKind::ImplConst => self.try_eval_and_append_impl(&expr, span), ExpressionKind::Const => { let eval_result = self.try_eval_and_append_impl(&expr, span); // We should be able to evaluate `Const` expressions at this @@ -740,7 +789,7 @@ impl<'a> ConstantEvaluator<'a> { Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => { Ok(self.append_expr(expr, span, ExpressionKind::Override)) } - Behavior::Wgsl(WgslRestrictions::Const) => { + Behavior::Wgsl(WgslRestrictions::Const(_)) => { Err(ConstantEvaluatorError::OverrideExpr) } Behavior::Glsl(_) => { @@ -761,14 +810,17 @@ impl<'a> ConstantEvaluator<'a> { const fn is_global_arena(&self) -> bool { matches!( self.behavior, - Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override) + Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override) | Behavior::Glsl(GlslRestrictions::Const) ) } const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> { match self.behavior { - Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data)) + Behavior::Wgsl( + WgslRestrictions::Runtime(ref function_local_data) + | WgslRestrictions::Const(Some(ref function_local_data)), + ) | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => { Some(function_local_data) } @@ -2057,7 +2109,10 @@ impl<'a> ConstantEvaluator<'a> { expr_type: ExpressionKind, ) -> Handle { let h = match self.behavior { - Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data)) + Behavior::Wgsl( + WgslRestrictions::Runtime(ref mut function_local_data) + | WgslRestrictions::Const(Some(ref mut function_local_data)), + ) | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => { let is_running = function_local_data.emitter.is_running(); let needs_pre_emit = expr.needs_pre_emit(); @@ -2480,7 +2535,7 @@ mod tests { let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl(WgslRestrictions::Const), + behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, @@ -2566,7 +2621,7 @@ mod tests { let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl(WgslRestrictions::Const), + behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, @@ -2684,7 +2739,7 @@ mod tests { let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl(WgslRestrictions::Const), + behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, @@ -2777,7 +2832,7 @@ mod tests { let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl(WgslRestrictions::Const), + behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, @@ -2859,7 +2914,7 @@ mod tests { let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl(WgslRestrictions::Const), + behavior: Behavior::Wgsl(WgslRestrictions::Const(None)), types: &mut types, constants: &constants, overrides: &overrides, diff --git a/naga/tests/in/local-const.param.ron b/naga/tests/in/local-const.param.ron new file mode 100644 index 000000000..dd626a0f3 --- /dev/null +++ b/naga/tests/in/local-const.param.ron @@ -0,0 +1 @@ +() \ No newline at end of file diff --git a/naga/tests/in/local-const.wgsl b/naga/tests/in/local-const.wgsl new file mode 100644 index 000000000..18c932e1e --- /dev/null +++ b/naga/tests/in/local-const.wgsl @@ -0,0 +1,26 @@ +const ga = 4; // AbstractInt with a value of 4. +const gb : i32 = 4; // i32 with a value of 4. +const gc : u32 = 4; // u32 with a value of 4. +const gd : f32 = 4; // f32 with a value of 4. +const ge = vec3(ga, ga, ga); // vec3 of AbstractInt with a value of (4, 4, 4). +const gf = 2.0; // AbstractFloat with a value of 2. + +fn const_in_fn() { + const a = 4; // AbstractInt with a value of 4. + const b: i32 = 4; // i32 with a value of 4. + const c: u32 = 4; // u32 with a value of 4. + const d: f32 = 4; // f32 with a value of 4. + const e = vec3(a, a, a); // vec3 of AbstractInt with a value of (4, 4, 4). + const f = 2.0; // AbstractFloat with a value of 2. + // TODO: Make it per spec, currently not possible + // because naga does not support automatic conversions + // of Abstract types + + // Check that we can access global constants + const ag = ga; + const bg = gb; + const cg = gc; + const dg = gd; + const eg = ge; + const fg = gf; +} diff --git a/naga/tests/out/ir/local-const.compact.ron b/naga/tests/out/ir/local-const.compact.ron new file mode 100644 index 000000000..a9b9f32af --- /dev/null +++ b/naga/tests/out/ir/local-const.compact.ron @@ -0,0 +1,139 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Sint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Sint, + width: 4, + ), + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: {}, + ), + constants: [ + ( + name: Some("ga"), + ty: 0, + init: 0, + ), + ( + name: Some("gb"), + ty: 0, + init: 1, + ), + ( + name: Some("gc"), + ty: 1, + init: 2, + ), + ( + name: Some("gd"), + ty: 2, + init: 3, + ), + ( + name: Some("ge"), + ty: 3, + init: 4, + ), + ( + name: Some("gf"), + ty: 2, + init: 5, + ), + ], + overrides: [], + global_variables: [], + global_expressions: [ + Literal(I32(4)), + Literal(I32(4)), + Literal(U32(4)), + Literal(F32(4.0)), + Compose( + ty: 3, + components: [ + 0, + 0, + 0, + ], + ), + Literal(F32(2.0)), + ], + functions: [ + ( + name: Some("const_in_fn"), + arguments: [], + result: None, + local_variables: [], + expressions: [ + Literal(I32(4)), + Literal(I32(4)), + Literal(U32(4)), + Literal(F32(4.0)), + Compose( + ty: 3, + components: [ + 0, + 0, + 0, + ], + ), + Literal(F32(2.0)), + Constant(0), + Constant(1), + Constant(2), + Constant(3), + Constant(4), + Constant(5), + ], + named_expressions: { + 0: "a", + 1: "b", + 2: "c", + 3: "d", + 4: "e", + 5: "f", + 6: "ag", + 7: "bg", + 8: "cg", + 9: "dg", + 10: "eg", + 11: "fg", + }, + body: [ + Emit(( + start: 4, + end: 5, + )), + ], + ), + ], + entry_points: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/local-const.ron b/naga/tests/out/ir/local-const.ron new file mode 100644 index 000000000..a9b9f32af --- /dev/null +++ b/naga/tests/out/ir/local-const.ron @@ -0,0 +1,139 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Sint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Sint, + width: 4, + ), + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: {}, + ), + constants: [ + ( + name: Some("ga"), + ty: 0, + init: 0, + ), + ( + name: Some("gb"), + ty: 0, + init: 1, + ), + ( + name: Some("gc"), + ty: 1, + init: 2, + ), + ( + name: Some("gd"), + ty: 2, + init: 3, + ), + ( + name: Some("ge"), + ty: 3, + init: 4, + ), + ( + name: Some("gf"), + ty: 2, + init: 5, + ), + ], + overrides: [], + global_variables: [], + global_expressions: [ + Literal(I32(4)), + Literal(I32(4)), + Literal(U32(4)), + Literal(F32(4.0)), + Compose( + ty: 3, + components: [ + 0, + 0, + 0, + ], + ), + Literal(F32(2.0)), + ], + functions: [ + ( + name: Some("const_in_fn"), + arguments: [], + result: None, + local_variables: [], + expressions: [ + Literal(I32(4)), + Literal(I32(4)), + Literal(U32(4)), + Literal(F32(4.0)), + Compose( + ty: 3, + components: [ + 0, + 0, + 0, + ], + ), + Literal(F32(2.0)), + Constant(0), + Constant(1), + Constant(2), + Constant(3), + Constant(4), + Constant(5), + ], + named_expressions: { + 0: "a", + 1: "b", + 2: "c", + 3: "d", + 4: "e", + 5: "f", + 6: "ag", + 7: "bg", + 8: "cg", + 9: "dg", + 10: "eg", + 11: "fg", + }, + body: [ + Emit(( + start: 4, + end: 5, + )), + ], + ), + ], + entry_points: [], +) \ No newline at end of file diff --git a/naga/tests/out/wgsl/binding-arrays.wgsl b/naga/tests/out/wgsl/binding-arrays.wgsl index 86bcfc1bf..5bed8ef00 100644 --- a/naga/tests/out/wgsl/binding-arrays.wgsl +++ b/naga/tests/out/wgsl/binding-arrays.wgsl @@ -34,8 +34,8 @@ fn main(fragment_in: FragmentIn) -> @location(0) vec4 { let uniform_index = uni.index; let non_uniform_index = fragment_in.index; - let uv = vec2(0f); - let pix = vec2(0i); + const uv = vec2(0f); + const pix = vec2(0i); let _e21 = textureDimensions(texture_array_unbounded[0]); let _e22 = u2_; u2_ = (_e22 + _e21); diff --git a/naga/tests/out/wgsl/constructors.wgsl b/naga/tests/out/wgsl/constructors.wgsl index 0e5eec734..6d9d7e2f5 100644 --- a/naga/tests/out/wgsl/constructors.wgsl +++ b/naga/tests/out/wgsl/constructors.wgsl @@ -21,11 +21,11 @@ fn main() { var foo: Foo; foo = Foo(vec4(1f), 1i); - let m0_ = mat2x2(vec2(1f, 0f), vec2(0f, 1f)); - let m1_ = mat4x4(vec4(1f, 0f, 0f, 0f), vec4(0f, 1f, 0f, 0f), vec4(0f, 0f, 1f, 0f), vec4(0f, 0f, 0f, 1f)); - let cit0_ = vec2(0u); - let cit1_ = mat2x2(vec2(0f), vec2(0f)); - let cit2_ = array(0i, 1i, 2i, 3i); - let ic4_ = vec2(0u, 0u); - let ic5_ = mat2x3(vec3(0f, 0f, 0f), vec3(0f, 0f, 0f)); + const m0_ = mat2x2(vec2(1f, 0f), vec2(0f, 1f)); + const m1_ = mat4x4(vec4(1f, 0f, 0f, 0f), vec4(0f, 1f, 0f, 0f), vec4(0f, 0f, 1f, 0f), vec4(0f, 0f, 0f, 1f)); + const cit0_ = vec2(0u); + const cit1_ = mat2x2(vec2(0f), vec2(0f)); + const cit2_ = array(0i, 1i, 2i, 3i); + const ic4_ = vec2(0u, 0u); + const ic5_ = mat2x3(vec3(0f, 0f, 0f), vec3(0f, 0f, 0f)); } diff --git a/naga/tests/out/wgsl/expressions.frag.wgsl b/naga/tests/out/wgsl/expressions.frag.wgsl index 0ba5962ab..ec53847d5 100644 --- a/naga/tests/out/wgsl/expressions.frag.wgsl +++ b/naga/tests/out/wgsl/expressions.frag.wgsl @@ -268,12 +268,12 @@ fn testUnaryOpMat(a_16: mat3x3) { let _e3 = a_17; v_8 = -(_e3); let _e5 = a_17; - let _e7 = vec3(1f); + const _e7 = vec3(1f); let _e9 = (_e5 - mat3x3(_e7, _e7, _e7)); a_17 = _e9; v_8 = _e9; let _e10 = a_17; - let _e12 = vec3(1f); + const _e12 = vec3(1f); a_17 = (_e10 - mat3x3(_e12, _e12, _e12)); v_8 = _e10; return; diff --git a/naga/tests/out/wgsl/functions.wgsl b/naga/tests/out/wgsl/functions.wgsl index 79f000ce2..db7b81b14 100644 --- a/naga/tests/out/wgsl/functions.wgsl +++ b/naga/tests/out/wgsl/functions.wgsl @@ -1,16 +1,16 @@ fn test_fma() -> vec2 { - let a = vec2(2f, 2f); - let b = vec2(0.5f, 0.5f); - let c = vec2(0.5f, 0.5f); + const a = vec2(2f, 2f); + const b = vec2(0.5f, 0.5f); + const c = vec2(0.5f, 0.5f); return fma(a, b, c); } fn test_integer_dot_product() -> i32 { - let a_2_ = vec2(1i); - let b_2_ = vec2(1i); + const a_2_ = vec2(1i); + const b_2_ = vec2(1i); let c_2_ = dot(a_2_, b_2_); - let a_3_ = vec3(1u); - let b_3_ = vec3(1u); + const a_3_ = vec3(1u); + const b_3_ = vec3(1u); let c_3_ = dot(a_3_, b_3_); let c_4_ = dot(vec4(4i), vec4(2i)); return c_4_; diff --git a/naga/tests/out/wgsl/image.wgsl b/naga/tests/out/wgsl/image.wgsl index 008b4c20c..a680e70ab 100644 --- a/naga/tests/out/wgsl/image.wgsl +++ b/naga/tests/out/wgsl/image.wgsl @@ -110,8 +110,8 @@ fn levels_queries() -> @builtin(position) vec4 { fn texture_sample() -> @location(0) vec4 { var a: vec4; - let tc = vec2(0.5f); - let tc3_ = vec3(0.5f); + const tc = vec2(0.5f); + const tc3_ = vec3(0.5f); let _e9 = textureSample(image_1d, sampler_reg, tc.x); let _e10 = a; a = (_e10 + _e9); @@ -186,8 +186,8 @@ fn texture_sample() -> @location(0) vec4 { fn texture_sample_comparison() -> @location(0) f32 { var a_1: f32; - let tc_1 = vec2(0.5f); - let tc3_1 = vec3(0.5f); + const tc_1 = vec2(0.5f); + const tc3_1 = vec3(0.5f); let _e8 = textureSampleCompare(image_2d_depth, sampler_cmp, tc_1, 0.5f); let _e9 = a_1; a_1 = (_e9 + _e8); @@ -218,7 +218,7 @@ fn texture_sample_comparison() -> @location(0) f32 { @fragment fn gather() -> @location(0) vec4 { - let tc_2 = vec2(0.5f); + const tc_2 = vec2(0.5f); let s2d = textureGather(1, image_2d, sampler_reg, tc_2); let s2d_offset = textureGather(3, image_2d, sampler_reg, tc_2, vec2(3i, 1i)); let s2d_depth = textureGatherCompare(image_2d_depth, sampler_cmp, tc_2, 0.5f); @@ -231,7 +231,7 @@ fn gather() -> @location(0) vec4 { @fragment fn depth_no_comparison() -> @location(0) vec4 { - let tc_3 = vec2(0.5f); + const tc_3 = vec2(0.5f); let s2d_1 = textureSample(image_2d_depth, sampler_reg, tc_3); let s2d_gather = textureGather(image_2d_depth, sampler_reg, tc_3); return (vec4(s2d_1) + s2d_gather); diff --git a/naga/tests/out/wgsl/local-const.wgsl b/naga/tests/out/wgsl/local-const.wgsl new file mode 100644 index 000000000..587f5a8e5 --- /dev/null +++ b/naga/tests/out/wgsl/local-const.wgsl @@ -0,0 +1,11 @@ +const ga: i32 = 4i; +const gb: i32 = 4i; +const gc: u32 = 4u; +const gd: f32 = 4f; +const ge: vec3 = vec3(4i, 4i, 4i); +const gf: f32 = 2f; + +fn const_in_fn() { + const e = vec3(4i, 4i, 4i); +} + diff --git a/naga/tests/out/wgsl/math-functions.wgsl b/naga/tests/out/wgsl/math-functions.wgsl index 2271bb9cb..732f7acdc 100644 --- a/naga/tests/out/wgsl/math-functions.wgsl +++ b/naga/tests/out/wgsl/math-functions.wgsl @@ -1,25 +1,25 @@ @fragment fn main() { - let v = vec4(0f); + const v = vec4(0f); let a = degrees(1f); let b = radians(1f); let c = degrees(v); let d = radians(v); let e = saturate(v); let g = refract(v, v, 1f); - let sign_b = vec4(-1i, -1i, -1i, -1i); - let sign_d = vec4(-1f, -1f, -1f, -1f); + const sign_b = vec4(-1i, -1i, -1i, -1i); + const sign_d = vec4(-1f, -1f, -1f, -1f); let const_dot = dot(vec2(), vec2()); - let flb_b = vec2(-1i, -1i); - let flb_c = vec2(0u, 0u); - let ftb_c = vec2(0i, 0i); - let ftb_d = vec2(0u, 0u); - let ctz_e = vec2(32u, 32u); - let ctz_f = vec2(32i, 32i); - let ctz_g = vec2(0u, 0u); - let ctz_h = vec2(0i, 0i); - let clz_c = vec2(0i, 0i); - let clz_d = vec2(31u, 31u); + const flb_b = vec2(-1i, -1i); + const flb_c = vec2(0u, 0u); + const ftb_c = vec2(0i, 0i); + const ftb_d = vec2(0u, 0u); + const ctz_e = vec2(32u, 32u); + const ctz_f = vec2(32i, 32i); + const ctz_g = vec2(0u, 0u); + const ctz_h = vec2(0i, 0i); + const clz_c = vec2(0i, 0i); + const clz_d = vec2(31u, 31u); let lde_a = ldexp(1f, 2i); let lde_b = ldexp(vec2(1f, 2f), vec2(3i, 4i)); let modf_a = modf(1.5f); diff --git a/naga/tests/out/wgsl/operators.wgsl b/naga/tests/out/wgsl/operators.wgsl index dbf39556d..2194a01df 100644 --- a/naga/tests/out/wgsl/operators.wgsl +++ b/naga/tests/out/wgsl/operators.wgsl @@ -11,7 +11,7 @@ fn builtins() -> vec4 { let m2_ = mix(v_f32_zero, v_f32_one, 0.1f); let b1_ = bitcast(1i); let b2_ = bitcast>(v_i32_one); - let v_i32_zero = vec4(0i, 0i, 0i, 0i); + const v_i32_zero = vec4(0i, 0i, 0i, 0i); return (((((vec4((vec4(s1_) + v_i32_zero)) + s2_) + m1_) + m2_) + vec4(b1_)) + b2_); } @@ -40,8 +40,8 @@ fn bool_cast(x: vec3) -> vec3 { } fn logical() { - let neg0_ = !(true); - let neg1_ = !(vec2(true)); + const neg0_ = !(true); + const neg1_ = !(vec2(true)); let or = (true || false); let and = (true && false); let bitwise_or0_ = (true | false); @@ -51,9 +51,9 @@ fn logical() { } fn arithmetic() { - let neg0_1 = -(1f); - let neg1_1 = -(vec2(1i)); - let neg2_ = -(vec2(1f)); + const neg0_1 = -(1f); + const neg1_1 = -(vec2(1i)); + const neg2_ = -(vec2(1f)); let add0_ = (2i + 1i); let add1_ = (2u + 1u); let add2_ = (2f + 1f); @@ -126,10 +126,10 @@ fn arithmetic() { } fn bit() { - let flip0_ = ~(1i); - let flip1_ = ~(1u); - let flip2_ = ~(vec2(1i)); - let flip3_ = ~(vec3(1u)); + const flip0_ = ~(1i); + const flip1_ = ~(1u); + const flip2_ = ~(vec2(1i)); + const flip3_ = ~(vec3(1u)); let or0_ = (2i | 1i); let or1_ = (2u | 1u); let or2_ = (vec2(2i) | vec2(1i)); @@ -230,14 +230,14 @@ fn assignment() { } fn negation_avoids_prefix_decrement() { - let p0_ = -(1i); - let p1_ = -(-(1i)); - let p2_ = -(-(1i)); - let p3_ = -(-(1i)); - let p4_ = -(-(-(1i))); - let p5_ = -(-(-(-(1i)))); - let p6_ = -(-(-(-(-(1i))))); - let p7_ = -(-(-(-(-(1i))))); + const p0_ = -(1i); + const p1_ = -(-(1i)); + const p2_ = -(-(1i)); + const p3_ = -(-(1i)); + const p4_ = -(-(-(1i))); + const p5_ = -(-(-(-(1i)))); + const p6_ = -(-(-(-(-(1i))))); + const p7_ = -(-(-(-(-(1i))))); } @compute @workgroup_size(1, 1, 1) diff --git a/naga/tests/out/wgsl/prepostfix.frag.wgsl b/naga/tests/out/wgsl/prepostfix.frag.wgsl index d2c59a0dd..15916303b 100644 --- a/naga/tests/out/wgsl/prepostfix.frag.wgsl +++ b/naga/tests/out/wgsl/prepostfix.frag.wgsl @@ -21,11 +21,11 @@ fn main_1() { vec = _e21; vec_target = _e21; let _e32 = mat; - let _e34 = vec3(1f); + const _e34 = vec3(1f); mat = (_e32 + mat4x3(_e34, _e34, _e34, _e34)); mat_target = _e32; let _e37 = mat; - let _e39 = vec3(1f); + const _e39 = vec3(1f); let _e41 = (_e37 - mat4x3(_e39, _e39, _e39, _e39)); mat = _e41; mat_target = _e41; diff --git a/naga/tests/out/wgsl/shadow.wgsl b/naga/tests/out/wgsl/shadow.wgsl index e9d5bbf1b..8b198d2ed 100644 --- a/naga/tests/out/wgsl/shadow.wgsl +++ b/naga/tests/out/wgsl/shadow.wgsl @@ -40,7 +40,7 @@ fn fetch_shadow(light_id: u32, homogeneous_coords: vec4) -> f32 { if (homogeneous_coords.w <= 0f) { return 1f; } - let flip_correction = vec2(0.5f, -0.5f); + const flip_correction = vec2(0.5f, -0.5f); let proj_correction = (1f / homogeneous_coords.w); let light_local = (((homogeneous_coords.xy * flip_correction) * proj_correction) + vec2(0.5f, 0.5f)); let _e24 = textureSampleCompareLevel(t_shadow, sampler_shadow, light_local, i32(light_id), (homogeneous_coords.z * proj_correction)); diff --git a/naga/tests/out/wgsl/type-alias.wgsl b/naga/tests/out/wgsl/type-alias.wgsl index fe3cf7903..13bfcba82 100644 --- a/naga/tests/out/wgsl/type-alias.wgsl +++ b/naga/tests/out/wgsl/type-alias.wgsl @@ -1,10 +1,10 @@ fn main() { - let a = vec3(0f, 0f, 0f); - let c = vec3(0f); - let b = vec3(vec2(0f), 0f); - let d = vec3(vec2(0f), 0f); - let e = vec3(d); - let f = mat2x2(vec2(1f, 2f), vec2(3f, 4f)); - let g = mat3x3(a, a, a); + const a = vec3(0f, 0f, 0f); + const c = vec3(0f); + const b = vec3(vec2(0f), 0f); + const d = vec3(vec2(0f), 0f); + const e = vec3(d); + const f = mat2x2(vec2(1f, 2f), vec2(3f, 4f)); + const g = mat3x3(a, a, a); } diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 6395923ac..56e3a95f7 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -818,6 +818,7 @@ fn convert_wgsl() { "use-gl-ext-over-grad-workaround-if-instructed", Targets::GLSL, ), + ("local-const", Targets::IR | Targets::WGSL), ( "math-functions", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, diff --git a/naga/tests/wgsl_errors.rs b/naga/tests/wgsl_errors.rs index d6d1710f7..76f3d5668 100644 --- a/naga/tests/wgsl_errors.rs +++ b/naga/tests/wgsl_errors.rs @@ -2277,3 +2277,97 @@ fn too_many_unclosed_loops() { .join() .unwrap() } + +#[test] +fn local_const_wrong_type() { + check( + " + fn f() { + const c: i32 = 5u; + } + ", + r###"error: the type of `c` is expected to be `i32`, but got `u32` + ┌─ wgsl:3:19 + │ +3 │ const c: i32 = 5u; + │ ^ definition of `c` + +"###, + ); +} + +#[test] +fn local_const_from_let() { + check( + " + fn f() { + let a = 5; + const c = a; + } + ", + r###"error: this operation is not supported in a const context + ┌─ wgsl:4:23 + │ +4 │ const c = a; + │ ^ operation not supported here + +"###, + ); +} + +#[test] +fn local_const_from_var() { + check( + " + fn f() { + var a = 5; + const c = a; + } + ", + r###"error: this operation is not supported in a const context + ┌─ wgsl:4:23 + │ +4 │ const c = a; + │ ^ operation not supported here + +"###, + ); +} + +#[test] +fn local_const_from_override() { + check( + " + override o: i32; + fn f() { + const c = o; + } + ", + r###"error: Unexpected override-expression + ┌─ wgsl:4:23 + │ +4 │ const c = o; + │ ^ see msg + +"###, + ); +} + +#[test] +fn local_const_from_global_var() { + check( + " + var v: i32; + fn f() { + const c = v; + } + ", + r###"error: Unexpected runtime-expression + ┌─ wgsl:4:23 + │ +4 │ const c = v; + │ ^ see msg + +"###, + ); +}