[naga wgsl] Impl const_assert (#6198)

Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com>
This commit is contained in:
Samson 2024-09-02 19:37:04 +02:00 committed by GitHub
parent ace2e201cf
commit 4e9a2a5003
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 299 additions and 17 deletions

View File

@ -263,6 +263,8 @@ pub(crate) enum Error<'a> {
limit: u8,
},
PipelineConstantIDValue(Span),
NotBool(Span),
ConstAssertFailed(Span),
}
#[derive(Clone, Debug)]
@ -815,6 +817,22 @@ impl<'a> Error<'a> {
)],
notes: vec![],
},
Error::NotBool(span) => ParseError {
message: "must be a const-expression that resolves to a bool".to_string(),
labels: vec![(
span,
"must resolve to bool".into(),
)],
notes: vec![],
},
Error::ConstAssertFailed(span) => ParseError {
message: "const_assert failure".to_string(),
labels: vec![(
span,
"evaluates to false".into(),
)],
notes: vec![],
},
}
}
}

View File

@ -20,13 +20,16 @@ impl<'a> Index<'a> {
// While doing so, reject conflicting definitions.
let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default());
for (handle, decl) in tu.decls.iter() {
let ident = decl_ident(decl);
let name = ident.name;
if let Some(old) = globals.insert(name, handle) {
return Err(Error::Redefinition {
previous: decl_ident(&tu.decls[old]).span,
current: ident.span,
});
if let Some(ident) = decl_ident(decl) {
let name = ident.name;
if let Some(old) = globals.insert(name, handle) {
return Err(Error::Redefinition {
previous: decl_ident(&tu.decls[old])
.expect("decl should have ident for redefinition")
.span,
current: ident.span,
});
}
}
}
@ -130,7 +133,7 @@ impl<'a> DependencySolver<'a, '_> {
return if dep_id == id {
// A declaration refers to itself directly.
Err(Error::RecursiveDeclaration {
ident: decl_ident(decl).span,
ident: decl_ident(decl).expect("decl should have ident").span,
usage: dep.usage,
})
} else {
@ -146,14 +149,19 @@ impl<'a> DependencySolver<'a, '_> {
.unwrap_or(0);
Err(Error::CyclicDeclaration {
ident: decl_ident(&self.module.decls[dep_id]).span,
ident: decl_ident(&self.module.decls[dep_id])
.expect("decl should have ident")
.span,
path: self.path[start_at..]
.iter()
.map(|curr_dep| {
let curr_id = curr_dep.decl;
let curr_decl = &self.module.decls[curr_id];
(decl_ident(curr_decl).span, curr_dep.usage)
(
decl_ident(curr_decl).expect("decl should have ident").span,
curr_dep.usage,
)
})
.collect(),
})
@ -182,13 +190,14 @@ impl<'a> DependencySolver<'a, '_> {
}
}
const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> Option<ast::Ident<'a>> {
match decl.kind {
ast::GlobalDeclKind::Fn(ref f) => f.name,
ast::GlobalDeclKind::Var(ref v) => v.name,
ast::GlobalDeclKind::Const(ref c) => c.name,
ast::GlobalDeclKind::Override(ref o) => o.name,
ast::GlobalDeclKind::Struct(ref s) => s.name,
ast::GlobalDeclKind::Type(ref t) => t.name,
ast::GlobalDeclKind::Fn(ref f) => Some(f.name),
ast::GlobalDeclKind::Var(ref v) => Some(v.name),
ast::GlobalDeclKind::Const(ref c) => Some(c.name),
ast::GlobalDeclKind::Override(ref o) => Some(o.name),
ast::GlobalDeclKind::Struct(ref s) => Some(s.name),
ast::GlobalDeclKind::Type(ref t) => Some(t.name),
ast::GlobalDeclKind::ConstAssert(_) => None,
}
}

View File

@ -1204,6 +1204,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ctx.globals
.insert(alias.name.name, LoweredGlobalDecl::Type(ty));
}
ast::GlobalDeclKind::ConstAssert(condition) => {
let condition = self.expression(condition, &mut ctx.as_const())?;
let span = ctx.module.global_expressions.get_span(condition);
match ctx
.module
.to_ctx()
.eval_expr_to_bool_from(condition, &ctx.module.global_expressions)
{
Some(true) => Ok(()),
Some(false) => Err(Error::ConstAssertFailed(span)),
_ => Err(Error::NotBool(span)),
}?;
}
}
}
@ -1742,6 +1756,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
value,
}
}
ast::StatementKind::ConstAssert(condition) => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);
let condition =
self.expression(condition, &mut ctx.as_const(block, &mut emitter))?;
let span = ctx.function.expressions.get_span(condition);
match ctx
.module
.to_ctx()
.eval_expr_to_bool_from(condition, &ctx.function.expressions)
{
Some(true) => Ok(()),
Some(false) => Err(Error::ConstAssertFailed(span)),
_ => Err(Error::NotBool(span)),
}?;
block.extend(emitter.finish(&ctx.function.expressions));
return Ok(());
}
ast::StatementKind::Ignore(expr) => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);

View File

@ -85,6 +85,7 @@ pub enum GlobalDeclKind<'a> {
Override(Override<'a>),
Struct(Struct<'a>),
Type(TypeAlias<'a>),
ConstAssert(Handle<Expression<'a>>),
}
#[derive(Debug)]
@ -284,6 +285,7 @@ pub enum StatementKind<'a> {
Increment(Handle<Expression<'a>>),
Decrement(Handle<Expression<'a>>),
Ignore(Handle<Expression<'a>>),
ConstAssert(Handle<Expression<'a>>),
}
#[derive(Debug)]

View File

@ -2016,6 +2016,20 @@ impl Parser {
lexer.expect(Token::Separator(';'))?;
ast::StatementKind::Kill
}
// https://www.w3.org/TR/WGSL/#const-assert-statement
"const_assert" => {
let _ = lexer.next();
// parentheses are optional
let paren = lexer.skip(Token::Paren('('));
let condition = self.general_expression(lexer, ctx)?;
if paren {
lexer.expect(Token::Paren(')'))?;
}
lexer.expect(Token::Separator(';'))?;
ast::StatementKind::ConstAssert(condition)
}
// assignment or a function call
_ => {
self.function_call_or_assignment_statement(lexer, ctx, block)?;
@ -2419,6 +2433,18 @@ impl Parser {
..function
}))
}
(Token::Word("const_assert"), _) => {
// parentheses are optional
let paren = lexer.skip(Token::Paren('('));
let condition = self.general_expression(lexer, &mut ctx)?;
if paren {
lexer.expect(Token::Paren(')'))?;
}
lexer.expect(Token::Separator(';'))?;
Some(ast::GlobalDeclKind::ConstAssert(condition))
}
(Token::End, _) => return Ok(()),
other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)),
};

View File

@ -674,6 +674,19 @@ impl GlobalCtx<'_> {
}
}
/// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
#[allow(dead_code)]
pub(super) fn eval_expr_to_bool_from(
&self,
handle: crate::Handle<crate::Expression>,
arena: &crate::Arena<crate::Expression>,
) -> Option<bool> {
match self.eval_expr_to_literal_from(handle, arena) {
Some(crate::Literal::Bool(value)) => Some(value),
_ => None,
}
}
#[allow(dead_code)]
pub(crate) fn eval_expr_to_literal(
&self,

View File

@ -0,0 +1,11 @@
// Sourced from https://www.w3.org/TR/WGSL/#const-assert-statement
const x = 1;
const y = 2;
const_assert x < y; // valid at module-scope.
const_assert(y != 0); // parentheses are optional.
fn foo() {
const z = x + y - 2;
const_assert z > 0; // valid in functions.
const_assert(z > 0);
}

View File

@ -0,0 +1,54 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [
(
name: Some("x"),
ty: 0,
init: 0,
),
(
name: Some("y"),
ty: 0,
init: 1,
),
],
overrides: [],
global_variables: [],
global_expressions: [
Literal(I32(1)),
Literal(I32(2)),
],
functions: [
(
name: Some("foo"),
arguments: [],
result: None,
local_variables: [],
expressions: [
Literal(I32(1)),
],
named_expressions: {
0: "z",
},
body: [
Return(
value: None,
),
],
),
],
entry_points: [],
)

View File

@ -0,0 +1,54 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [
(
name: Some("x"),
ty: 0,
init: 0,
),
(
name: Some("y"),
ty: 0,
init: 1,
),
],
overrides: [],
global_variables: [],
global_expressions: [
Literal(I32(1)),
Literal(I32(2)),
],
functions: [
(
name: Some("foo"),
arguments: [],
result: None,
local_variables: [],
expressions: [
Literal(I32(1)),
],
named_expressions: {
0: "z",
},
body: [
Return(
value: None,
),
],
),
],
entry_points: [],
)

View File

@ -0,0 +1,7 @@
const x: i32 = 1i;
const y: i32 = 2i;
fn foo() {
return;
}

View File

@ -868,6 +868,7 @@ fn convert_wgsl() {
"const-exprs",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
("const_assert", Targets::WGSL | Targets::IR),
("separate-entry-points", Targets::SPIRV | Targets::GLSL),
(
"struct-layout",

View File

@ -2389,3 +2389,54 @@ fn only_one_swizzle_type() {
"###,
);
}
#[test]
fn const_assert_must_be_const() {
check(
"
fn foo() {
let a = 5;
const_assert a != 0;
}
",
r###"error: this operation is not supported in a const context
wgsl:4:26
4 const_assert a != 0;
^ operation not supported here
"###,
);
}
#[test]
fn const_assert_must_be_bool() {
check(
"
const_assert(5); // 5 is not bool
",
r###"error: must be a const-expression that resolves to a bool
wgsl:2:26
2 const_assert(5); // 5 is not bool
^ must resolve to bool
"###,
);
}
#[test]
fn const_assert_failed() {
check(
"
const_assert(false);
",
r###"error: const_assert failure
wgsl:2:26
2 const_assert(false);
^^^^^ evaluates to false
"###,
);
}