diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 2935a6ad3..1bae7f2f6 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -695,6 +695,20 @@ impl<'a> ExpressionContext<'a> { ) } + /// See docs for [`proc::index::bounds_check_iter`]. + fn bounds_check_iter( + &self, + chain: Handle<crate::Expression>, + ) -> impl Iterator< + Item = ( + Handle<crate::Expression>, + index::GuardedIndex, + index::IndexableLength, + ), + > + '_ { + index::bounds_check_iter(chain, self.module, self.function, self.info) + } + fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> { match self.function.expressions[expr_handle] { crate::Expression::AccessIndex { base, index } => { @@ -2647,68 +2661,38 @@ impl<W: Write> Writer<W> { #[allow(unused_variables)] fn put_bounds_checks( &mut self, - mut chain: Handle<crate::Expression>, + chain: Handle<crate::Expression>, context: &ExpressionContext, level: back::Level, prefix: &'static str, ) -> Result<bool, Error> { let mut check_written = false; - // Iterate over the access chain, handling each expression. - loop { - // Produce a `GuardedIndex`, so we can shared code between the - // `Access` and `AccessIndex` cases. - let (base, guarded_index) = match context.function.expressions[chain] { - crate::Expression::Access { base, index } => { - (base, Some(index::GuardedIndex::Expression(index))) - } - crate::Expression::AccessIndex { base, index } => { - // Don't try to check indices into structs. Validation already took - // care of them, and index::needs_guard doesn't handle that case. - let mut base_inner = context.resolve_type(base); - if let crate::TypeInner::Pointer { base, .. } = *base_inner { - base_inner = &context.module.types[base].inner; - } - match *base_inner { - crate::TypeInner::Struct { .. } => (base, None), - _ => (base, Some(index::GuardedIndex::Known(index))), - } - } - _ => break, - }; - - if let Some(index) = guarded_index { - if let Some(length) = context.access_needs_check(base, index) { - if check_written { - write!(self.out, " && ")?; - } else { - write!(self.out, "{level}{prefix}")?; - check_written = true; - } - - // Check that the index falls within bounds. Do this with a single - // comparison, by casting the index to `uint` first, so that negative - // indices become large positive values. - write!(self.out, "uint(")?; - self.put_index(index, context, true)?; - self.out.write_str(") < ")?; - match length { - index::IndexableLength::Known(value) => write!(self.out, "{value}")?, - index::IndexableLength::Dynamic => { - let global = - context.function.originating_global(base).ok_or_else(|| { - Error::GenericValidation( - "Could not find originating global".into(), - ) - })?; - write!(self.out, "1 + ")?; - self.put_dynamic_array_max_index(global, context)? - } - } - } + // Iterate over the access chain, handling each required bounds check. + for (base, index, length) in context.bounds_check_iter(chain) { + if check_written { + write!(self.out, " && ")?; + } else { + write!(self.out, "{level}{prefix}")?; + check_written = true; } - chain = base + // Check that the index falls within bounds. Do this with a single + // comparison, by casting the index to `uint` first, so that negative + // indices become large positive values. + write!(self.out, "uint(")?; + self.put_index(index, context, true)?; + self.out.write_str(") < ")?; + match length { + index::IndexableLength::Known(value) => write!(self.out, "{value}")?, + index::IndexableLength::Dynamic => { + let global = context.function.originating_global(base).ok_or_else(|| { + Error::GenericValidation("Could not find originating global".into()) + })?; + write!(self.out, "1 + ")?; + self.put_dynamic_array_max_index(global, context)? + } + } } Ok(check_written) diff --git a/naga/src/proc/index.rs b/naga/src/proc/index.rs index 9f1c0ddb7..550814573 100644 --- a/naga/src/proc/index.rs +++ b/naga/src/proc/index.rs @@ -2,6 +2,8 @@ Definitions for index bounds checking. */ +use core::iter; + use crate::arena::{Handle, HandleSet, UniqueArena}; use crate::valid; @@ -340,6 +342,53 @@ pub fn access_needs_check( Some(length) } +/// Returns an iterator of accesses within the chain of `Access` and +/// `AccessIndex` expressions starting from `chain` that may need to be +/// bounds-checked at runtime. +/// +/// They're yielded as `(base, index)` pairs, where `base` is the type that the +/// access expression will produce and `index` is the index being used. +/// +/// Accesses through a struct are omitted, since you never need a bounds check +/// for accessing a struct field. +/// +/// If `chain` isn't an `Access` or `AccessIndex` expression at all, the +/// iterator is empty. +pub(crate) fn bounds_check_iter<'a>( + mut chain: Handle<crate::Expression>, + module: &'a crate::Module, + function: &'a crate::Function, + info: &'a valid::FunctionInfo, +) -> impl Iterator<Item = (Handle<crate::Expression>, GuardedIndex, IndexableLength)> + 'a { + iter::from_fn(move || { + let (next_expr, result) = match function.expressions[chain] { + crate::Expression::Access { base, index } => { + (base, Some((base, GuardedIndex::Expression(index)))) + } + crate::Expression::AccessIndex { base, index } => { + // Don't try to check indices into structs. Validation already took + // care of them, and access_needs_check doesn't handle that case. + let mut base_inner = info[base].ty.inner_with(&module.types); + if let crate::TypeInner::Pointer { base, .. } = *base_inner { + base_inner = &module.types[base].inner; + } + match *base_inner { + crate::TypeInner::Struct { .. } => (base, None), + _ => (base, Some((base, GuardedIndex::Known(index)))), + } + } + _ => return None, + }; + chain = next_expr; + Some(result) + }) + .flatten() + .filter_map(|(base, index)| { + access_needs_check(base, index, module, &function.expressions, info) + .map(|length| (base, index, length)) + }) +} + impl GuardedIndex { /// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible. ///