mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-02-16 08:53:20 +00:00
[naga wgsl] Impl const_assert
(#6198)
Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com>
This commit is contained in:
parent
ace2e201cf
commit
4e9a2a5003
@ -263,6 +263,8 @@ pub(crate) enum Error<'a> {
|
||||
limit: u8,
|
||||
},
|
||||
PipelineConstantIDValue(Span),
|
||||
NotBool(Span),
|
||||
ConstAssertFailed(Span),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@ -815,6 +817,22 @@ impl<'a> Error<'a> {
|
||||
)],
|
||||
notes: vec![],
|
||||
},
|
||||
Error::NotBool(span) => ParseError {
|
||||
message: "must be a const-expression that resolves to a bool".to_string(),
|
||||
labels: vec![(
|
||||
span,
|
||||
"must resolve to bool".into(),
|
||||
)],
|
||||
notes: vec![],
|
||||
},
|
||||
Error::ConstAssertFailed(span) => ParseError {
|
||||
message: "const_assert failure".to_string(),
|
||||
labels: vec![(
|
||||
span,
|
||||
"evaluates to false".into(),
|
||||
)],
|
||||
notes: vec![],
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -20,13 +20,16 @@ impl<'a> Index<'a> {
|
||||
// While doing so, reject conflicting definitions.
|
||||
let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default());
|
||||
for (handle, decl) in tu.decls.iter() {
|
||||
let ident = decl_ident(decl);
|
||||
let name = ident.name;
|
||||
if let Some(old) = globals.insert(name, handle) {
|
||||
return Err(Error::Redefinition {
|
||||
previous: decl_ident(&tu.decls[old]).span,
|
||||
current: ident.span,
|
||||
});
|
||||
if let Some(ident) = decl_ident(decl) {
|
||||
let name = ident.name;
|
||||
if let Some(old) = globals.insert(name, handle) {
|
||||
return Err(Error::Redefinition {
|
||||
previous: decl_ident(&tu.decls[old])
|
||||
.expect("decl should have ident for redefinition")
|
||||
.span,
|
||||
current: ident.span,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -130,7 +133,7 @@ impl<'a> DependencySolver<'a, '_> {
|
||||
return if dep_id == id {
|
||||
// A declaration refers to itself directly.
|
||||
Err(Error::RecursiveDeclaration {
|
||||
ident: decl_ident(decl).span,
|
||||
ident: decl_ident(decl).expect("decl should have ident").span,
|
||||
usage: dep.usage,
|
||||
})
|
||||
} else {
|
||||
@ -146,14 +149,19 @@ impl<'a> DependencySolver<'a, '_> {
|
||||
.unwrap_or(0);
|
||||
|
||||
Err(Error::CyclicDeclaration {
|
||||
ident: decl_ident(&self.module.decls[dep_id]).span,
|
||||
ident: decl_ident(&self.module.decls[dep_id])
|
||||
.expect("decl should have ident")
|
||||
.span,
|
||||
path: self.path[start_at..]
|
||||
.iter()
|
||||
.map(|curr_dep| {
|
||||
let curr_id = curr_dep.decl;
|
||||
let curr_decl = &self.module.decls[curr_id];
|
||||
|
||||
(decl_ident(curr_decl).span, curr_dep.usage)
|
||||
(
|
||||
decl_ident(curr_decl).expect("decl should have ident").span,
|
||||
curr_dep.usage,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
@ -182,13 +190,14 @@ impl<'a> DependencySolver<'a, '_> {
|
||||
}
|
||||
}
|
||||
|
||||
const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
|
||||
const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> Option<ast::Ident<'a>> {
|
||||
match decl.kind {
|
||||
ast::GlobalDeclKind::Fn(ref f) => f.name,
|
||||
ast::GlobalDeclKind::Var(ref v) => v.name,
|
||||
ast::GlobalDeclKind::Const(ref c) => c.name,
|
||||
ast::GlobalDeclKind::Override(ref o) => o.name,
|
||||
ast::GlobalDeclKind::Struct(ref s) => s.name,
|
||||
ast::GlobalDeclKind::Type(ref t) => t.name,
|
||||
ast::GlobalDeclKind::Fn(ref f) => Some(f.name),
|
||||
ast::GlobalDeclKind::Var(ref v) => Some(v.name),
|
||||
ast::GlobalDeclKind::Const(ref c) => Some(c.name),
|
||||
ast::GlobalDeclKind::Override(ref o) => Some(o.name),
|
||||
ast::GlobalDeclKind::Struct(ref s) => Some(s.name),
|
||||
ast::GlobalDeclKind::Type(ref t) => Some(t.name),
|
||||
ast::GlobalDeclKind::ConstAssert(_) => None,
|
||||
}
|
||||
}
|
||||
|
@ -1204,6 +1204,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
ctx.globals
|
||||
.insert(alias.name.name, LoweredGlobalDecl::Type(ty));
|
||||
}
|
||||
ast::GlobalDeclKind::ConstAssert(condition) => {
|
||||
let condition = self.expression(condition, &mut ctx.as_const())?;
|
||||
|
||||
let span = ctx.module.global_expressions.get_span(condition);
|
||||
match ctx
|
||||
.module
|
||||
.to_ctx()
|
||||
.eval_expr_to_bool_from(condition, &ctx.module.global_expressions)
|
||||
{
|
||||
Some(true) => Ok(()),
|
||||
Some(false) => Err(Error::ConstAssertFailed(span)),
|
||||
_ => Err(Error::NotBool(span)),
|
||||
}?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1742,6 +1756,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
value,
|
||||
}
|
||||
}
|
||||
ast::StatementKind::ConstAssert(condition) => {
|
||||
let mut emitter = Emitter::default();
|
||||
emitter.start(&ctx.function.expressions);
|
||||
|
||||
let condition =
|
||||
self.expression(condition, &mut ctx.as_const(block, &mut emitter))?;
|
||||
|
||||
let span = ctx.function.expressions.get_span(condition);
|
||||
match ctx
|
||||
.module
|
||||
.to_ctx()
|
||||
.eval_expr_to_bool_from(condition, &ctx.function.expressions)
|
||||
{
|
||||
Some(true) => Ok(()),
|
||||
Some(false) => Err(Error::ConstAssertFailed(span)),
|
||||
_ => Err(Error::NotBool(span)),
|
||||
}?;
|
||||
|
||||
block.extend(emitter.finish(&ctx.function.expressions));
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
ast::StatementKind::Ignore(expr) => {
|
||||
let mut emitter = Emitter::default();
|
||||
emitter.start(&ctx.function.expressions);
|
||||
|
@ -85,6 +85,7 @@ pub enum GlobalDeclKind<'a> {
|
||||
Override(Override<'a>),
|
||||
Struct(Struct<'a>),
|
||||
Type(TypeAlias<'a>),
|
||||
ConstAssert(Handle<Expression<'a>>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -284,6 +285,7 @@ pub enum StatementKind<'a> {
|
||||
Increment(Handle<Expression<'a>>),
|
||||
Decrement(Handle<Expression<'a>>),
|
||||
Ignore(Handle<Expression<'a>>),
|
||||
ConstAssert(Handle<Expression<'a>>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -2016,6 +2016,20 @@ impl Parser {
|
||||
lexer.expect(Token::Separator(';'))?;
|
||||
ast::StatementKind::Kill
|
||||
}
|
||||
// https://www.w3.org/TR/WGSL/#const-assert-statement
|
||||
"const_assert" => {
|
||||
let _ = lexer.next();
|
||||
// parentheses are optional
|
||||
let paren = lexer.skip(Token::Paren('('));
|
||||
|
||||
let condition = self.general_expression(lexer, ctx)?;
|
||||
|
||||
if paren {
|
||||
lexer.expect(Token::Paren(')'))?;
|
||||
}
|
||||
lexer.expect(Token::Separator(';'))?;
|
||||
ast::StatementKind::ConstAssert(condition)
|
||||
}
|
||||
// assignment or a function call
|
||||
_ => {
|
||||
self.function_call_or_assignment_statement(lexer, ctx, block)?;
|
||||
@ -2419,6 +2433,18 @@ impl Parser {
|
||||
..function
|
||||
}))
|
||||
}
|
||||
(Token::Word("const_assert"), _) => {
|
||||
// parentheses are optional
|
||||
let paren = lexer.skip(Token::Paren('('));
|
||||
|
||||
let condition = self.general_expression(lexer, &mut ctx)?;
|
||||
|
||||
if paren {
|
||||
lexer.expect(Token::Paren(')'))?;
|
||||
}
|
||||
lexer.expect(Token::Separator(';'))?;
|
||||
Some(ast::GlobalDeclKind::ConstAssert(condition))
|
||||
}
|
||||
(Token::End, _) => return Ok(()),
|
||||
other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)),
|
||||
};
|
||||
|
@ -674,6 +674,19 @@ impl GlobalCtx<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
|
||||
#[allow(dead_code)]
|
||||
pub(super) fn eval_expr_to_bool_from(
|
||||
&self,
|
||||
handle: crate::Handle<crate::Expression>,
|
||||
arena: &crate::Arena<crate::Expression>,
|
||||
) -> Option<bool> {
|
||||
match self.eval_expr_to_literal_from(handle, arena) {
|
||||
Some(crate::Literal::Bool(value)) => Some(value),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn eval_expr_to_literal(
|
||||
&self,
|
||||
|
11
naga/tests/in/const_assert.wgsl
Normal file
11
naga/tests/in/const_assert.wgsl
Normal file
@ -0,0 +1,11 @@
|
||||
// Sourced from https://www.w3.org/TR/WGSL/#const-assert-statement
|
||||
const x = 1;
|
||||
const y = 2;
|
||||
const_assert x < y; // valid at module-scope.
|
||||
const_assert(y != 0); // parentheses are optional.
|
||||
|
||||
fn foo() {
|
||||
const z = x + y - 2;
|
||||
const_assert z > 0; // valid in functions.
|
||||
const_assert(z > 0);
|
||||
}
|
54
naga/tests/out/ir/const_assert.compact.ron
Normal file
54
naga/tests/out/ir/const_assert.compact.ron
Normal file
@ -0,0 +1,54 @@
|
||||
(
|
||||
types: [
|
||||
(
|
||||
name: None,
|
||||
inner: Scalar((
|
||||
kind: Sint,
|
||||
width: 4,
|
||||
)),
|
||||
),
|
||||
],
|
||||
special_types: (
|
||||
ray_desc: None,
|
||||
ray_intersection: None,
|
||||
predeclared_types: {},
|
||||
),
|
||||
constants: [
|
||||
(
|
||||
name: Some("x"),
|
||||
ty: 0,
|
||||
init: 0,
|
||||
),
|
||||
(
|
||||
name: Some("y"),
|
||||
ty: 0,
|
||||
init: 1,
|
||||
),
|
||||
],
|
||||
overrides: [],
|
||||
global_variables: [],
|
||||
global_expressions: [
|
||||
Literal(I32(1)),
|
||||
Literal(I32(2)),
|
||||
],
|
||||
functions: [
|
||||
(
|
||||
name: Some("foo"),
|
||||
arguments: [],
|
||||
result: None,
|
||||
local_variables: [],
|
||||
expressions: [
|
||||
Literal(I32(1)),
|
||||
],
|
||||
named_expressions: {
|
||||
0: "z",
|
||||
},
|
||||
body: [
|
||||
Return(
|
||||
value: None,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
entry_points: [],
|
||||
)
|
54
naga/tests/out/ir/const_assert.ron
Normal file
54
naga/tests/out/ir/const_assert.ron
Normal file
@ -0,0 +1,54 @@
|
||||
(
|
||||
types: [
|
||||
(
|
||||
name: None,
|
||||
inner: Scalar((
|
||||
kind: Sint,
|
||||
width: 4,
|
||||
)),
|
||||
),
|
||||
],
|
||||
special_types: (
|
||||
ray_desc: None,
|
||||
ray_intersection: None,
|
||||
predeclared_types: {},
|
||||
),
|
||||
constants: [
|
||||
(
|
||||
name: Some("x"),
|
||||
ty: 0,
|
||||
init: 0,
|
||||
),
|
||||
(
|
||||
name: Some("y"),
|
||||
ty: 0,
|
||||
init: 1,
|
||||
),
|
||||
],
|
||||
overrides: [],
|
||||
global_variables: [],
|
||||
global_expressions: [
|
||||
Literal(I32(1)),
|
||||
Literal(I32(2)),
|
||||
],
|
||||
functions: [
|
||||
(
|
||||
name: Some("foo"),
|
||||
arguments: [],
|
||||
result: None,
|
||||
local_variables: [],
|
||||
expressions: [
|
||||
Literal(I32(1)),
|
||||
],
|
||||
named_expressions: {
|
||||
0: "z",
|
||||
},
|
||||
body: [
|
||||
Return(
|
||||
value: None,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
entry_points: [],
|
||||
)
|
7
naga/tests/out/wgsl/const_assert.wgsl
Normal file
7
naga/tests/out/wgsl/const_assert.wgsl
Normal file
@ -0,0 +1,7 @@
|
||||
const x: i32 = 1i;
|
||||
const y: i32 = 2i;
|
||||
|
||||
fn foo() {
|
||||
return;
|
||||
}
|
||||
|
@ -868,6 +868,7 @@ fn convert_wgsl() {
|
||||
"const-exprs",
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
),
|
||||
("const_assert", Targets::WGSL | Targets::IR),
|
||||
("separate-entry-points", Targets::SPIRV | Targets::GLSL),
|
||||
(
|
||||
"struct-layout",
|
||||
|
@ -2389,3 +2389,54 @@ fn only_one_swizzle_type() {
|
||||
"###,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn const_assert_must_be_const() {
|
||||
check(
|
||||
"
|
||||
fn foo() {
|
||||
let a = 5;
|
||||
const_assert a != 0;
|
||||
}
|
||||
",
|
||||
r###"error: this operation is not supported in a const context
|
||||
┌─ wgsl:4:26
|
||||
│
|
||||
4 │ const_assert a != 0;
|
||||
│ ^ operation not supported here
|
||||
|
||||
"###,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn const_assert_must_be_bool() {
|
||||
check(
|
||||
"
|
||||
const_assert(5); // 5 is not bool
|
||||
",
|
||||
r###"error: must be a const-expression that resolves to a bool
|
||||
┌─ wgsl:2:26
|
||||
│
|
||||
2 │ const_assert(5); // 5 is not bool
|
||||
│ ^ must resolve to bool
|
||||
|
||||
"###,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn const_assert_failed() {
|
||||
check(
|
||||
"
|
||||
const_assert(false);
|
||||
",
|
||||
r###"error: const_assert failure
|
||||
┌─ wgsl:2:26
|
||||
│
|
||||
2 │ const_assert(false);
|
||||
│ ^^^^^ evaluates to false
|
||||
|
||||
"###,
|
||||
);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user