From a0dbe5ebc6fa24422fb84b2e0fea1cc94dee5109 Mon Sep 17 00:00:00 2001 From: Andy Leiserson <aleiserson@mozilla.com> Date: Tue, 8 Apr 2025 16:08:57 -0700 Subject: [PATCH] refactor(msl-out): create a type for bounds check iter items Co-Authored-By: Erich Gubler <erichdongubler@gmail.com> --- naga/src/back/msl/writer.rs | 22 +++++++++++++--------- naga/src/proc/index.rs | 30 +++++++++++++++++++++++++----- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 721595a37..2903ad80e 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -18,7 +18,11 @@ use crate::{ arena::{Handle, HandleSet}, back::{self, Baked}, common, - proc::{self, index, NameKey, TypeResolution}, + proc::{ + self, + index::{self, BoundsCheck}, + NameKey, TypeResolution, + }, valid, FastHashMap, FastHashSet, }; @@ -723,13 +727,7 @@ impl<'a> ExpressionContext<'a> { fn bounds_check_iter( &self, chain: Handle<crate::Expression>, - ) -> impl Iterator< - Item = ( - Handle<crate::Expression>, - index::GuardedIndex, - index::IndexableLength, - ), - > + '_ { + ) -> impl Iterator<Item = BoundsCheck> + '_ { index::bounds_check_iter(chain, self.module, self.function, self.info) } @@ -2778,7 +2776,13 @@ impl<W: Write> Writer<W> { let mut check_written = false; // Iterate over the access chain, handling each required bounds check. - for (base, index, length) in context.bounds_check_iter(chain) { + for item in context.bounds_check_iter(chain) { + let BoundsCheck { + base, + index, + length, + } = item; + if check_written { write!(self.out, " && ")?; } else { diff --git a/naga/src/proc/index.rs b/naga/src/proc/index.rs index 8e76c8524..cf6a127ac 100644 --- a/naga/src/proc/index.rs +++ b/naga/src/proc/index.rs @@ -342,12 +342,27 @@ pub fn access_needs_check( Some(length) } +/// Items returned by the [`bounds_check_iter`] iterator. +#[cfg_attr(not(feature = "msl-out"), allow(dead_code))] +pub(crate) struct BoundsCheck { + /// The base of the [`Access`] or [`AccessIndex`] expression. + /// + /// [`Access`]: crate::Expression::Access + /// [`AccessIndex`]: crate::Expression::AccessIndex + pub base: Handle<crate::Expression>, + + /// The index being accessed. + pub index: GuardedIndex, + + /// The length of `base`. + pub length: IndexableLength, +} + /// Returns an iterator of accesses within the chain of `Access` and /// `AccessIndex` expressions starting from `chain` that may need to be /// bounds-checked at runtime. /// -/// They're yielded as `(base, index)` pairs, where `base` is the type that the -/// access expression will produce and `index` is the index being used. +/// Items are yielded as [`BoundsCheck`] instances. /// /// Accesses through a struct are omitted, since you never need a bounds check /// for accessing a struct field. @@ -359,7 +374,7 @@ pub(crate) fn bounds_check_iter<'a>( module: &'a crate::Module, function: &'a crate::Function, info: &'a valid::FunctionInfo, -) -> impl Iterator<Item = (Handle<crate::Expression>, GuardedIndex, IndexableLength)> + 'a { +) -> impl Iterator<Item = BoundsCheck> + 'a { iter::from_fn(move || { let (next_expr, result) = match function.expressions[chain] { crate::Expression::Access { base, index } => { @@ -384,8 +399,13 @@ pub(crate) fn bounds_check_iter<'a>( }) .flatten() .filter_map(|(base, index)| { - access_needs_check(base, index, module, &function.expressions, info) - .map(|length| (base, index, length)) + access_needs_check(base, index, module, &function.expressions, info).map(|length| { + BoundsCheck { + base, + index, + length, + } + }) }) }