refactor(msl-out): extract bounds_check_iter helper

--
Co-authored-by: Liam Murphy <liampm32@gmail.com>
Co-Authored-By: Erich Gubler <erichdongubler@gmail.com>
This commit is contained in:
Andy Leiserson 2025-03-11 18:27:02 -07:00 committed by Erich Gubler
parent aad187f52f
commit 587aea2da6
2 changed files with 87 additions and 54 deletions

View File

@ -695,6 +695,20 @@ impl<'a> ExpressionContext<'a> {
)
}
/// See docs for [`proc::index::bounds_check_iter`].
fn bounds_check_iter(
&self,
chain: Handle<crate::Expression>,
) -> impl Iterator<
Item = (
Handle<crate::Expression>,
index::GuardedIndex,
index::IndexableLength,
),
> + '_ {
index::bounds_check_iter(chain, self.module, self.function, self.info)
}
fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
match self.function.expressions[expr_handle] {
crate::Expression::AccessIndex { base, index } => {
@ -2647,68 +2661,38 @@ impl<W: Write> Writer<W> {
#[allow(unused_variables)]
fn put_bounds_checks(
&mut self,
mut chain: Handle<crate::Expression>,
chain: Handle<crate::Expression>,
context: &ExpressionContext,
level: back::Level,
prefix: &'static str,
) -> Result<bool, Error> {
let mut check_written = false;
// Iterate over the access chain, handling each expression.
loop {
// Produce a `GuardedIndex`, so we can shared code between the
// `Access` and `AccessIndex` cases.
let (base, guarded_index) = match context.function.expressions[chain] {
crate::Expression::Access { base, index } => {
(base, Some(index::GuardedIndex::Expression(index)))
}
crate::Expression::AccessIndex { base, index } => {
// Don't try to check indices into structs. Validation already took
// care of them, and index::needs_guard doesn't handle that case.
let mut base_inner = context.resolve_type(base);
if let crate::TypeInner::Pointer { base, .. } = *base_inner {
base_inner = &context.module.types[base].inner;
}
match *base_inner {
crate::TypeInner::Struct { .. } => (base, None),
_ => (base, Some(index::GuardedIndex::Known(index))),
}
}
_ => break,
};
if let Some(index) = guarded_index {
if let Some(length) = context.access_needs_check(base, index) {
if check_written {
write!(self.out, " && ")?;
} else {
write!(self.out, "{level}{prefix}")?;
check_written = true;
}
// Check that the index falls within bounds. Do this with a single
// comparison, by casting the index to `uint` first, so that negative
// indices become large positive values.
write!(self.out, "uint(")?;
self.put_index(index, context, true)?;
self.out.write_str(") < ")?;
match length {
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
index::IndexableLength::Dynamic => {
let global =
context.function.originating_global(base).ok_or_else(|| {
Error::GenericValidation(
"Could not find originating global".into(),
)
})?;
write!(self.out, "1 + ")?;
self.put_dynamic_array_max_index(global, context)?
}
}
}
// Iterate over the access chain, handling each required bounds check.
for (base, index, length) in context.bounds_check_iter(chain) {
if check_written {
write!(self.out, " && ")?;
} else {
write!(self.out, "{level}{prefix}")?;
check_written = true;
}
chain = base
// Check that the index falls within bounds. Do this with a single
// comparison, by casting the index to `uint` first, so that negative
// indices become large positive values.
write!(self.out, "uint(")?;
self.put_index(index, context, true)?;
self.out.write_str(") < ")?;
match length {
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
index::IndexableLength::Dynamic => {
let global = context.function.originating_global(base).ok_or_else(|| {
Error::GenericValidation("Could not find originating global".into())
})?;
write!(self.out, "1 + ")?;
self.put_dynamic_array_max_index(global, context)?
}
}
}
Ok(check_written)

View File

@ -2,6 +2,8 @@
Definitions for index bounds checking.
*/
use core::iter;
use crate::arena::{Handle, HandleSet, UniqueArena};
use crate::valid;
@ -340,6 +342,53 @@ pub fn access_needs_check(
Some(length)
}
/// Returns an iterator of accesses within the chain of `Access` and
/// `AccessIndex` expressions starting from `chain` that may need to be
/// bounds-checked at runtime.
///
/// They're yielded as `(base, index)` pairs, where `base` is the type that the
/// access expression will produce and `index` is the index being used.
///
/// Accesses through a struct are omitted, since you never need a bounds check
/// for accessing a struct field.
///
/// If `chain` isn't an `Access` or `AccessIndex` expression at all, the
/// iterator is empty.
pub(crate) fn bounds_check_iter<'a>(
mut chain: Handle<crate::Expression>,
module: &'a crate::Module,
function: &'a crate::Function,
info: &'a valid::FunctionInfo,
) -> impl Iterator<Item = (Handle<crate::Expression>, GuardedIndex, IndexableLength)> + 'a {
iter::from_fn(move || {
let (next_expr, result) = match function.expressions[chain] {
crate::Expression::Access { base, index } => {
(base, Some((base, GuardedIndex::Expression(index))))
}
crate::Expression::AccessIndex { base, index } => {
// Don't try to check indices into structs. Validation already took
// care of them, and access_needs_check doesn't handle that case.
let mut base_inner = info[base].ty.inner_with(&module.types);
if let crate::TypeInner::Pointer { base, .. } = *base_inner {
base_inner = &module.types[base].inner;
}
match *base_inner {
crate::TypeInner::Struct { .. } => (base, None),
_ => (base, Some((base, GuardedIndex::Known(index)))),
}
}
_ => return None,
};
chain = next_expr;
Some(result)
})
.flatten()
.filter_map(|(base, index)| {
access_needs_check(base, index, module, &function.expressions, info)
.map(|length| (base, index, length))
})
}
impl GuardedIndex {
/// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible.
///