[naga] Ensure that FooResult expressions are correctly populated.

Make Naga module validation require that `CallResult` and
`AtomicResult` expressions are indeed visited by exactly one `Call` /
`Atomic` statement.
This commit is contained in:
Jim Blandy 2024-06-03 07:59:25 -07:00 committed by Teodor Tanasoaia
parent 9a27ba53ca
commit 583cc6ab04
3 changed files with 145 additions and 55 deletions

View File

@ -22,6 +22,8 @@ pub enum CallError {
},
#[error("Result expression {0:?} has already been introduced earlier")]
ResultAlreadyInScope(Handle<crate::Expression>),
#[error("Result expression {0:?} is populated by multiple `Call` statements")]
ResultAlreadyPopulated(Handle<crate::Expression>),
#[error("Result value is invalid")]
ResultValue(#[source] ExpressionError),
#[error("Requires {required} arguments, but {seen} are provided")]
@ -45,6 +47,8 @@ pub enum AtomicError {
InvalidOperand(Handle<crate::Expression>),
#[error("Result type for {0:?} doesn't match the statement")]
ResultTypeMismatch(Handle<crate::Expression>),
#[error("Result expression {0:?} is populated by multiple `Atomic` statements")]
ResultAlreadyPopulated(Handle<crate::Expression>),
}
#[derive(Clone, Debug, thiserror::Error)]
@ -174,6 +178,8 @@ pub enum FunctionError {
InvalidSubgroup(#[from] SubgroupError),
#[error("Emit statement should not cover \"result\" expressions like {0:?}")]
EmitResult(Handle<crate::Expression>),
#[error("Expression not visited by the appropriate statement")]
UnvisitedExpression(Handle<crate::Expression>),
}
bitflags::bitflags! {
@ -305,7 +311,13 @@ impl super::Validator {
}
match context.expressions[expr] {
crate::Expression::CallResult(callee)
if fun.result.is_some() && callee == function => {}
if fun.result.is_some() && callee == function =>
{
if !self.needs_visit.remove(expr.index()) {
return Err(CallError::ResultAlreadyPopulated(expr)
.with_span_handle(expr, context.expressions));
}
}
_ => {
return Err(CallError::ExpressionMismatch(result)
.with_span_handle(expr, context.expressions))
@ -397,7 +409,14 @@ impl super::Validator {
}
_ => false,
}
} => {}
} =>
{
if !self.needs_visit.remove(result.index()) {
return Err(AtomicError::ResultAlreadyPopulated(result)
.with_span_handle(result, context.expressions)
.into_other());
}
}
_ => {
return Err(AtomicError::ResultTypeMismatch(result)
.with_span_handle(result, context.expressions)
@ -1290,11 +1309,20 @@ impl super::Validator {
self.valid_expression_set.clear();
self.valid_expression_list.clear();
self.needs_visit.clear();
for (handle, expr) in fun.expressions.iter() {
if expr.needs_pre_emit() {
self.valid_expression_set.insert(handle.index());
}
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
// Mark expressions that need to be visited by a particular kind of
// statement.
if let crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } =
*expr
{
self.needs_visit.insert(handle.index());
}
match self.validate_expression(
handle,
expr,
@ -1321,6 +1349,15 @@ impl super::Validator {
)?
.stages;
info.available_stages &= stages;
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
if let Some(unvisited) = self.needs_visit.iter().next() {
let index = std::num::NonZeroU32::new(unvisited as u32 + 1).unwrap();
let handle = Handle::new(index);
return Err(FunctionError::UnvisitedExpression(handle)
.with_span_handle(handle, &fun.expressions));
}
}
}
Ok(info)
}

View File

@ -246,6 +246,26 @@ pub struct Validator {
valid_expression_set: BitSet,
override_ids: FastHashSet<u16>,
allow_overrides: bool,
/// A checklist of expressions that must be visited by a specific kind of
/// statement.
///
/// For example:
///
/// - [`CallResult`] expressions must be visited by a [`Call`] statement.
/// - [`AtomicResult`] expressions must be visited by an [`Atomic`] statement.
///
/// Be sure not to remove any [`Expression`] handle from this set unless
/// you've explicitly checked that it is the right kind of expression for
/// the visiting [`Statement`].
///
/// [`CallResult`]: crate::Expression::CallResult
/// [`Call`]: crate::Statement::Call
/// [`AtomicResult`]: crate::Expression::AtomicResult
/// [`Atomic`]: crate::Statement::Atomic
/// [`Expression`]: crate::Expression
/// [`Statement`]: crate::Statement
needs_visit: BitSet,
}
#[derive(Clone, Debug, thiserror::Error)]
@ -398,6 +418,7 @@ impl Validator {
valid_expression_set: BitSet::new(),
override_ids: FastHashSet::default(),
allow_overrides: true,
needs_visit: BitSet::new(),
}
}

View File

@ -1,18 +1,30 @@
use naga::{valid, Expression, Function, Scalar};
/// Validation should fail if `AtomicResult` expressions are not
/// populated by `Atomic` statements.
#[test]
fn emit_atomic_result() {
fn populate_atomic_result() {
use naga::{Module, Type, TypeInner};
// We want to ensure that the *only* problem with the code is the
// use of an `Emit` statement instead of an `Atomic` statement. So
// validate two versions of the module varying only in that
// aspect.
//
// Looking at uses of the `atomic` makes it easy to identify the
// differences between the two variants.
fn variant(
atomic: bool,
/// Different variants of the test case that we want to exercise.
enum Variant {
/// An `AtomicResult` expression with an `Atomic` statement
/// that populates it: valid.
Atomic,
/// An `AtomicResult` expression visited by an `Emit`
/// statement: invalid.
Emit,
/// An `AtomicResult` expression visited by no statement at
/// all: invalid
None,
}
// Looking at uses of `variant` should make it easy to identify
// the differences between the test cases.
fn try_variant(
variant: Variant,
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
let span = naga::Span::default();
let mut module = Module::default();
@ -56,21 +68,25 @@ fn emit_atomic_result() {
span,
);
if atomic {
fun.body.push(
naga::Statement::Atomic {
pointer: ex_global,
fun: naga::AtomicFunction::Add,
value: ex_42,
result: ex_result,
},
span,
);
} else {
fun.body.push(
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
span,
);
match variant {
Variant::Atomic => {
fun.body.push(
naga::Statement::Atomic {
pointer: ex_global,
fun: naga::AtomicFunction::Add,
value: ex_42,
result: ex_result,
},
span,
);
}
Variant::Emit => {
fun.body.push(
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
span,
);
}
Variant::None => {}
}
module.functions.append(fun, span);
@ -82,23 +98,34 @@ fn emit_atomic_result() {
.validate(&module)
}
variant(true).expect("module should validate");
assert!(variant(false).is_err());
try_variant(Variant::Atomic).expect("module should validate");
assert!(try_variant(Variant::Emit).is_err());
assert!(try_variant(Variant::None).is_err());
}
#[test]
fn emit_call_result() {
fn populate_call_result() {
use naga::{Module, Type, TypeInner};
// We want to ensure that the *only* problem with the code is the
// use of an `Emit` statement instead of a `Call` statement. So
// validate two versions of the module varying only in that
// aspect.
//
// Looking at uses of the `call` makes it easy to identify the
// differences between the two variants.
fn variant(
call: bool,
/// Different variants of the test case that we want to exercise.
enum Variant {
/// A `CallResult` expression with an `Call` statement that
/// populates it: valid.
Call,
/// A `CallResult` expression visited by an `Emit` statement:
/// invalid.
Emit,
/// A `CallResult` expression visited by no statement at all:
/// invalid
None,
}
// Looking at uses of `variant` should make it easy to identify
// the differences between the test cases.
fn try_variant(
variant: Variant,
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
let span = naga::Span::default();
let mut module = Module::default();
@ -130,20 +157,24 @@ fn emit_call_result() {
.expressions
.append(Expression::CallResult(fun_callee), span);
if call {
fun_caller.body.push(
naga::Statement::Call {
function: fun_callee,
arguments: vec![],
result: Some(ex_result),
},
span,
);
} else {
fun_caller.body.push(
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
span,
);
match variant {
Variant::Call => {
fun_caller.body.push(
naga::Statement::Call {
function: fun_callee,
arguments: vec![],
result: Some(ex_result),
},
span,
);
}
Variant::Emit => {
fun_caller.body.push(
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
span,
);
}
Variant::None => {}
}
module.functions.append(fun_caller, span);
@ -155,8 +186,9 @@ fn emit_call_result() {
.validate(&module)
}
variant(true).expect("should validate");
assert!(variant(false).is_err());
try_variant(Variant::Call).expect("should validate");
assert!(try_variant(Variant::Emit).is_err());
assert!(try_variant(Variant::None).is_err());
}
#[test]