From f0f0f318ec2efb4c7ea86b46670503fb8d5ce9d3 Mon Sep 17 00:00:00 2001 From: Ashley Hauck Date: Thu, 1 Apr 2021 10:03:00 +0200 Subject: [PATCH] Don't treat OpSelectionMerge as an edge in mem2reg (#571) * Don't treat OpSelectionMerge as an edge in mem2reg * Make outgoing_edges not allocate * Update test --- .../rustc_codegen_spirv/src/linker/inline.rs | 16 +++- .../rustc_codegen_spirv/src/linker/mem2reg.rs | 78 ++++++++++++------- .../src/linker/new_structurizer.rs | 4 +- .../src/linker/simple_passes.rs | 29 +++---- .../src/linker/structurizer.rs | 30 +++---- crates/spirv-builder/src/test/basic.rs | 6 +- 6 files changed, 92 insertions(+), 71 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 0195e0077d..5e0f164ebb 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -5,7 +5,7 @@ //! run mem2reg (see mem2reg.rs) on the result to "unwrap" the Function pointer. use super::apply_rewrite_rules; -use super::mem2reg::compute_preds; +use super::simple_passes::outgoing_edges; use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand}; use rspirv::spirv::{FunctionControl, Op, StorageClass, Word}; use std::collections::{HashMap, HashSet}; @@ -413,3 +413,17 @@ fn fuse_trivial_branches(function: &mut Function) { } function.blocks.retain(|b| !b.instructions.is_empty()); } + +fn compute_preds(blocks: &[Block]) -> Vec> { + let mut result = vec![vec![]; blocks.len()]; + for (source_idx, source) in blocks.iter().enumerate() { + for dest_id in outgoing_edges(source) { + let dest_idx = blocks + .iter() + .position(|b| b.label_id().unwrap() == dest_id) + .unwrap(); + result[dest_idx].push(source_idx); + } + } + result +} diff --git a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs index 391b11c4d1..f92e970d1c 100644 --- a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs +++ b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs @@ -22,8 +22,9 @@ pub fn mem2reg( constants: &HashMap, func: &mut Function, ) { - let preds = compute_preds(&func.blocks); - let idom = compute_idom(&preds); + let reachable = compute_reachable(&func.blocks); + let preds = compute_preds(&func.blocks, &reachable); + let idom = compute_idom(&preds, &reachable); let dominance_frontier = compute_dominance_frontier(&preds, &idom); insert_phis_all( header, @@ -35,24 +36,38 @@ pub fn mem2reg( ); } -pub fn compute_preds(blocks: &[Block]) -> Vec> { - let mut result = vec![vec![]; blocks.len()]; - for (source_idx, source) in blocks.iter().enumerate() { - let mut edges = outgoing_edges(source); - // HACK(eddyb) treat `OpSelectionMerge` as an edge, in case it points - // to an otherwise-unreachable block. - if let Some(before_last_idx) = source.instructions.len().checked_sub(2) { - if let Some(before_last) = source.instructions.get(before_last_idx) { - if before_last.class.opcode == Op::SelectionMerge { - edges.push(before_last.operands[0].unwrap_id_ref()); - } +fn label_to_index(blocks: &[Block], id: Word) -> usize { + blocks + .iter() + .position(|b| b.label_id().unwrap() == id) + .unwrap() +} + +fn compute_reachable(blocks: &[Block]) -> Vec { + fn recurse(blocks: &[Block], reachable: &mut [bool], block: usize) { + if !reachable[block] { + reachable[block] = true; + for dest_id in outgoing_edges(&blocks[block]) { + let dest_idx = label_to_index(blocks, dest_id); + recurse(blocks, reachable, dest_idx); } } - for dest_id in edges { - let dest_idx = blocks - .iter() - .position(|b| b.label_id().unwrap() == dest_id) - .unwrap(); + } + let mut reachable = vec![false; blocks.len()]; + recurse(blocks, &mut reachable, 0); + reachable +} + +fn compute_preds(blocks: &[Block], reachable_blocks: &[bool]) -> Vec> { + let mut result = vec![vec![]; blocks.len()]; + // Do not count unreachable blocks as valid preds of blocks + for (source_idx, source) in blocks + .iter() + .enumerate() + .filter(|&(b, _)| reachable_blocks[b]) + { + for dest_id in outgoing_edges(source) { + let dest_idx = label_to_index(blocks, dest_id); result[dest_idx].push(source_idx); } } @@ -62,7 +77,8 @@ pub fn compute_preds(blocks: &[Block]) -> Vec> { // Paper: A Simple, Fast Dominance Algorithm // https://www.cs.rice.edu/~keith/EMBED/dom.pdf // Note: requires nodes in reverse postorder -fn compute_idom(preds: &[Vec]) -> Vec { +// If a result is None, that means the block is unreachable, and therefore has no idom. +fn compute_idom(preds: &[Vec], reachable_blocks: &[bool]) -> Vec> { fn intersect(doms: &[Option], mut finger1: usize, mut finger2: usize) -> usize { // TODO: This may return an optional result? while finger1 != finger2 { @@ -83,7 +99,8 @@ fn compute_idom(preds: &[Vec]) -> Vec { let mut changed = true; while changed { changed = false; - for node in 1..(preds.len()) { + // Unreachable blocks have no preds, and therefore no idom + for node in (1..(preds.len())).filter(|&i| reachable_blocks[i]) { let mut new_idom: Option = None; for &pred in &preds[node] { if idom[pred].is_some() { @@ -99,20 +116,25 @@ fn compute_idom(preds: &[Vec]) -> Vec { } } } - idom.iter().map(|x| x.unwrap()).collect() + assert!(idom + .iter() + .enumerate() + .all(|(i, x)| x.is_some() == reachable_blocks[i])); + idom } // Same paper as above -fn compute_dominance_frontier(preds: &[Vec], idom: &[usize]) -> Vec> { +fn compute_dominance_frontier(preds: &[Vec], idom: &[Option]) -> Vec> { assert_eq!(preds.len(), idom.len()); let mut dominance_frontier = vec![HashSet::new(); preds.len()]; for node in 0..preds.len() { if preds[node].len() >= 2 { + let node_idom = idom[node].unwrap(); for &pred in &preds[node] { let mut runner = pred; - while runner != idom[node] { + while runner != node_idom { dominance_frontier[runner].insert(node); - runner = idom[runner]; + runner = idom[runner].unwrap(); } } } @@ -442,13 +464,9 @@ impl Renamer<'_> { } } - for dest_id in outgoing_edges(&self.blocks[block]) { + for dest_id in outgoing_edges(&self.blocks[block]).collect::>() { // TODO: Don't do this find - let dest_idx = self - .blocks - .iter() - .position(|b| b.label_id().unwrap() == dest_id) - .unwrap(); + let dest_idx = label_to_index(self.blocks, dest_id); self.rename(dest_idx, Some(block)); } diff --git a/crates/rustc_codegen_spirv/src/linker/new_structurizer.rs b/crates/rustc_codegen_spirv/src/linker/new_structurizer.rs index d36a2b175c..5cd7b70283 100644 --- a/crates/rustc_codegen_spirv/src/linker/new_structurizer.rs +++ b/crates/rustc_codegen_spirv/src/linker/new_structurizer.rs @@ -561,7 +561,9 @@ impl Structurizer<'_> { } visited[block] = true; - for target in super::simple_passes::outgoing_edges(&self.func.blocks()[block]) { + for target in + super::simple_passes::outgoing_edges(&self.func.blocks()[block]).collect::>() + { self.post_order_step(self.block_id_to_idx[&target], visited, post_order) } diff --git a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs index 70eed5d48a..9dfb987e93 100644 --- a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs +++ b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs @@ -1,7 +1,6 @@ use rspirv::dr::{Block, Function, Module}; use rspirv::spirv::{Op, Word}; use std::collections::{HashMap, HashSet}; -use std::iter::once; use std::mem::replace; pub fn shift_ids(module: &mut Module, add: u32) { @@ -44,7 +43,7 @@ pub fn block_ordering_pass(func: &mut Function) { .iter() .find(|b| b.label_id().unwrap() == current) .unwrap(); - let mut edges = outgoing_edges(current_block); + let mut edges = outgoing_edges(current_block).collect::>(); // HACK(eddyb) treat `OpSelectionMerge` as an edge, in case it points // to an otherwise-unreachable block. if let Some(before_last_idx) = current_block.instructions.len().checked_sub(2) { @@ -80,27 +79,17 @@ pub fn block_ordering_pass(func: &mut Function) { assert_eq!(func.blocks[0].label_id().unwrap(), entry_label); } -// FIXME(eddyb) use `Either`, `Cow`, and/or `SmallVec`. -pub fn outgoing_edges(block: &Block) -> Vec { +pub fn outgoing_edges(block: &Block) -> impl Iterator + '_ { let terminator = block.instructions.last().unwrap(); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Termination - match terminator.class.opcode { - Op::Branch => vec![terminator.operands[0].unwrap_id_ref()], - Op::BranchConditional => vec![ - terminator.operands[1].unwrap_id_ref(), - terminator.operands[2].unwrap_id_ref(), - ], - Op::Switch => once(terminator.operands[1].unwrap_id_ref()) - .chain( - terminator.operands[3..] - .iter() - .step_by(2) - .map(|op| op.unwrap_id_ref()), - ) - .collect(), - Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => Vec::new(), + let operand_indices = match terminator.class.opcode { + Op::Branch => (0..1).step_by(1), + Op::BranchConditional => (1..3).step_by(1), + Op::Switch => (1..terminator.operands.len()).step_by(2), + Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => (0..0).step_by(1), _ => panic!("Invalid block terminator: {:?}", terminator), - } + }; + operand_indices.map(move |i| terminator.operands[i].unwrap_id_ref()) } pub fn compact_ids(module: &mut Module) -> u32 { diff --git a/crates/rustc_codegen_spirv/src/linker/structurizer.rs b/crates/rustc_codegen_spirv/src/linker/structurizer.rs index 5413d6f74c..510c70bd06 100644 --- a/crates/rustc_codegen_spirv/src/linker/structurizer.rs +++ b/crates/rustc_codegen_spirv/src/linker/structurizer.rs @@ -208,7 +208,8 @@ fn retarget_loop_children_if_needed(builder: &mut Builder, cf_info: &ControlFlow while let Some(front) = next.pop_front() { let block_idx = find_block_index_from_id(builder, &front); - let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)); + let mut new_edges = + outgoing_edges(get_block_ref!(builder, block_idx)).collect::>(); // Make sure we are not looping or going into child loops. for loop_info in &cf_info.loops { @@ -249,8 +250,7 @@ fn retarget_loop_children_if_needed(builder: &mut Builder, cf_info: &ControlFlow fn incoming_edges(id: Word, builder: &mut Builder) -> Vec { let mut incoming_edges = Vec::new(); for block in get_blocks_ref(builder) { - let out = outgoing_edges(block); - if out.contains(&id) { + if outgoing_edges(block).any(|x| x == id) { incoming_edges.push(block.label_id().unwrap()); } } @@ -292,8 +292,8 @@ fn defer_loop_internals(builder: &mut Builder, cf_info: &ControlFlowInfo) { // find all blocks that branch to a merge block. let mut possible_intermediate_block_idexes = Vec::new(); for (i, block) in get_blocks_ref(builder).iter().enumerate() { - let out = outgoing_edges(block); - if out.len() == 1 && out[0] == loop_info.merge_id { + let mut out = outgoing_edges(block); + if out.next() == Some(loop_info.merge_id) && out.next() == None { possible_intermediate_block_idexes.push(i) } } @@ -390,7 +390,7 @@ fn eliminate_multiple_continue_blocks(builder: &mut Builder, header: Word) -> Wo for block in get_blocks_ref(builder) { let block_id = block.label_id().unwrap(); if ends_in_branch(block) { - let edge = outgoing_edges(block)[0]; + let edge = outgoing_edges(block).next().unwrap(); if edge == header && block_is_parent_of(builder, header, block_id) { continue_blocks.push(block_id); } @@ -421,7 +421,7 @@ fn block_leads_into_break(builder: &Builder, cf_info: &ControlFlowInfo, start: W while let Some(front) = next.pop_front() { let block_idx = find_block_index_from_id(builder, &front); - let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)); + let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::>(); // Make sure we are not looping. for loop_info in &cf_info.loops { @@ -456,7 +456,7 @@ fn block_leads_into_break(builder: &Builder, cf_info: &ControlFlowInfo, start: W fn block_leads_into_continue(builder: &Builder, cf_info: &ControlFlowInfo, start: Word) -> bool { let start_idx = find_block_index_from_id(builder, &start); - let new_edges = outgoing_edges(get_block_ref!(builder, start_idx)); + let new_edges = outgoing_edges(get_block_ref!(builder, start_idx)).collect::>(); for loop_info in &cf_info.loops { if new_edges.len() == 1 && loop_info.continue_id == new_edges[0] { return true; @@ -485,7 +485,7 @@ fn block_is_reverse_idom_of( continue; } - let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)); + let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::>(); // Skip loop bodies by jumping to the merge block is we hit a header block. for loop_info in &cf_info.loops { @@ -552,7 +552,7 @@ fn block_is_parent_of(builder: &Builder, parent: Word, child: Word) -> bool { while let Some(front) = next.pop_front() { let block_idx = find_block_index_from_id(builder, &front); - let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)); + let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::>(); for id in &processed { if let Some(i) = new_edges.iter().position(|x| x == id) { @@ -589,7 +589,7 @@ fn get_looping_branch_from_block( } let block_idx = find_block_index_from_id(builder, &front); - let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)); + let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::>(); let edge_it = new_edges.iter().find(|&x| x == &start); // Check if the new_edges contain the start if new_edges.len() == 1 { @@ -598,8 +598,8 @@ fn get_looping_branch_from_block( let start_idx = find_block_index_from_id(builder, &front); let start_edges = outgoing_edges(get_block_ref!(builder, start_idx)); - for (i, start_edge) in start_edges.iter().enumerate() { - if start_edge == edge_it || block_is_parent_of(builder, *start_edge, *edge_it) { + for (i, start_edge) in start_edges.enumerate() { + if start_edge == *edge_it || block_is_parent_of(builder, start_edge, *edge_it) { return Some(i); } } @@ -696,7 +696,7 @@ pub fn insert_selection_merge_on_conditional_branch( }; let bi = find_block_index_from_id(builder, id); - let out = outgoing_edges(&get_blocks_ref(builder)[bi]); + let out = outgoing_edges(&get_blocks_ref(builder)[bi]).collect::>(); let id = idx_to_id(builder, bi); let a_nexts = get_possible_merge_positions(builder, cf_info, out[0]); let b_nexts = get_possible_merge_positions(builder, cf_info, out[1]); @@ -809,7 +809,7 @@ pub fn insert_loop_merge_on_conditional_branch( let merge_branch_idx = (looping_branch_idx + 1) % 2; let bi = find_block_index_from_id(builder, &id); - let out = outgoing_edges(&get_blocks_ref(builder)[bi]); + let out = outgoing_edges(&get_blocks_ref(builder)[bi]).collect::>(); let continue_block_id = eliminate_multiple_continue_blocks(builder, id); let merge_block_id = out[merge_branch_idx]; diff --git a/crates/spirv-builder/src/test/basic.rs b/crates/spirv-builder/src/test/basic.rs index 1963be7f45..d65dc83f41 100644 --- a/crates/spirv-builder/src/test/basic.rs +++ b/crates/spirv-builder/src/test/basic.rs @@ -242,14 +242,12 @@ OpSelectionMerge %24 None OpBranchConditional %22 %25 %26 %25 = OpLabel %27 = OpIMul %2 %28 %14 -%29 = OpIAdd %2 %27 %5 -%30 = OpIAdd %10 %9 %31 +%15 = OpIAdd %2 %27 %5 +%12 = OpIAdd %10 %9 %29 OpBranch %24 %26 = OpLabel OpReturnValue %14 %24 = OpLabel -%12 = OpPhi %10 %30 %25 -%15 = OpPhi %2 %29 %25 %19 = OpPhi %17 %18 %25 OpBranch %13 %13 = OpLabel