[naga wgsl] Implement local const declarations (#6156)

This commit is contained in:
Samson 2024-08-30 11:55:03 +02:00 committed by GitHub
parent 4454cbfaab
commit 34bb9e4ceb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 769 additions and 127 deletions

View File

@ -47,7 +47,7 @@ pub use features::Features;
use crate::{
back::{self, Baked},
proc::{self, NameKey},
proc::{self, ExpressionKindTracker, NameKey},
valid, Handle, ShaderStage, TypeInner,
};
use features::FeaturesManager;
@ -1584,6 +1584,7 @@ impl<'a, W: Write> Writer<'a, W> {
info,
expressions: &func.expressions,
named_expressions: &func.named_expressions,
expr_kind_tracker: ExpressionKindTracker::from_arena(&func.expressions),
};
self.named_expressions.clear();

View File

@ -8,7 +8,7 @@ use super::{
};
use crate::{
back::{self, Baked},
proc::{self, NameKey},
proc::{self, ExpressionKindTracker, NameKey},
valid, Handle, Module, Scalar, ScalarKind, ShaderStage, TypeInner,
};
use std::{fmt, mem};
@ -346,6 +346,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
info,
expressions: &function.expressions,
named_expressions: &function.named_expressions,
expr_kind_tracker: ExpressionKindTracker::from_arena(&function.expressions),
};
let name = self.names[&NameKey::Function(handle)].clone();
@ -386,6 +387,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
info,
expressions: &ep.function.expressions,
named_expressions: &ep.function.named_expressions,
expr_kind_tracker: ExpressionKindTracker::from_arena(&ep.function.expressions),
};
self.write_wrapped_functions(module, &ctx)?;

View File

@ -3,6 +3,8 @@ Backend functions that export shader [`Module`](super::Module)s into binary and
*/
#![allow(dead_code)] // can be dead if none of the enabled backends need it
use crate::proc::ExpressionKindTracker;
#[cfg(dot_out)]
pub mod dot;
#[cfg(glsl_out)]
@ -118,6 +120,8 @@ pub struct FunctionCtx<'a> {
pub expressions: &'a crate::Arena<crate::Expression>,
/// Map of expressions that have associated variable names
pub named_expressions: &'a crate::NamedExpressions,
/// For constness checks
pub expr_kind_tracker: ExpressionKindTracker,
}
impl FunctionCtx<'_> {

View File

@ -289,6 +289,7 @@ fn process_function(
&mut local_expression_kind_tracker,
&mut emitter,
&mut block,
false,
);
for (old_h, mut expr, span) in expressions.drain() {

View File

@ -1,7 +1,7 @@
use super::Error;
use crate::{
back::{self, Baked},
proc::{self, NameKey},
proc::{self, ExpressionKindTracker, NameKey},
valid, Handle, Module, ShaderStage, TypeInner,
};
use std::fmt::Write;
@ -166,6 +166,7 @@ impl<W: Write> Writer<W> {
info: fun_info,
expressions: &function.expressions,
named_expressions: &function.named_expressions,
expr_kind_tracker: ExpressionKindTracker::from_arena(&function.expressions),
};
// Write the function
@ -193,6 +194,7 @@ impl<W: Write> Writer<W> {
info: info.get_entry_point(index),
expressions: &ep.function.expressions,
named_expressions: &ep.function.named_expressions,
expr_kind_tracker: ExpressionKindTracker::from_arena(&ep.function.expressions),
};
self.write_function(module, &ep.function, &func_ctx)?;
@ -1115,8 +1117,14 @@ impl<W: Write> Writer<W> {
func_ctx: &back::FunctionCtx,
name: &str,
) -> BackendResult {
// Some functions are marked as const, but are not yet implemented as constant expression
let quantifier = if func_ctx.expr_kind_tracker.is_impl_const(handle) {
"const"
} else {
"let"
};
// Write variable name
write!(self.out, "let {name}")?;
write!(self.out, "{quantifier} {name}")?;
if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
write!(self.out, ": ")?;
let ty = &func_ctx.info[handle].ty;

View File

@ -98,7 +98,7 @@ impl<'source> GlobalContext<'source, '_, '_> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
expr_type: ExpressionContextType::Constant,
expr_type: ExpressionContextType::Constant(None),
global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@ -160,7 +160,8 @@ pub struct StatementContext<'source, 'temp, 'out> {
///
/// [`LocalVariable`]: crate::Expression::LocalVariable
/// [`FunctionArgument`]: crate::Expression::FunctionArgument
local_table: &'temp mut FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>,
local_table:
&'temp mut FastHashMap<Handle<ast::Local>, Declared<Typed<Handle<crate::Expression>>>>,
const_typifier: &'temp mut Typifier,
typifier: &'temp mut Typifier,
@ -184,6 +185,32 @@ pub struct StatementContext<'source, 'temp, 'out> {
}
impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
fn as_const<'t>(
&'t mut self,
block: &'t mut crate::Block,
emitter: &'t mut Emitter,
) -> ExpressionContext<'a, 't, '_>
where
'temp: 't,
{
ExpressionContext {
globals: self.globals,
types: self.types,
ast_expressions: self.ast_expressions,
const_typifier: self.const_typifier,
global_expression_kind_tracker: self.global_expression_kind_tracker,
module: self.module,
expr_type: ExpressionContextType::Constant(Some(LocalExpressionContext {
local_table: self.local_table,
function: self.function,
block,
emitter,
typifier: self.typifier,
local_expression_kind_tracker: self.local_expression_kind_tracker,
})),
}
}
fn as_expression<'t>(
&'t mut self,
block: &'t mut crate::Block,
@ -199,7 +226,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
const_typifier: self.const_typifier,
global_expression_kind_tracker: self.global_expression_kind_tracker,
module: self.module,
expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext {
expr_type: ExpressionContextType::Runtime(LocalExpressionContext {
local_table: self.local_table,
function: self.function,
block,
@ -235,12 +262,12 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
}
}
pub struct RuntimeExpressionContext<'temp, 'out> {
pub struct LocalExpressionContext<'temp, 'out> {
/// A map from [`ast::Local`] handles to the Naga expressions we've built for them.
///
/// This is always [`StatementContext::local_table`] for the
/// enclosing statement; see that documentation for details.
local_table: &'temp FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>,
local_table: &'temp FastHashMap<Handle<ast::Local>, Declared<Typed<Handle<crate::Expression>>>>,
function: &'out mut crate::Function,
block: &'temp mut crate::Block,
@ -262,15 +289,15 @@ pub enum ExpressionContextType<'temp, 'out> {
/// The given [`RuntimeExpressionContext`] holds information about local
/// variables, arguments, and other definitions available only to runtime
/// expressions, not constant or override expressions.
Runtime(RuntimeExpressionContext<'temp, 'out>),
Runtime(LocalExpressionContext<'temp, 'out>),
/// We are lowering to a constant expression, to be included in the module's
/// constant expression arena.
///
/// Everything constant expressions are allowed to refer to is
/// available in the [`ExpressionContext`], so this variant
/// carries no further information.
Constant,
/// Everything global constant expressions are allowed to refer to is
/// available in the [`ExpressionContext`], but local constant expressions can
/// also refer to other
Constant(Option<LocalExpressionContext<'temp, 'out>>),
/// We are lowering to an override expression, to be included in the module's
/// constant expression arena.
@ -352,7 +379,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
ast_expressions: self.ast_expressions,
const_typifier: self.const_typifier,
module: self.module,
expr_type: ExpressionContextType::Constant,
expr_type: ExpressionContextType::Constant(None),
global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@ -376,8 +403,19 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
rctx.local_expression_kind_tracker,
rctx.emitter,
rctx.block,
false,
),
ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(
ExpressionContextType::Constant(Some(ref mut rctx)) => {
ConstantEvaluator::for_wgsl_function(
self.module,
&mut rctx.function.expressions,
rctx.local_expression_kind_tracker,
rctx.emitter,
rctx.block,
true,
)
}
ExpressionContextType::Constant(None) => ConstantEvaluator::for_wgsl_module(
self.module,
self.global_expression_kind_tracker,
false,
@ -412,15 +450,27 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
.eval_expr_to_u32_from(handle, &ctx.function.expressions)
.ok()
}
ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(),
ExpressionContextType::Constant(Some(ref ctx)) => {
assert!(ctx.local_expression_kind_tracker.is_const(handle));
self.module
.to_ctx()
.eval_expr_to_u32_from(handle, &ctx.function.expressions)
.ok()
}
ExpressionContextType::Constant(None) => {
self.module.to_ctx().eval_expr_to_u32(handle).ok()
}
ExpressionContextType::Override => None,
}
}
fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle),
ExpressionContextType::Constant | ExpressionContextType::Override => {
ExpressionContextType::Runtime(ref ctx)
| ExpressionContextType::Constant(Some(ref ctx)) => {
ctx.function.expressions.get_span(handle)
}
ExpressionContextType::Constant(None) | ExpressionContextType::Override => {
self.module.global_expressions.get_span(handle)
}
}
@ -428,20 +478,35 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
fn typifier(&self) -> &Typifier {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
ExpressionContextType::Constant | ExpressionContextType::Override => {
ExpressionContextType::Runtime(ref ctx)
| ExpressionContextType::Constant(Some(ref ctx)) => ctx.typifier,
ExpressionContextType::Constant(None) | ExpressionContextType::Override => {
self.const_typifier
}
}
}
fn local(
&mut self,
local: &Handle<ast::Local>,
span: Span,
) -> Result<Typed<Handle<crate::Expression>>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => Ok(ctx.local_table[local].runtime()),
ExpressionContextType::Constant(Some(ref ctx)) => ctx.local_table[local]
.const_time()
.ok_or(Error::UnexpectedOperationInConstContext(span)),
_ => Err(Error::UnexpectedOperationInConstContext(span)),
}
}
fn runtime_expression_ctx(
&mut self,
span: Span,
) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> {
) -> Result<&mut LocalExpressionContext<'temp, 'out>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx),
ExpressionContextType::Constant | ExpressionContextType::Override => {
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
Err(Error::UnexpectedOperationInConstContext(span))
}
}
@ -480,7 +545,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
// This means a `gather` operation appeared in a constant expression.
// This error refers to the `gather` itself, not its "component" argument.
ExpressionContextType::Constant | ExpressionContextType::Override => {
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
Err(Error::UnexpectedOperationInConstContext(gather_span))
}
}
@ -505,8 +570,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
// except that this lets the borrow checker see that it's okay
// to also borrow self.module.types mutably below.
let typifier = match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
ExpressionContextType::Constant | ExpressionContextType::Override => {
ExpressionContextType::Runtime(ref ctx)
| ExpressionContextType::Constant(Some(ref ctx)) => ctx.typifier,
ExpressionContextType::Constant(None) | ExpressionContextType::Override => {
&*self.const_typifier
}
};
@ -542,7 +608,8 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
let typifier;
let expressions;
match self.expr_type {
ExpressionContextType::Runtime(ref mut ctx) => {
ExpressionContextType::Runtime(ref mut ctx)
| ExpressionContextType::Constant(Some(ref mut ctx)) => {
resolve_ctx = ResolveContext::with_locals(
self.module,
&ctx.function.local_variables,
@ -551,7 +618,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
typifier = &mut *ctx.typifier;
expressions = &ctx.function.expressions;
}
ExpressionContextType::Constant | ExpressionContextType::Override => {
ExpressionContextType::Constant(None) | ExpressionContextType::Override => {
resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]);
typifier = self.const_typifier;
expressions = &self.module.global_expressions;
@ -643,18 +710,20 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
span: Span,
) -> Result<Handle<crate::Expression>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => {
ExpressionContextType::Runtime(ref mut rctx)
| ExpressionContextType::Constant(Some(ref mut rctx)) => {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
}
ExpressionContextType::Constant | ExpressionContextType::Override => {}
ExpressionContextType::Constant(None) | ExpressionContextType::Override => {}
}
let result = self.append_expression(expression, span);
match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => {
ExpressionContextType::Runtime(ref mut rctx)
| ExpressionContextType::Constant(Some(ref mut rctx)) => {
rctx.emitter.start(&rctx.function.expressions);
}
ExpressionContextType::Constant | ExpressionContextType::Override => {}
ExpressionContextType::Constant(None) | ExpressionContextType::Override => {}
}
result
}
@ -718,6 +787,30 @@ impl<'source> ArgumentContext<'_, 'source> {
}
}
#[derive(Debug, Copy, Clone)]
enum Declared<T> {
/// Value declared as const
Const(T),
/// Value declared as non-const
Runtime(T),
}
impl<T> Declared<T> {
fn runtime(self) -> T {
match self {
Declared::Const(t) | Declared::Runtime(t) => t,
}
}
fn const_time(self) -> Option<T> {
match self {
Declared::Const(t) => Some(t),
Declared::Runtime(_) => None,
}
}
}
/// WGSL type annotations on expressions, types, values, etc.
///
/// Naga and WGSL types are very close, but Naga lacks WGSL's `ref` types, which
@ -1120,7 +1213,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let ty = self.resolve_ast_type(arg.ty, ctx)?;
let expr = expressions
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
local_table.insert(arg.handle, Typed::Plain(expr));
local_table.insert(arg.handle, Declared::Runtime(Typed::Plain(expr)));
named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime);
@ -1268,7 +1361,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
block.extend(emitter.finish(&ctx.function.expressions));
ctx.local_table.insert(l.handle, Typed::Plain(value));
ctx.local_table
.insert(l.handle, Declared::Runtime(Typed::Plain(value)));
ctx.named_expressions
.insert(value, (l.name.name.to_string(), l.name.span));
@ -1350,7 +1444,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Span::UNDEFINED,
)?;
block.extend(emitter.finish(&ctx.function.expressions));
ctx.local_table.insert(v.handle, Typed::Reference(handle));
ctx.local_table
.insert(v.handle, Declared::Runtime(Typed::Reference(handle)));
match initializer {
Some(initializer) => crate::Statement::Store {
@ -1360,6 +1455,41 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
None => return Ok(()),
}
}
ast::LocalDecl::Const(ref c) => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);
let ectx = &mut ctx.as_const(block, &mut emitter);
let mut init = self.expression_for_abstract(c.init, ectx)?;
if let Some(explicit_ty) = c.ty {
let explicit_ty =
self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?;
let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty);
init = ectx
.try_automatic_conversions(init, &explicit_ty_res, c.name.span)
.map_err(|error| match error {
Error::AutoConversion(error) => Error::InitializationTypeMismatch {
name: c.name.span,
expected: error.dest_type,
got: error.source_type,
},
other => other,
})?;
} else {
init = ectx.concretize(init)?;
ectx.register_type(init)?;
}
block.extend(emitter.finish(&ctx.function.expressions));
ctx.local_table
.insert(c.handle, Declared::Const(Typed::Plain(init)));
ctx.named_expressions
.insert(init, (c.name.name.to_string(), c.name.span));
return Ok(());
}
},
ast::StatementKind::If {
condition,
@ -1658,8 +1788,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(Typed::Plain(handle));
}
ast::Expression::Ident(ast::IdentExpr::Local(local)) => {
let rctx = ctx.runtime_expression_ctx(span)?;
return Ok(rctx.local_table[&local]);
return ctx.local(&local, span);
}
ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => {
let global = ctx

View File

@ -109,7 +109,7 @@ pub struct EntryPoint<'a> {
}
#[cfg(doc)]
use crate::front::wgsl::lower::{RuntimeExpressionContext, StatementContext};
use crate::front::wgsl::lower::{LocalExpressionContext, StatementContext};
#[derive(Debug)]
pub struct Function<'a> {
@ -460,10 +460,19 @@ pub struct Let<'a> {
pub handle: Handle<Local>,
}
#[derive(Debug)]
pub struct LocalConst<'a> {
pub name: Ident<'a>,
pub ty: Option<Handle<Type<'a>>>,
pub init: Handle<Expression<'a>>,
pub handle: Handle<Local>,
}
#[derive(Debug)]
pub enum LocalDecl<'a> {
Var(LocalVariable<'a>),
Let(Let<'a>),
Const(LocalConst<'a>),
}
#[derive(Debug)]

View File

@ -1688,6 +1688,28 @@ impl Parser {
handle,
}))
}
"const" => {
let _ = lexer.next();
let name = lexer.next_ident()?;
let given_ty = if lexer.skip(Token::Separator(':')) {
let ty = self.type_decl(lexer, ctx)?;
Some(ty)
} else {
None
};
lexer.expect(Token::Operation('='))?;
let expr_id = self.general_expression(lexer, ctx)?;
lexer.expect(Token::Separator(';'))?;
let handle = ctx.declare_local(name)?;
ast::StatementKind::LocalDecl(ast::LocalDecl::Const(ast::LocalConst {
name,
ty: given_ty,
init: expr_id,
handle,
}))
}
"var" => {
let _ = lexer.next();

View File

@ -317,7 +317,7 @@ pub struct ConstantEvaluator<'a> {
#[derive(Debug)]
enum WgslRestrictions<'a> {
/// - const-expressions will be evaluated and inserted in the arena
Const,
Const(Option<FunctionLocalData<'a>>),
/// - const-expressions will be evaluated and inserted in the arena
/// - override-expressions will be inserted in the arena
Override,
@ -347,6 +347,8 @@ struct FunctionLocalData<'a> {
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub enum ExpressionKind {
/// If const is also implemented as const
ImplConst,
Const,
Override,
Runtime,
@ -372,14 +374,23 @@ impl ExpressionKindTracker {
pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
self.inner.insert(value, expr_type);
}
pub fn is_const(&self, h: Handle<Expression>) -> bool {
matches!(self.type_of(h), ExpressionKind::Const)
matches!(
self.type_of(h),
ExpressionKind::Const | ExpressionKind::ImplConst
)
}
/// Returns `true` if naga can also evaluate expression as const
pub fn is_impl_const(&self, h: Handle<Expression>) -> bool {
matches!(self.type_of(h), ExpressionKind::ImplConst)
}
pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
matches!(
self.type_of(h),
ExpressionKind::Const | ExpressionKind::Override
ExpressionKind::Const | ExpressionKind::Override | ExpressionKind::ImplConst
)
}
@ -400,13 +411,14 @@ impl ExpressionKindTracker {
}
fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
use crate::MathFunction as Mf;
match *expr {
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
ExpressionKind::Const
ExpressionKind::ImplConst
}
Expression::Override(_) => ExpressionKind::Override,
Expression::Compose { ref components, .. } => {
let mut expr_type = ExpressionKind::Const;
let mut expr_type = ExpressionKind::ImplConst;
for component in components {
expr_type = expr_type.max(self.type_of(*component))
}
@ -417,13 +429,16 @@ impl ExpressionKindTracker {
Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
Expression::Swizzle { vector, .. } => self.type_of(vector),
Expression::Unary { expr, .. } => self.type_of(expr),
Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
Expression::Binary { left, right, .. } => self
.type_of(left)
.max(self.type_of(right))
.max(ExpressionKind::Const),
Expression::Math {
fun,
arg,
arg1,
arg2,
arg3,
..
} => self
.type_of(arg)
.max(
@ -437,8 +452,34 @@ impl ExpressionKindTracker {
.max(
arg3.map(|arg| self.type_of(arg))
.unwrap_or(ExpressionKind::Const),
)
.max(
if matches!(
fun,
Mf::Dot
| Mf::Outer
| Mf::Cross
| Mf::Distance
| Mf::Length
| Mf::Normalize
| Mf::FaceForward
| Mf::Reflect
| Mf::Refract
| Mf::Ldexp
| Mf::Modf
| Mf::Mix
| Mf::Frexp
) {
ExpressionKind::Const
} else {
ExpressionKind::ImplConst
},
),
Expression::As { expr, .. } => self.type_of(expr),
Expression::As { convert, expr, .. } => self.type_of(expr).max(if convert.is_some() {
ExpressionKind::ImplConst
} else {
ExpressionKind::Const
}),
Expression::Select {
condition,
accept,
@ -446,7 +487,8 @@ impl ExpressionKindTracker {
} => self
.type_of(condition)
.max(self.type_of(accept))
.max(self.type_of(reject)),
.max(self.type_of(reject))
.max(ExpressionKind::Const),
Expression::Relational { argument, .. } => self.type_of(argument),
Expression::ArrayLength(expr) => self.type_of(expr),
_ => ExpressionKind::Runtime,
@ -556,7 +598,7 @@ impl<'a> ConstantEvaluator<'a> {
Behavior::Wgsl(if in_override_ctx {
WgslRestrictions::Override
} else {
WgslRestrictions::Const
WgslRestrictions::Const(None)
}),
module,
global_expression_kind_tracker,
@ -603,13 +645,19 @@ impl<'a> ConstantEvaluator<'a> {
local_expression_kind_tracker: &'a mut ExpressionKindTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
is_const: bool,
) -> Self {
let local_data = FunctionLocalData {
global_expressions: &module.global_expressions,
emitter,
block,
};
Self {
behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData {
global_expressions: &module.global_expressions,
emitter,
block,
})),
behavior: Behavior::Wgsl(if is_const {
WgslRestrictions::Const(Some(local_data))
} else {
WgslRestrictions::Runtime(local_data)
}),
types: &mut module.types,
constants: &module.constants,
overrides: &module.overrides,
@ -718,6 +766,7 @@ impl<'a> ConstantEvaluator<'a> {
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
match self.expression_kind_tracker.type_of_with_expr(&expr) {
ExpressionKind::ImplConst => self.try_eval_and_append_impl(&expr, span),
ExpressionKind::Const => {
let eval_result = self.try_eval_and_append_impl(&expr, span);
// We should be able to evaluate `Const` expressions at this
@ -740,7 +789,7 @@ impl<'a> ConstantEvaluator<'a> {
Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
Ok(self.append_expr(expr, span, ExpressionKind::Override))
}
Behavior::Wgsl(WgslRestrictions::Const) => {
Behavior::Wgsl(WgslRestrictions::Const(_)) => {
Err(ConstantEvaluatorError::OverrideExpr)
}
Behavior::Glsl(_) => {
@ -761,14 +810,17 @@ impl<'a> ConstantEvaluator<'a> {
const fn is_global_arena(&self) -> bool {
matches!(
self.behavior,
Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override)
Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
| Behavior::Glsl(GlslRestrictions::Const)
)
}
const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
match self.behavior {
Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data))
Behavior::Wgsl(
WgslRestrictions::Runtime(ref function_local_data)
| WgslRestrictions::Const(Some(ref function_local_data)),
)
| Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
Some(function_local_data)
}
@ -2057,7 +2109,10 @@ impl<'a> ConstantEvaluator<'a> {
expr_type: ExpressionKind,
) -> Handle<Expression> {
let h = match self.behavior {
Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data))
Behavior::Wgsl(
WgslRestrictions::Runtime(ref mut function_local_data)
| WgslRestrictions::Const(Some(ref mut function_local_data)),
)
| Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
let is_running = function_local_data.emitter.is_running();
let needs_pre_emit = expr.needs_pre_emit();
@ -2480,7 +2535,7 @@ mod tests {
let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl(WgslRestrictions::Const),
behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
types: &mut types,
constants: &constants,
overrides: &overrides,
@ -2566,7 +2621,7 @@ mod tests {
let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl(WgslRestrictions::Const),
behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
types: &mut types,
constants: &constants,
overrides: &overrides,
@ -2684,7 +2739,7 @@ mod tests {
let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl(WgslRestrictions::Const),
behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
types: &mut types,
constants: &constants,
overrides: &overrides,
@ -2777,7 +2832,7 @@ mod tests {
let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl(WgslRestrictions::Const),
behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
types: &mut types,
constants: &constants,
overrides: &overrides,
@ -2859,7 +2914,7 @@ mod tests {
let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl(WgslRestrictions::Const),
behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
types: &mut types,
constants: &constants,
overrides: &overrides,

View File

@ -0,0 +1 @@
()

View File

@ -0,0 +1,26 @@
const ga = 4; // AbstractInt with a value of 4.
const gb : i32 = 4; // i32 with a value of 4.
const gc : u32 = 4; // u32 with a value of 4.
const gd : f32 = 4; // f32 with a value of 4.
const ge = vec3(ga, ga, ga); // vec3 of AbstractInt with a value of (4, 4, 4).
const gf = 2.0; // AbstractFloat with a value of 2.
fn const_in_fn() {
const a = 4; // AbstractInt with a value of 4.
const b: i32 = 4; // i32 with a value of 4.
const c: u32 = 4; // u32 with a value of 4.
const d: f32 = 4; // f32 with a value of 4.
const e = vec3(a, a, a); // vec3 of AbstractInt with a value of (4, 4, 4).
const f = 2.0; // AbstractFloat with a value of 2.
// TODO: Make it per spec, currently not possible
// because naga does not support automatic conversions
// of Abstract types
// Check that we can access global constants
const ag = ga;
const bg = gb;
const cg = gc;
const dg = gd;
const eg = ge;
const fg = gf;
}

View File

@ -0,0 +1,139 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
(
name: None,
inner: Scalar((
kind: Uint,
width: 4,
)),
),
(
name: None,
inner: Scalar((
kind: Float,
width: 4,
)),
),
(
name: None,
inner: Vector(
size: Tri,
scalar: (
kind: Sint,
width: 4,
),
),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [
(
name: Some("ga"),
ty: 0,
init: 0,
),
(
name: Some("gb"),
ty: 0,
init: 1,
),
(
name: Some("gc"),
ty: 1,
init: 2,
),
(
name: Some("gd"),
ty: 2,
init: 3,
),
(
name: Some("ge"),
ty: 3,
init: 4,
),
(
name: Some("gf"),
ty: 2,
init: 5,
),
],
overrides: [],
global_variables: [],
global_expressions: [
Literal(I32(4)),
Literal(I32(4)),
Literal(U32(4)),
Literal(F32(4.0)),
Compose(
ty: 3,
components: [
0,
0,
0,
],
),
Literal(F32(2.0)),
],
functions: [
(
name: Some("const_in_fn"),
arguments: [],
result: None,
local_variables: [],
expressions: [
Literal(I32(4)),
Literal(I32(4)),
Literal(U32(4)),
Literal(F32(4.0)),
Compose(
ty: 3,
components: [
0,
0,
0,
],
),
Literal(F32(2.0)),
Constant(0),
Constant(1),
Constant(2),
Constant(3),
Constant(4),
Constant(5),
],
named_expressions: {
0: "a",
1: "b",
2: "c",
3: "d",
4: "e",
5: "f",
6: "ag",
7: "bg",
8: "cg",
9: "dg",
10: "eg",
11: "fg",
},
body: [
Emit((
start: 4,
end: 5,
)),
],
),
],
entry_points: [],
)

View File

@ -0,0 +1,139 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
(
name: None,
inner: Scalar((
kind: Uint,
width: 4,
)),
),
(
name: None,
inner: Scalar((
kind: Float,
width: 4,
)),
),
(
name: None,
inner: Vector(
size: Tri,
scalar: (
kind: Sint,
width: 4,
),
),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [
(
name: Some("ga"),
ty: 0,
init: 0,
),
(
name: Some("gb"),
ty: 0,
init: 1,
),
(
name: Some("gc"),
ty: 1,
init: 2,
),
(
name: Some("gd"),
ty: 2,
init: 3,
),
(
name: Some("ge"),
ty: 3,
init: 4,
),
(
name: Some("gf"),
ty: 2,
init: 5,
),
],
overrides: [],
global_variables: [],
global_expressions: [
Literal(I32(4)),
Literal(I32(4)),
Literal(U32(4)),
Literal(F32(4.0)),
Compose(
ty: 3,
components: [
0,
0,
0,
],
),
Literal(F32(2.0)),
],
functions: [
(
name: Some("const_in_fn"),
arguments: [],
result: None,
local_variables: [],
expressions: [
Literal(I32(4)),
Literal(I32(4)),
Literal(U32(4)),
Literal(F32(4.0)),
Compose(
ty: 3,
components: [
0,
0,
0,
],
),
Literal(F32(2.0)),
Constant(0),
Constant(1),
Constant(2),
Constant(3),
Constant(4),
Constant(5),
],
named_expressions: {
0: "a",
1: "b",
2: "c",
3: "d",
4: "e",
5: "f",
6: "ag",
7: "bg",
8: "cg",
9: "dg",
10: "eg",
11: "fg",
},
body: [
Emit((
start: 4,
end: 5,
)),
],
),
],
entry_points: [],
)

View File

@ -34,8 +34,8 @@ fn main(fragment_in: FragmentIn) -> @location(0) vec4<f32> {
let uniform_index = uni.index;
let non_uniform_index = fragment_in.index;
let uv = vec2(0f);
let pix = vec2(0i);
const uv = vec2(0f);
const pix = vec2(0i);
let _e21 = textureDimensions(texture_array_unbounded[0]);
let _e22 = u2_;
u2_ = (_e22 + _e21);

View File

@ -21,11 +21,11 @@ fn main() {
var foo: Foo;
foo = Foo(vec4(1f), 1i);
let m0_ = mat2x2<f32>(vec2<f32>(1f, 0f), vec2<f32>(0f, 1f));
let m1_ = mat4x4<f32>(vec4<f32>(1f, 0f, 0f, 0f), vec4<f32>(0f, 1f, 0f, 0f), vec4<f32>(0f, 0f, 1f, 0f), vec4<f32>(0f, 0f, 0f, 1f));
let cit0_ = vec2(0u);
let cit1_ = mat2x2<f32>(vec2(0f), vec2(0f));
let cit2_ = array<i32, 4>(0i, 1i, 2i, 3i);
let ic4_ = vec2<u32>(0u, 0u);
let ic5_ = mat2x3<f32>(vec3<f32>(0f, 0f, 0f), vec3<f32>(0f, 0f, 0f));
const m0_ = mat2x2<f32>(vec2<f32>(1f, 0f), vec2<f32>(0f, 1f));
const m1_ = mat4x4<f32>(vec4<f32>(1f, 0f, 0f, 0f), vec4<f32>(0f, 1f, 0f, 0f), vec4<f32>(0f, 0f, 1f, 0f), vec4<f32>(0f, 0f, 0f, 1f));
const cit0_ = vec2(0u);
const cit1_ = mat2x2<f32>(vec2(0f), vec2(0f));
const cit2_ = array<i32, 4>(0i, 1i, 2i, 3i);
const ic4_ = vec2<u32>(0u, 0u);
const ic5_ = mat2x3<f32>(vec3<f32>(0f, 0f, 0f), vec3<f32>(0f, 0f, 0f));
}

View File

@ -268,12 +268,12 @@ fn testUnaryOpMat(a_16: mat3x3<f32>) {
let _e3 = a_17;
v_8 = -(_e3);
let _e5 = a_17;
let _e7 = vec3(1f);
const _e7 = vec3(1f);
let _e9 = (_e5 - mat3x3<f32>(_e7, _e7, _e7));
a_17 = _e9;
v_8 = _e9;
let _e10 = a_17;
let _e12 = vec3(1f);
const _e12 = vec3(1f);
a_17 = (_e10 - mat3x3<f32>(_e12, _e12, _e12));
v_8 = _e10;
return;

View File

@ -1,16 +1,16 @@
fn test_fma() -> vec2<f32> {
let a = vec2<f32>(2f, 2f);
let b = vec2<f32>(0.5f, 0.5f);
let c = vec2<f32>(0.5f, 0.5f);
const a = vec2<f32>(2f, 2f);
const b = vec2<f32>(0.5f, 0.5f);
const c = vec2<f32>(0.5f, 0.5f);
return fma(a, b, c);
}
fn test_integer_dot_product() -> i32 {
let a_2_ = vec2(1i);
let b_2_ = vec2(1i);
const a_2_ = vec2(1i);
const b_2_ = vec2(1i);
let c_2_ = dot(a_2_, b_2_);
let a_3_ = vec3(1u);
let b_3_ = vec3(1u);
const a_3_ = vec3(1u);
const b_3_ = vec3(1u);
let c_3_ = dot(a_3_, b_3_);
let c_4_ = dot(vec4(4i), vec4(2i));
return c_4_;

View File

@ -110,8 +110,8 @@ fn levels_queries() -> @builtin(position) vec4<f32> {
fn texture_sample() -> @location(0) vec4<f32> {
var a: vec4<f32>;
let tc = vec2(0.5f);
let tc3_ = vec3(0.5f);
const tc = vec2(0.5f);
const tc3_ = vec3(0.5f);
let _e9 = textureSample(image_1d, sampler_reg, tc.x);
let _e10 = a;
a = (_e10 + _e9);
@ -186,8 +186,8 @@ fn texture_sample() -> @location(0) vec4<f32> {
fn texture_sample_comparison() -> @location(0) f32 {
var a_1: f32;
let tc_1 = vec2(0.5f);
let tc3_1 = vec3(0.5f);
const tc_1 = vec2(0.5f);
const tc3_1 = vec3(0.5f);
let _e8 = textureSampleCompare(image_2d_depth, sampler_cmp, tc_1, 0.5f);
let _e9 = a_1;
a_1 = (_e9 + _e8);
@ -218,7 +218,7 @@ fn texture_sample_comparison() -> @location(0) f32 {
@fragment
fn gather() -> @location(0) vec4<f32> {
let tc_2 = vec2(0.5f);
const tc_2 = vec2(0.5f);
let s2d = textureGather(1, image_2d, sampler_reg, tc_2);
let s2d_offset = textureGather(3, image_2d, sampler_reg, tc_2, vec2<i32>(3i, 1i));
let s2d_depth = textureGatherCompare(image_2d_depth, sampler_cmp, tc_2, 0.5f);
@ -231,7 +231,7 @@ fn gather() -> @location(0) vec4<f32> {
@fragment
fn depth_no_comparison() -> @location(0) vec4<f32> {
let tc_3 = vec2(0.5f);
const tc_3 = vec2(0.5f);
let s2d_1 = textureSample(image_2d_depth, sampler_reg, tc_3);
let s2d_gather = textureGather(image_2d_depth, sampler_reg, tc_3);
return (vec4(s2d_1) + s2d_gather);

View File

@ -0,0 +1,11 @@
const ga: i32 = 4i;
const gb: i32 = 4i;
const gc: u32 = 4u;
const gd: f32 = 4f;
const ge: vec3<i32> = vec3<i32>(4i, 4i, 4i);
const gf: f32 = 2f;
fn const_in_fn() {
const e = vec3<i32>(4i, 4i, 4i);
}

View File

@ -1,25 +1,25 @@
@fragment
fn main() {
let v = vec4(0f);
const v = vec4(0f);
let a = degrees(1f);
let b = radians(1f);
let c = degrees(v);
let d = radians(v);
let e = saturate(v);
let g = refract(v, v, 1f);
let sign_b = vec4<i32>(-1i, -1i, -1i, -1i);
let sign_d = vec4<f32>(-1f, -1f, -1f, -1f);
const sign_b = vec4<i32>(-1i, -1i, -1i, -1i);
const sign_d = vec4<f32>(-1f, -1f, -1f, -1f);
let const_dot = dot(vec2<i32>(), vec2<i32>());
let flb_b = vec2<i32>(-1i, -1i);
let flb_c = vec2<u32>(0u, 0u);
let ftb_c = vec2<i32>(0i, 0i);
let ftb_d = vec2<u32>(0u, 0u);
let ctz_e = vec2<u32>(32u, 32u);
let ctz_f = vec2<i32>(32i, 32i);
let ctz_g = vec2<u32>(0u, 0u);
let ctz_h = vec2<i32>(0i, 0i);
let clz_c = vec2<i32>(0i, 0i);
let clz_d = vec2<u32>(31u, 31u);
const flb_b = vec2<i32>(-1i, -1i);
const flb_c = vec2<u32>(0u, 0u);
const ftb_c = vec2<i32>(0i, 0i);
const ftb_d = vec2<u32>(0u, 0u);
const ctz_e = vec2<u32>(32u, 32u);
const ctz_f = vec2<i32>(32i, 32i);
const ctz_g = vec2<u32>(0u, 0u);
const ctz_h = vec2<i32>(0i, 0i);
const clz_c = vec2<i32>(0i, 0i);
const clz_d = vec2<u32>(31u, 31u);
let lde_a = ldexp(1f, 2i);
let lde_b = ldexp(vec2<f32>(1f, 2f), vec2<i32>(3i, 4i));
let modf_a = modf(1.5f);

View File

@ -11,7 +11,7 @@ fn builtins() -> vec4<f32> {
let m2_ = mix(v_f32_zero, v_f32_one, 0.1f);
let b1_ = bitcast<f32>(1i);
let b2_ = bitcast<vec4<f32>>(v_i32_one);
let v_i32_zero = vec4<i32>(0i, 0i, 0i, 0i);
const v_i32_zero = vec4<i32>(0i, 0i, 0i, 0i);
return (((((vec4<f32>((vec4(s1_) + v_i32_zero)) + s2_) + m1_) + m2_) + vec4(b1_)) + b2_);
}
@ -40,8 +40,8 @@ fn bool_cast(x: vec3<f32>) -> vec3<f32> {
}
fn logical() {
let neg0_ = !(true);
let neg1_ = !(vec2(true));
const neg0_ = !(true);
const neg1_ = !(vec2(true));
let or = (true || false);
let and = (true && false);
let bitwise_or0_ = (true | false);
@ -51,9 +51,9 @@ fn logical() {
}
fn arithmetic() {
let neg0_1 = -(1f);
let neg1_1 = -(vec2(1i));
let neg2_ = -(vec2(1f));
const neg0_1 = -(1f);
const neg1_1 = -(vec2(1i));
const neg2_ = -(vec2(1f));
let add0_ = (2i + 1i);
let add1_ = (2u + 1u);
let add2_ = (2f + 1f);
@ -126,10 +126,10 @@ fn arithmetic() {
}
fn bit() {
let flip0_ = ~(1i);
let flip1_ = ~(1u);
let flip2_ = ~(vec2(1i));
let flip3_ = ~(vec3(1u));
const flip0_ = ~(1i);
const flip1_ = ~(1u);
const flip2_ = ~(vec2(1i));
const flip3_ = ~(vec3(1u));
let or0_ = (2i | 1i);
let or1_ = (2u | 1u);
let or2_ = (vec2(2i) | vec2(1i));
@ -230,14 +230,14 @@ fn assignment() {
}
fn negation_avoids_prefix_decrement() {
let p0_ = -(1i);
let p1_ = -(-(1i));
let p2_ = -(-(1i));
let p3_ = -(-(1i));
let p4_ = -(-(-(1i)));
let p5_ = -(-(-(-(1i))));
let p6_ = -(-(-(-(-(1i)))));
let p7_ = -(-(-(-(-(1i)))));
const p0_ = -(1i);
const p1_ = -(-(1i));
const p2_ = -(-(1i));
const p3_ = -(-(1i));
const p4_ = -(-(-(1i)));
const p5_ = -(-(-(-(1i))));
const p6_ = -(-(-(-(-(1i)))));
const p7_ = -(-(-(-(-(1i)))));
}
@compute @workgroup_size(1, 1, 1)

View File

@ -21,11 +21,11 @@ fn main_1() {
vec = _e21;
vec_target = _e21;
let _e32 = mat;
let _e34 = vec3(1f);
const _e34 = vec3(1f);
mat = (_e32 + mat4x3<f32>(_e34, _e34, _e34, _e34));
mat_target = _e32;
let _e37 = mat;
let _e39 = vec3(1f);
const _e39 = vec3(1f);
let _e41 = (_e37 - mat4x3<f32>(_e39, _e39, _e39, _e39));
mat = _e41;
mat_target = _e41;

View File

@ -40,7 +40,7 @@ fn fetch_shadow(light_id: u32, homogeneous_coords: vec4<f32>) -> f32 {
if (homogeneous_coords.w <= 0f) {
return 1f;
}
let flip_correction = vec2<f32>(0.5f, -0.5f);
const flip_correction = vec2<f32>(0.5f, -0.5f);
let proj_correction = (1f / homogeneous_coords.w);
let light_local = (((homogeneous_coords.xy * flip_correction) * proj_correction) + vec2<f32>(0.5f, 0.5f));
let _e24 = textureSampleCompareLevel(t_shadow, sampler_shadow, light_local, i32(light_id), (homogeneous_coords.z * proj_correction));

View File

@ -1,10 +1,10 @@
fn main() {
let a = vec3<f32>(0f, 0f, 0f);
let c = vec3(0f);
let b = vec3<f32>(vec2(0f), 0f);
let d = vec3<f32>(vec2(0f), 0f);
let e = vec3<i32>(d);
let f = mat2x2<f32>(vec2<f32>(1f, 2f), vec2<f32>(3f, 4f));
let g = mat3x3<f32>(a, a, a);
const a = vec3<f32>(0f, 0f, 0f);
const c = vec3(0f);
const b = vec3<f32>(vec2(0f), 0f);
const d = vec3<f32>(vec2(0f), 0f);
const e = vec3<i32>(d);
const f = mat2x2<f32>(vec2<f32>(1f, 2f), vec2<f32>(3f, 4f));
const g = mat3x3<f32>(a, a, a);
}

View File

@ -818,6 +818,7 @@ fn convert_wgsl() {
"use-gl-ext-over-grad-workaround-if-instructed",
Targets::GLSL,
),
("local-const", Targets::IR | Targets::WGSL),
(
"math-functions",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,

View File

@ -2277,3 +2277,97 @@ fn too_many_unclosed_loops() {
.join()
.unwrap()
}
#[test]
fn local_const_wrong_type() {
check(
"
fn f() {
const c: i32 = 5u;
}
",
r###"error: the type of `c` is expected to be `i32`, but got `u32`
wgsl:3:19
3 const c: i32 = 5u;
^ definition of `c`
"###,
);
}
#[test]
fn local_const_from_let() {
check(
"
fn f() {
let a = 5;
const c = a;
}
",
r###"error: this operation is not supported in a const context
wgsl:4:23
4 const c = a;
^ operation not supported here
"###,
);
}
#[test]
fn local_const_from_var() {
check(
"
fn f() {
var a = 5;
const c = a;
}
",
r###"error: this operation is not supported in a const context
wgsl:4:23
4 const c = a;
^ operation not supported here
"###,
);
}
#[test]
fn local_const_from_override() {
check(
"
override o: i32;
fn f() {
const c = o;
}
",
r###"error: Unexpected override-expression
wgsl:4:23
4 const c = o;
^ see msg
"###,
);
}
#[test]
fn local_const_from_global_var() {
check(
"
var v: i32;
fn f() {
const c = v;
}
",
r###"error: Unexpected runtime-expression
wgsl:4:23
4 const c = v;
^ see msg
"###,
);
}