[naga spv-out] Consolidate code to find index values.

Let the SPIR-V backend use `GuardedIndex::try_resolve_to_constant`,
rather than writing out its definition in `write_restricted_index` and
`write_index_comparison`.

Call `try_resolve_to_constant` in one place, in `write_bounds_check`,
and simply pass the `GuardedIndex` into subroutines.

Reduce `write_restricted_index` and `write_index_comparison` to case
analysis and code generation.

Note that this commit does have a benign effect on SPIR-V snapshot
output for programs like this:

    let one_i = 1i;
    var vec0 = vec3<i32>();
    vec0[one_i] = 1;

The value indexing `vec0` here is an `i32`, but after this commit, the
operand to `OpAccessChain` becomes a `u32` constant (with the same
value).

This is because `write_bounds_check` now calls
`try_resolve_to_constant` itself, rather than deferring this work to
its callees, so it may return `BoundsCheckResult::KnownInBounds` even
when the `Unchecked` policy is in force. This directs the caller,
`write_expression_pointer`, to treat the `OpAccessChain` operand as a
fresh `u32` constant, rather than simply passing through the original
`i32` expression.
This commit is contained in:
Jim Blandy 2024-10-01 10:11:41 -07:00
parent 287ca16b52
commit 3d85781f05
3 changed files with 62 additions and 68 deletions

View File

@ -7,13 +7,17 @@ use super::{
selection::Selection,
Block, BlockContext, Error, IdGenerator, Instruction, Word,
};
use crate::{arena::Handle, proc::BoundsCheckPolicy};
use crate::{
arena::Handle,
proc::{index::GuardedIndex, BoundsCheckPolicy},
};
/// The results of performing a bounds check.
///
/// On success, [`write_bounds_check`](BlockContext::write_bounds_check)
/// returns a value of this type. The caller can assume that the right
/// policy has been applied, and simply do what the variant says.
#[derive(Debug)]
pub(super) enum BoundsCheckResult {
/// The index is statically known and in bounds, with the given value.
KnownInBounds(u32),
@ -40,6 +44,7 @@ pub(super) enum BoundsCheckResult {
}
/// A value that we either know at translation time, or need to compute at runtime.
#[derive(Copy, Clone)]
pub(super) enum MaybeKnown<T> {
/// The value is known at shader translation time.
Known(T),
@ -329,33 +334,26 @@ impl<'w> BlockContext<'w> {
pub(super) fn write_restricted_index(
&mut self,
sequence: Handle<crate::Expression>,
index: Handle<crate::Expression>,
index: GuardedIndex,
block: &mut Block,
) -> Result<BoundsCheckResult, Error> {
let index_id = self.cached[index];
let max_index = self.write_sequence_max_index(sequence, block)?;
// Get the sequence's maximum valid index. Return early if we've already
// done the bounds check.
let max_index_id = match self.write_sequence_max_index(sequence, block)? {
MaybeKnown::Known(known_max_index) => {
if let Ok(known_index) = self
.ir_module
.to_ctx()
.eval_expr_to_u32_from(index, &self.ir_function.expressions)
{
// Both the index and length are known at compile time.
//
// In strict WGSL compliance mode, out-of-bounds indices cannot be
// reported at shader translation time, and must be replaced with
// in-bounds indices at run time. So we cannot assume that
// validation ensured the index was in bounds. Restrict now.
let restricted = std::cmp::min(known_index, known_max_index);
// If both are known, we can compute the index to be used
// right now.
if let (GuardedIndex::Known(index), MaybeKnown::Known(max_index)) = (index, max_index) {
let restricted = std::cmp::min(index, max_index);
return Ok(BoundsCheckResult::KnownInBounds(restricted));
}
self.get_index_constant(known_max_index)
}
MaybeKnown::Computed(max_index_id) => max_index_id,
let index_id = match index {
GuardedIndex::Known(value) => self.get_index_constant(value),
GuardedIndex::Expression(expr) => self.cached[expr],
};
let max_index_id = match max_index {
MaybeKnown::Known(value) => self.get_index_constant(value),
MaybeKnown::Computed(id) => id,
};
// One or the other of the index or length is dynamic, so emit code for
@ -393,48 +391,33 @@ impl<'w> BlockContext<'w> {
fn write_index_comparison(
&mut self,
sequence: Handle<crate::Expression>,
index: Handle<crate::Expression>,
index: GuardedIndex,
block: &mut Block,
) -> Result<BoundsCheckResult, Error> {
let index_id = self.cached[index];
let length = self.write_sequence_length(sequence, block)?;
// Get the sequence's length. Return early if we've already done the
// bounds check.
let length_id = match self.write_sequence_length(sequence, block)? {
MaybeKnown::Known(known_length) => {
if let Ok(known_index) = self
.ir_module
.to_ctx()
.eval_expr_to_u32_from(index, &self.ir_function.expressions)
{
// Both the index and length are known at compile time.
//
// It would be nice to assume that, since we are using the
// `ReadZeroSkipWrite` policy, we are not in strict WGSL
// compliance mode, and thus we can count on the validator to have
// rejected any programs with known out-of-bounds indices, and
// thus just return `KnownInBounds` here without actually
// checking.
//
// But it's also reasonable to expect that bounds check policies
// and error reporting policies should be able to vary
// independently without introducing security holes. So, we should
// support the case where bad indices do not cause validation
// errors, and are handled via `ReadZeroSkipWrite`.
//
// In theory, when `known_index` is bad, we could return a new
// If both are known, we can decide whether the index is in
// bounds right now.
if let (GuardedIndex::Known(index), MaybeKnown::Known(length)) = (index, length) {
if index < length {
return Ok(BoundsCheckResult::KnownInBounds(index));
}
// In theory, when `index` is bad, we could return a new
// `KnownOutOfBounds` variant here. But it's simpler just to fall
// through and let the bounds check take place. The shader is
// broken anyway, so it doesn't make sense to invest in emitting
// the ideal code for it.
if known_index < known_length {
return Ok(BoundsCheckResult::KnownInBounds(known_index));
}
// through and let the bounds check take place. The shader is broken
// anyway, so it doesn't make sense to invest in emitting the ideal
// code for it.
}
self.get_index_constant(known_length)
}
MaybeKnown::Computed(length_id) => length_id,
let index_id = match index {
GuardedIndex::Known(value) => self.get_index_constant(value),
GuardedIndex::Expression(expr) => self.cached[expr],
};
let length_id = match length {
MaybeKnown::Known(value) => self.get_index_constant(value),
MaybeKnown::Computed(id) => id,
};
// Compare the index against the length.
@ -519,6 +502,10 @@ impl<'w> BlockContext<'w> {
index: Handle<crate::Expression>,
block: &mut Block,
) -> Result<BoundsCheckResult, Error> {
// If the value of `index` is known at compile time, find it now.
let mut index = GuardedIndex::Expression(index);
index.try_resolve_to_constant(self.ir_function, self.ir_module);
let policy = self.writer.bounds_check_policies.choose_policy(
base,
&self.ir_module.types,
@ -530,7 +517,10 @@ impl<'w> BlockContext<'w> {
BoundsCheckPolicy::ReadZeroSkipWrite => {
self.write_index_comparison(base, index, block)?
}
BoundsCheckPolicy::Unchecked => BoundsCheckResult::Computed(self.cached[index]),
BoundsCheckPolicy::Unchecked => match index {
GuardedIndex::Known(value) => BoundsCheckResult::KnownInBounds(value),
GuardedIndex::Expression(expr) => BoundsCheckResult::Computed(self.cached[expr]),
},
})
}

View File

@ -334,7 +334,11 @@ impl GuardedIndex {
/// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible.
///
/// Return values that are already `Known` unchanged.
fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) {
pub(crate) fn try_resolve_to_constant(
&mut self,
function: &crate::Function,
module: &crate::Module,
) {
if let GuardedIndex::Expression(expr) = *self {
if let Ok(value) = module
.to_ctx()

View File

@ -387,15 +387,15 @@ OpStore %302 %331
%332 = OpLoad %5 %302
%333 = OpISub %5 %332 %23
OpStore %302 %333
%335 = OpAccessChain %334 %305 %23
%335 = OpAccessChain %334 %305 %122
%336 = OpLoad %5 %335
%337 = OpIAdd %5 %336 %23
%338 = OpAccessChain %334 %305 %23
%338 = OpAccessChain %334 %305 %122
OpStore %338 %337
%339 = OpAccessChain %334 %305 %23
%339 = OpAccessChain %334 %305 %122
%340 = OpLoad %5 %339
%341 = OpISub %5 %340 %23
%342 = OpAccessChain %334 %305 %23
%342 = OpAccessChain %334 %305 %122
OpStore %342 %341
OpReturn
OpFunctionEnd