[naga] Let filter_emits_with_block operate on a &mut Block.

This removes some clones and collects, simplifies call sites, and
isn't any more complicated to implement.
This commit is contained in:
Jim Blandy 2024-03-23 07:47:49 -07:00 committed by Teodor Tanasoaia
parent aaf3b17623
commit bb15286df2
2 changed files with 39 additions and 43 deletions

View File

@ -3,7 +3,7 @@ use crate::{
proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
Span, Statement, SwitchCase, TypeInner, WithSpan,
Span, Statement, TypeInner, WithSpan,
};
use std::{borrow::Cow, collections::HashSet, mem};
use thiserror::Error;
@ -302,8 +302,7 @@ fn process_function(
adjust_block(&adjusted_local_expressions, &mut function.body);
let new_body = filter_emits_in_block(&function.body, &function.expressions);
function.body = new_body;
filter_emits_in_block(&mut function.body, &function.expressions);
// We've changed the keys of `function.named_expression`, so we have to
// rebuild it from scratch.
@ -620,16 +619,16 @@ fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
/// [`Emit`]: Statement::Emit
/// [`needs_pre_emit`]: Expression::needs_pre_emit
/// [`Override`]: Expression::Override
fn filter_emits_in_block(block: &Block, expressions: &Arena<Expression>) -> Block {
let mut out = Block::with_capacity(block.len());
for (stmt, span) in block.span_iter() {
fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
let original = std::mem::replace(block, Block::with_capacity(block.len()));
for (stmt, span) in original.span_into_iter() {
match stmt {
&Statement::Emit(ref range) => {
Statement::Emit(range) => {
let mut current = None;
for expr_h in range.clone() {
for expr_h in range {
if expressions[expr_h].needs_pre_emit() {
if let Some((first, last)) = current {
out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span);
block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
}
current = None;
@ -640,66 +639,57 @@ fn filter_emits_in_block(block: &Block, expressions: &Arena<Expression>) -> Bloc
}
}
if let Some((first, last)) = current {
out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span);
block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
}
}
&Statement::Block(ref block) => {
let block = filter_emits_in_block(block, expressions);
out.push(Statement::Block(block), *span);
Statement::Block(mut child) => {
filter_emits_in_block(&mut child, expressions);
block.push(Statement::Block(child), span);
}
&Statement::If {
Statement::If {
condition,
ref accept,
ref reject,
mut accept,
mut reject,
} => {
let accept = filter_emits_in_block(accept, expressions);
let reject = filter_emits_in_block(reject, expressions);
out.push(
filter_emits_in_block(&mut accept, expressions);
filter_emits_in_block(&mut reject, expressions);
block.push(
Statement::If {
condition,
accept,
reject,
},
*span,
span,
);
}
&Statement::Switch {
Statement::Switch {
selector,
ref cases,
mut cases,
} => {
let cases = cases
.iter()
.map(|case| {
let body = filter_emits_in_block(&case.body, expressions);
SwitchCase {
value: case.value,
body,
fall_through: case.fall_through,
}
})
.collect();
out.push(Statement::Switch { selector, cases }, *span);
for case in &mut cases {
filter_emits_in_block(&mut case.body, expressions);
}
block.push(Statement::Switch { selector, cases }, span);
}
&Statement::Loop {
ref body,
ref continuing,
Statement::Loop {
mut body,
mut continuing,
break_if,
} => {
let body = filter_emits_in_block(body, expressions);
let continuing = filter_emits_in_block(continuing, expressions);
out.push(
filter_emits_in_block(&mut body, expressions);
filter_emits_in_block(&mut continuing, expressions);
block.push(
Statement::Loop {
body,
continuing,
break_if,
},
*span,
span,
);
}
stmt => out.push(stmt.clone(), *span),
stmt => block.push(stmt.clone(), span),
}
}
out
}
fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {

View File

@ -65,6 +65,12 @@ impl Block {
self.span_info.splice(range.clone(), other.span_info);
self.body.splice(range, other.body);
}
pub fn span_into_iter(self) -> impl Iterator<Item = (Statement, Span)> {
let Block { body, span_info } = self;
body.into_iter().zip(span_info)
}
pub fn span_iter(&self) -> impl Iterator<Item = (&Statement, &Span)> {
let span_iter = self.span_info.iter();
self.body.iter().zip(span_iter)