diff --git a/crates/rustc_codegen_spirv/src/link.rs b/crates/rustc_codegen_spirv/src/link.rs index 9137bfe7af..1b15e59085 100644 --- a/crates/rustc_codegen_spirv/src/link.rs +++ b/crates/rustc_codegen_spirv/src/link.rs @@ -539,7 +539,6 @@ fn do_link( inline: legalize, mem2reg: legalize, structurize: env::var("NO_STRUCTURIZE").is_err(), - use_new_structurizer: env::var("OLD_STRUCTURIZER").is_err(), emit_multiple_modules, }; diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index a2bd47d78f..638115aedf 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -6,7 +6,6 @@ mod duplicates; mod import_export_link; mod inline; mod mem2reg; -mod new_structurizer; mod peephole_opts; mod simple_passes; mod specializer; @@ -29,7 +28,6 @@ pub struct Options { pub inline: bool, pub mem2reg: bool, pub structurize: bool, - pub use_new_structurizer: bool, pub emit_multiple_modules: bool, } @@ -190,11 +188,7 @@ pub fn link(sess: &Session, mut inputs: Vec, opts: &Options) -> Result { - builder: &'a mut Builder, -} - -impl FuncBuilder<'_> { - fn function(&self) -> &Function { - let func_idx = self.builder.selected_function().unwrap(); - &self.builder.module_ref().functions[func_idx] - } - - fn function_mut(&mut self) -> &mut Function { - let func_idx = self.builder.selected_function().unwrap(); - &mut self.builder.module_mut().functions[func_idx] - } - - fn blocks(&self) -> &[Block] { - &self.function().blocks - } - - fn blocks_mut(&mut self) -> &mut [Block] { - &mut self.function_mut().blocks - } -} - -pub fn structurize( - module: Module, - unroll_loops_decorations: FxHashMap, -) -> Module { - let mut builder = Builder::new_from_module(module); - - // Get the `OpTypeBool` type (it will only be created if it's missing). - let type_bool = builder.type_bool(); - - // Find already present `OpConstant{False,True}` (if they're in the module). - let mut existing_const_false = None; - let mut existing_const_true = None; - for inst in &builder.module_ref().types_global_values { - let existing = match inst.class.opcode { - Op::ConstantFalse => &mut existing_const_false, - Op::ConstantTrue => &mut existing_const_true, - _ => continue, - }; - - if existing.is_none() { - *existing = Some(inst.result_id.unwrap()); - } - - if existing_const_false.is_some() && existing_const_true.is_some() { - break; - } - } - - // Create new `OpConstant{False,True}` if they're missing. - let const_false = existing_const_false.unwrap_or_else(|| builder.constant_false(type_bool)); - let const_true = existing_const_true.unwrap_or_else(|| builder.constant_true(type_bool)); - - for func_idx in 0..builder.module_ref().functions.len() { - builder.select_function(Some(func_idx)).unwrap(); - let func = FuncBuilder { - builder: &mut builder, - }; - - let func_id = func.function().def_id().unwrap(); - - let loop_control = match unroll_loops_decorations.get(&func_id) { - Some(UnrollLoopsDecoration {}) => LoopControl::UNROLL, - None => LoopControl::NONE, - }; - - let block_id_to_idx = func - .blocks() - .iter() - .enumerate() - .map(|(i, block)| (block.label_id().unwrap(), i)) - .collect(); - - Structurizer { - globals: Globals { - type_bool, - const_false, - const_true, - }, - func, - block_id_to_idx, - loop_control, - incoming_edge_count: vec![], - regions: FxHashMap::default(), - } - .structurize_func(); - } - - builder.module() -} - -// FIXME(eddyb) use newtyped indices and `IndexVec`. -type BlockIdx = usize; -type BlockId = Word; - -/// Regions are made up of their entry block and all other blocks dominated -/// by that block. All edges leaving a region are considered "exits". -struct Region { - /// After structurizing a region, all paths through it must lead to a single - /// "merge" block (i.e. `merge` post-dominates the entire region). - /// The `merge` block must be terminated by one of `OpReturn`, `OpReturnValue`, - /// `OpKill`, `OpIgnoreIntersectionKHR`, `OpTerminateRayKHR` or - /// `OpUnreachable`. If `exits` isn't empty, `merge` will receive an - /// `OpBranch` from its parent region (to an outer merge block). - merge: BlockIdx, - merge_id: BlockId, - - exits: IndexMap, -} - -#[derive(Default)] -struct Exit { - /// Number of total edges to this target (a subset of the target's predecessors). - edge_count: usize, - - /// If this is a deferred exit, `condition` is a boolean value which must - /// be `true` in order to execute this exit. - condition: Option, -} - -struct Structurizer<'a> { - globals: Globals, - - func: FuncBuilder<'a>, - block_id_to_idx: FxHashMap, - - /// `LoopControl` to use in all loops' `OpLoopMerge` instruction. - /// Currently only affected by function-scoped `#[spirv(unroll_loops)]`. - loop_control: LoopControl, - - /// Number of edges pointing to each block. - /// Computed by `post_order` and updated when structuring loops - /// (backedge count is subtracted to hide them from outer regions). - incoming_edge_count: Vec, - - regions: FxHashMap, -} - -impl Structurizer<'_> { - fn structurize_func(&mut self) { - let Globals { - const_false, - const_true, - type_bool, - } = self.globals; - - // By iterating in post-order, we are guaranteed to visit "inner" regions - // before "outer" ones. - for block in self.post_order() { - let block_id = self.func.blocks()[block].label_id().unwrap(); - let terminator = self.func.blocks()[block].instructions.last().unwrap(); - let mut region = match terminator.class.opcode { - Op::Return - | Op::ReturnValue - | Op::Kill - | Op::IgnoreIntersectionKHR - | Op::TerminateRayKHR - | Op::Unreachable => Region { - merge: block, - merge_id: block_id, - exits: indexmap! {}, - }, - - Op::Branch => { - let target = self.block_id_to_idx[&terminator.operands[0].unwrap_id_ref()]; - self.child_region(target).unwrap_or_else(|| { - self.func.builder.select_block(Some(block)).unwrap(); - self.func.builder.pop_instruction().unwrap(); - // Default all merges to `OpUnreachable`, in case they're unused. - self.func.builder.unreachable().unwrap(); - Region { - merge: block, - merge_id: block_id, - exits: indexmap! { - target => Exit { edge_count: 1, condition: None } - }, - } - }) - } - - Op::BranchConditional | Op::Switch => { - let target_operand_indices = match terminator.class.opcode { - Op::BranchConditional => (1..3).step_by(1), - Op::Switch => (1..terminator.operands.len()).step_by(2), - _ => unreachable!(), - }; - - // FIXME(eddyb) avoid wasteful allocation. - let child_regions: Vec<_> = target_operand_indices - .map(|i| { - let target_id = self.func.blocks()[block] - .instructions - .last() - .unwrap() - .operands[i] - .unwrap_id_ref(); - let target = self.block_id_to_idx[&target_id]; - self.child_region(target).unwrap_or_else(|| { - // Synthesize a single-block region for every edge that - // doesn't already enter a child region, so that the - // merge block we later generate has an unique source for - // every single arm of this conditional branch or switch, - // to attach per-exit condition phis to. - let new_block_id = self.func.builder.begin_block(None).unwrap(); - let new_block = self.func.builder.selected_block().unwrap(); - // Default all merges to `OpUnreachable`, in case they're unused. - self.func.builder.unreachable().unwrap(); - self.func.blocks_mut()[block] - .instructions - .last_mut() - .unwrap() - .operands[i] = Operand::IdRef(new_block_id); - Region { - merge: new_block, - merge_id: new_block_id, - exits: indexmap! { - target => Exit { edge_count: 1, condition: None } - }, - } - }) - }) - .collect(); - - self.selection_merge_regions(block, &child_regions) - } - _ => panic!("Invalid block terminator: {:?}", terminator), - }; - - // Peel off deferred exits which have all their edges accounted for - // already, within this region. Repeat until no such exits are left. - while let Some((&target, _)) = region - .exits - .iter() - .find(|&(&target, exit)| exit.edge_count == self.incoming_edge_count[target]) - { - let taken_block_id = self.func.blocks()[target].label_id().unwrap(); - let exit = region.exits.remove(&target).unwrap(); - - // Special-case the last exit as unconditional - regardless of - // what might end up in `exit.condition`, what we'd generate is - // `if exit.condition { branch target; } else { unreachable; }` - // which is just `branch target;` with an extra assumption that - // `exit.condition` is `true` (which we can just ignore). - if region.exits.is_empty() { - self.func.builder.select_block(Some(region.merge)).unwrap(); - assert_eq!( - self.func.builder.pop_instruction().unwrap().class.opcode, - Op::Unreachable - ); - self.func.builder.branch(taken_block_id).unwrap(); - region = self.regions.remove(&target).unwrap(); - continue; - } - - // Create a new block for the "`exit` not taken" path. - let not_taken_block_id = self.func.builder.begin_block(None).unwrap(); - let not_taken_block = self.func.builder.selected_block().unwrap(); - // Default all merges to `OpUnreachable`, in case they're unused. - self.func.builder.unreachable().unwrap(); - - // Choose whether to take this `exit`, in the previous merge block. - let branch_block = region.merge; - self.func.builder.select_block(Some(branch_block)).unwrap(); - assert_eq!( - self.func.builder.pop_instruction().unwrap().class.opcode, - Op::Unreachable - ); - self.func - .builder - .branch_conditional( - exit.condition.unwrap(), - taken_block_id, - not_taken_block_id, - iter::empty(), - ) - .unwrap(); - - // Merge the "taken" and "not taken" paths. - let taken_region = self.regions.remove(&target).unwrap(); - let not_taken_region = Region { - merge: not_taken_block, - merge_id: not_taken_block_id, - exits: region.exits, - }; - region = - self.selection_merge_regions(branch_block, &[taken_region, not_taken_region]); - } - - // Peel off a backedge exit, which indicates this region is a loop. - if let Some(mut backedge_exit) = region.exits.remove(&block) { - // Inject a `while`-like loop header just before the start of the - // loop body. This is needed because our "`break` vs `continue`" - // choice is *after* the loop body, like in a `do`-`while` loop, - // but SPIR-V requires it at the start, like in a `while` loop. - let while_header_block_id = self.func.builder.begin_block(None).unwrap(); - let while_header_block = self.func.builder.selected_block().unwrap(); - self.func.builder.select_block(None).unwrap(); - let while_exit_block_id = self.func.builder.begin_block(None).unwrap(); - let while_exit_block = self.func.builder.selected_block().unwrap(); - // Default all merges to `OpUnreachable`, in case they're unused. - self.func.builder.unreachable().unwrap(); - let while_body_block_id = self.func.builder.begin_block(None).unwrap(); - let while_body_block = self.func.builder.selected_block().unwrap(); - self.func.builder.select_block(None).unwrap(); - - // Move all of the contents of the original `block` into the - // new loop body, but keep labels and indices intact. - // Also update the existing merge if it happens to be the `block` - // we just moved (this should only be relevant to infinite loops). - self.func.blocks_mut()[while_body_block].instructions = - mem::take(&mut self.func.blocks_mut()[block].instructions); - if region.merge == block { - region.merge = while_body_block; - region.merge_id = while_body_block_id; - } - - // Create a separate merge block for the loop body, as the original - // one might be used by an `OpSelectionMerge` and cannot be reused. - let while_body_merge_id = self.func.builder.begin_block(None).unwrap(); - let while_body_merge = self.func.builder.selected_block().unwrap(); - self.func.builder.select_block(None).unwrap(); - self.func.builder.select_block(Some(region.merge)).unwrap(); - assert_eq!( - self.func.builder.pop_instruction().unwrap().class.opcode, - Op::Unreachable - ); - self.func.builder.branch(while_body_merge_id).unwrap(); - - // Point both the original block and the merge of the loop body, - // at the new loop header, and compute phis for all the exit - // conditions (including the backedge, which indicates "continue"). - self.func.builder.select_block(Some(block)).unwrap(); - self.func.builder.branch(while_header_block_id).unwrap(); - self.func - .builder - .select_block(Some(while_body_merge)) - .unwrap(); - self.func.builder.branch(while_header_block_id).unwrap(); - self.func - .builder - .select_block(Some(while_header_block)) - .unwrap(); - - for (&target, exit) in region - .exits - .iter_mut() - .chain(iter::once((&while_body_block, &mut backedge_exit))) - { - let first_entry_case = ( - if target == while_body_block { - const_true - } else { - const_false - }, - block_id, - ); - let repeat_case = (exit.condition.unwrap_or(const_true), while_body_merge_id); - let phi_cases = [first_entry_case, repeat_case]; - exit.condition = Some( - self.func - .builder - .phi(type_bool, None, phi_cases.iter().copied()) - .unwrap(), - ); - } - - // Choose whether to keep looping, in the `while`-like loop header. - self.func - .builder - .select_block(Some(while_header_block)) - .unwrap(); - self.func - .builder - .loop_merge( - while_exit_block_id, - while_body_merge_id, - self.loop_control, - iter::empty(), - ) - .unwrap(); - self.func - .builder - .select_block(Some(while_header_block)) - .unwrap(); - self.func - .builder - .branch_conditional( - backedge_exit.condition.unwrap(), - while_body_block_id, - while_exit_block_id, - iter::empty(), - ) - .unwrap(); - region.merge = while_exit_block; - region.merge_id = while_exit_block_id; - - // Remove the backedge count from the total incoming count of `block`. - // This will allow outer regions to treat the loop opaquely. - self.incoming_edge_count[block] -= backedge_exit.edge_count; - } - - self.regions.insert(block, region); - } - - assert_eq!(self.regions.len(), 1); - assert_eq!(self.regions.values().next().unwrap().exits.len(), 0); - } - - fn child_region(&mut self, target: BlockIdx) -> Option { - // An "entry" edge is the unique edge into a region. - if self.incoming_edge_count[target] == 1 { - Some(self.regions.remove(&target).unwrap()) - } else { - None - } - } - - fn selection_merge_regions(&mut self, block: BlockIdx, child_regions: &[Region]) -> Region { - let Globals { - const_false, - const_true, - type_bool, - } = self.globals; - - // HACK(eddyb) this special-cases the easy case where we can - // just reuse a merge block, and don't have to create our own. - let unconditional_single_exit = |region: &Region| { - region.exits.len() == 1 && region.exits.get_index(0).unwrap().1.condition.is_none() - }; - let structural_merge = if child_regions.iter().all(unconditional_single_exit) { - let merge = *child_regions[0].exits.get_index(0).unwrap().0; - if child_regions - .iter() - .all(|region| *region.exits.get_index(0).unwrap().0 == merge) - && child_regions - .iter() - .map(|region| region.exits.get_index(0).unwrap().1.edge_count) - .sum::() - == self.incoming_edge_count[merge] - { - Some(merge) - } else { - None - } - } else { - None - }; - - // Reuse or create a merge block, and use it as the selection merge. - let merge = structural_merge.unwrap_or_else(|| { - self.func.builder.begin_block(None).unwrap(); - self.func.builder.selected_block().unwrap() - }); - let merge_id = self.func.blocks()[merge].label_id().unwrap(); - self.func.builder.select_block(Some(block)).unwrap(); - self.func - .builder - .insert_selection_merge(InsertPoint::FromEnd(1), merge_id, SelectionControl::NONE) - .unwrap(); - - // Branch all the child regions into our merge block. - for region in child_regions { - // HACK(eddyb) empty `region.exits` indicate diverging control-flow, - // and that we should ignore `region.merge`. - if !region.exits.is_empty() { - self.func.builder.select_block(Some(region.merge)).unwrap(); - assert_eq!( - self.func.builder.pop_instruction().unwrap().class.opcode, - Op::Unreachable - ); - self.func.builder.branch(merge_id).unwrap(); - } - } - - if let Some(merge) = structural_merge { - self.regions.remove(&merge).unwrap() - } else { - self.func.builder.select_block(Some(merge)).unwrap(); - - // Gather all the potential exits. - let mut exits: IndexMap = indexmap! {}; - for region in child_regions { - for (&target, exit) in ®ion.exits { - exits.entry(target).or_default().edge_count += exit.edge_count; - } - } - - // Update conditions using phis. - for (&target, exit) in &mut exits { - let phi_cases = child_regions - .iter() - .filter(|region| { - // HACK(eddyb) empty `region.exits` indicate diverging control-flow, - // and that we should ignore `region.merge`. - !region.exits.is_empty() - }) - .map(|region| { - ( - match region.exits.get(&target) { - Some(exit) => exit.condition.unwrap_or(const_true), - None => const_false, - }, - region.merge_id, - ) - }); - exit.condition = Some(self.func.builder.phi(type_bool, None, phi_cases).unwrap()); - } - - // Default all merges to `OpUnreachable`, in case they're unused. - self.func.builder.unreachable().unwrap(); - - Region { - merge, - merge_id, - exits, - } - } - } - - // FIXME(eddyb) replace this with `rustc_data_structures::graph::iterate` - // (or similar). - fn post_order(&mut self) -> Vec { - let blocks = self.func.blocks(); - - // HACK(eddyb) compute edge counts through the post-order traversal. - assert!(self.incoming_edge_count.is_empty()); - self.incoming_edge_count = vec![0; blocks.len()]; - - // FIXME(eddyb) use a proper bitset. - let mut visited = vec![false; blocks.len()]; - let mut post_order = Vec::with_capacity(blocks.len()); - - self.post_order_step(0, &mut visited, &mut post_order); - - post_order - } - - fn post_order_step( - &mut self, - block: BlockIdx, - visited: &mut [bool], - post_order: &mut Vec, - ) { - self.incoming_edge_count[block] += 1; - - if visited[block] { - return; - } - visited[block] = true; - - for target in - super::simple_passes::outgoing_edges(&self.func.blocks()[block]).collect::>() - { - self.post_order_step(self.block_id_to_idx[&target], visited, post_order) - } - - post_order.push(block); - } -} diff --git a/crates/rustc_codegen_spirv/src/linker/structurizer.rs b/crates/rustc_codegen_spirv/src/linker/structurizer.rs index d2f59bb896..2ddbef7fc4 100644 --- a/crates/rustc_codegen_spirv/src/linker/structurizer.rs +++ b/crates/rustc_codegen_spirv/src/linker/structurizer.rs @@ -1,850 +1,578 @@ -// This pass inserts merge instructions for structured control flow with the assumption the spir-v is reducible. - -use super::simple_passes::outgoing_edges; use crate::decorations::UnrollLoopsDecoration; -use rspirv::spirv::{Op, SelectionControl, Word}; -use rspirv::{ - dr::{Block, Builder, InsertPoint, Module, Operand}, - spirv::LoopControl, -}; +use indexmap::{indexmap, IndexMap}; +use rspirv::dr::{Block, Builder, Function, InsertPoint, Module, Operand}; +use rspirv::spirv::{LoopControl, Op, SelectionControl, Word}; use rustc_data_structures::fx::FxHashMap; -use rustc_session::Session; -use std::collections::VecDeque; +use std::{iter, mem}; -pub struct LoopInfo { - merge_id: Word, - continue_id: Word, - header_id: Word, -} -pub struct ControlFlowInfo { - loops: Vec, - if_merge_ids: Vec, - switch_merge_ids: Vec, +/// Cached IDs of `OpTypeBool`, `OpConstantFalse`, and `OpConstantTrue`. +struct Globals { + type_bool: Word, + const_false: Word, + const_true: Word, } -impl ControlFlowInfo { - fn new() -> Self { - Self { - loops: Vec::new(), - if_merge_ids: Vec::new(), - switch_merge_ids: Vec::new(), - } +// FIXME(eddyb) move this into some common module. Also consider whether we +// actually need a "builder" or could just operate on a `&mut Function`. +struct FuncBuilder<'a> { + builder: &'a mut Builder, +} + +impl FuncBuilder<'_> { + fn function(&self) -> &Function { + let func_idx = self.builder.selected_function().unwrap(); + &self.builder.module_ref().functions[func_idx] } - fn id_is_loops_merge(&self, id: Word) -> bool { - for loop_info in &self.loops { - if loop_info.merge_id == id { - return true; - } - } - - false + fn function_mut(&mut self) -> &mut Function { + let func_idx = self.builder.selected_function().unwrap(); + &mut self.builder.module_mut().functions[func_idx] } - fn id_is_loops_header(&self, id: Word) -> bool { - for loop_info in &self.loops { - if loop_info.header_id == id { - return true; - } - } - - false + fn blocks(&self) -> &[Block] { + &self.function().blocks } - fn id_is_loops_continue(&self, id: Word) -> bool { - for loop_info in &self.loops { - if loop_info.continue_id == id { - return true; - } - } - - false - } - - fn id_is_ifs_merge(&self, id: Word) -> bool { - for merge in &self.if_merge_ids { - if *merge == id { - return true; - } - } - - false - } - - fn set_loops_continue_and_merge(&mut self, header_id: Word, merge_id: Word, continue_id: Word) { - for loop_info in &mut self.loops { - if loop_info.header_id == header_id { - loop_info.merge_id = merge_id; - loop_info.continue_id = continue_id; - return; - } - } - - panic!("tried to set the continue and merge of a header block that does not exist"); - } - - fn used(&self, id: Word) -> bool { - // I don't believe it is nessessary to check if the block is used as a loop header. - self.id_is_loops_merge(id) || self.id_is_loops_continue(id) || self.id_is_ifs_merge(id) - } - - fn retarget(&mut self, old: Word, new: Word) { - for loop_info in &mut self.loops { - if loop_info.header_id == old { - loop_info.header_id = new; - } else if loop_info.merge_id == old { - loop_info.merge_id = new; - } else if loop_info.continue_id == old { - loop_info.continue_id = new; - } - } - for merge_id in &mut self.if_merge_ids { - if *merge_id == old { - *merge_id = new; - } - } - for merge_id in &mut self.switch_merge_ids { - if *merge_id == old { - *merge_id = new; - } - } - } - - fn set_names(&self, builder: &mut Builder) { - for loop_info in &self.loops { - builder.name(loop_info.header_id, "loop_header".to_string()); - builder.name(loop_info.merge_id, "loop_merge".to_string()); - builder.name(loop_info.continue_id, "loop_continue".to_string()); - } - for id in &self.if_merge_ids { - builder.name(*id, "if_merge".to_string()); - } - for id in &self.switch_merge_ids { - builder.name(*id, "switch_merge".to_string()); - } + fn blocks_mut(&mut self) -> &mut [Block] { + &mut self.function_mut().blocks } } pub fn structurize( - sess: &Session, module: Module, unroll_loops_decorations: FxHashMap, ) -> Module { let mut builder = Builder::new_from_module(module); + // Get the `OpTypeBool` type (it will only be created if it's missing). + let type_bool = builder.type_bool(); + + // Find already present `OpConstant{False,True}` (if they're in the module). + let mut existing_const_false = None; + let mut existing_const_true = None; + for inst in &builder.module_ref().types_global_values { + let existing = match inst.class.opcode { + Op::ConstantFalse => &mut existing_const_false, + Op::ConstantTrue => &mut existing_const_true, + _ => continue, + }; + + if existing.is_none() { + *existing = Some(inst.result_id.unwrap()); + } + + if existing_const_false.is_some() && existing_const_true.is_some() { + break; + } + } + + // Create new `OpConstant{False,True}` if they're missing. + let const_false = existing_const_false.unwrap_or_else(|| builder.constant_false(type_bool)); + let const_true = existing_const_true.unwrap_or_else(|| builder.constant_true(type_bool)); + for func_idx in 0..builder.module_ref().functions.len() { - let mut cf_info = ControlFlowInfo::new(); - builder.select_function(Some(func_idx)).unwrap(); + let func = FuncBuilder { + builder: &mut builder, + }; - let func_id = builder.module_ref().functions[func_idx] - .def - .as_ref() - .unwrap() - .result_id - .unwrap(); + let func_id = func.function().def_id().unwrap(); let loop_control = match unroll_loops_decorations.get(&func_id) { Some(UnrollLoopsDecoration {}) => LoopControl::UNROLL, None => LoopControl::NONE, }; - insert_loop_merge_on_conditional_branch(&mut builder, &mut cf_info, loop_control); - retarget_loop_children_if_needed(&mut builder, &cf_info); - insert_selection_merge_on_conditional_branch(sess, &mut builder, &mut cf_info); - defer_loop_internals(&mut builder, &cf_info); - cf_info.set_names(&mut builder); + let block_id_to_idx = func + .blocks() + .iter() + .enumerate() + .map(|(i, block)| (block.label_id().unwrap(), i)) + .collect(); + + Structurizer { + globals: Globals { + type_bool, + const_false, + const_true, + }, + func, + block_id_to_idx, + loop_control, + incoming_edge_count: vec![], + regions: FxHashMap::default(), + } + .structurize_func(); } builder.module() } -fn get_blocks_mut(builder: &mut Builder) -> &mut Vec { - let function = builder.selected_function().unwrap(); - &mut builder.module_mut().functions[function].blocks +// FIXME(eddyb) use newtyped indices and `IndexVec`. +type BlockIdx = usize; +type BlockId = Word; + +/// Regions are made up of their entry block and all other blocks dominated +/// by that block. All edges leaving a region are considered "exits". +struct Region { + /// After structurizing a region, all paths through it must lead to a single + /// "merge" block (i.e. `merge` post-dominates the entire region). + /// The `merge` block must be terminated by one of `OpReturn`, `OpReturnValue`, + /// `OpKill`, `OpIgnoreIntersectionKHR`, `OpTerminateRayKHR` or + /// `OpUnreachable`. If `exits` isn't empty, `merge` will receive an + /// `OpBranch` from its parent region (to an outer merge block). + merge: BlockIdx, + merge_id: BlockId, + + exits: IndexMap, } -fn get_blocks_ref(builder: &Builder) -> &[Block] { - let function = builder.selected_function().unwrap(); - &builder.module_ref().functions[function].blocks +#[derive(Default)] +struct Exit { + /// Number of total edges to this target (a subset of the target's predecessors). + edge_count: usize, + + /// If this is a deferred exit, `condition` is a boolean value which must + /// be `true` in order to execute this exit. + condition: Option, } -fn find_block_index_from_id(builder: &Builder, id: &Word) -> usize { - for (i, block) in get_blocks_ref(builder).iter().enumerate() { - if block.label_id() == Some(*id) { - return i; - } - } +struct Structurizer<'a> { + globals: Globals, - panic!("Failed to find block from id {}", id); -} + func: FuncBuilder<'a>, + block_id_to_idx: FxHashMap, -macro_rules! get_block_mut { - ($builder:expr, $idx:expr) => { - &mut get_blocks_mut($builder)[$idx] - }; -} - -macro_rules! get_block_ref { - ($builder:expr, $idx:expr) => { - &get_blocks_ref($builder)[$idx] - }; -} - -fn idx_to_id(builder: &mut Builder, idx: usize) -> Word { - get_blocks_ref(builder)[idx].label_id().unwrap() -} - -// some times break will yeet themselfs out of a parent loop by skipping the merge block. This prevents that. -fn retarget_loop_children_if_needed(builder: &mut Builder, cf_info: &ControlFlowInfo) { - for loop_info in &cf_info.loops { - let LoopInfo { - header_id: header, - merge_id: merge, - .. - } = loop_info; - - let mut next: VecDeque = VecDeque::new(); - next.push_back(*header); - - while let Some(front) = next.pop_front() { - let block_idx = find_block_index_from_id(builder, &front); - let mut new_edges = - outgoing_edges(get_block_ref!(builder, block_idx)).collect::>(); - - // Make sure we are not looping or going into child loops. - for loop_info in &cf_info.loops { - if new_edges.contains(&loop_info.header_id) { - let index = new_edges - .iter() - .position(|x| *x == loop_info.header_id) - .unwrap(); - new_edges.remove(index); - } - } - - // don't continue after merge - if front == *merge { - new_edges.clear(); - } - - if new_edges.len() == 1 { - // if front branches to a block that is the child of a merge, retarget it. - if block_is_parent_of(builder, *merge, new_edges[0]) { - // retarget front to branch to merge. - let front_block = get_block_mut!(builder, block_idx); - (*front_block - .instructions - .last_mut() - .unwrap() - .operands - .last_mut() - .unwrap()) = Operand::IdRef(*merge); - } - } - - next.extend(new_edges); - } - } -} - -fn incoming_edges(id: Word, builder: &mut Builder) -> Vec { - let mut incoming_edges = Vec::new(); - for block in get_blocks_ref(builder) { - if outgoing_edges(block).any(|x| x == id) { - incoming_edges.push(block.label_id().unwrap()); - } - } - - incoming_edges -} - -fn num_incoming_edges(id: Word, builder: &mut Builder) -> usize { - incoming_edges(id, builder).len() -} - -// Turn a block into a conditional branch that either goes to yes or goes to merge. -fn change_block_to_switch( - builder: &mut Builder, - block_id: Word, - cases: &[Word], - merge: Word, - condition: Word, -) { - let cb_idx = find_block_index_from_id(builder, &block_id); - builder.select_block(Some(cb_idx)).unwrap(); - - let target: Vec<(Operand, Word)> = cases - .iter() - .enumerate() - .map(|(i, id)| (Operand::LiteralInt32(i as u32 + 1), *id)) - .collect(); - - builder.pop_instruction().unwrap(); - builder - .selection_merge(merge, SelectionControl::NONE) - .unwrap(); - builder.switch(condition, merge, target).unwrap(); -} - -// detect the intermediate break block by checking whether a block that branches to a merge block has 2 parents. -fn defer_loop_internals(builder: &mut Builder, cf_info: &ControlFlowInfo) { - for loop_info in &cf_info.loops { - // find all blocks that branch to a merge block. - let mut possible_intermediate_block_idexes = Vec::new(); - for (i, block) in get_blocks_ref(builder).iter().enumerate() { - let mut out = outgoing_edges(block); - if out.next() == Some(loop_info.merge_id) && out.next() == None { - possible_intermediate_block_idexes.push(i) - } - } - // check how many incoming edges the branch has and use that to collect a list of intermediate blocks. - let mut intermediate_block_ids = Vec::new(); - for i in possible_intermediate_block_idexes { - let intermediate_block_id = idx_to_id(builder, i); - let num_incoming_edges = num_incoming_edges(intermediate_block_id, builder); - if num_incoming_edges > 1 { - intermediate_block_ids.push(intermediate_block_id); - } - } - - if !intermediate_block_ids.is_empty() { - // Create a new empty block. - let old_merge_block_id = split_block(builder, loop_info.merge_id, false); - - // Create Phi - let phi_result_id = builder.id(); - let int_type_id = builder.type_int(32, 1); - let const_0_id = builder.constant_u32(int_type_id, 0); - - let mut phi_operands = vec![]; - for (intermediate_i, intermediate_block_id) in intermediate_block_ids.iter().enumerate() - { - let intermediate_i = intermediate_i as u32 + 1; - let const_x_id = builder.constant_u32(int_type_id, intermediate_i); - let t = incoming_edges(*intermediate_block_id, builder); - for blocks_that_go_to_intermediate in t { - phi_operands.push((const_x_id, blocks_that_go_to_intermediate)); - } - builder.name(*intermediate_block_id, "deferred".to_string()); - } - phi_operands.push((const_0_id, loop_info.header_id)); - - builder - .select_block(Some(find_block_index_from_id(builder, &loop_info.merge_id))) - .unwrap(); - builder - .insert_phi( - InsertPoint::Begin, - int_type_id, - Some(phi_result_id), - phi_operands, - ) - .unwrap(); - - // point all intermediate blocks to the new empty merge block. - for intermediate_block_id in intermediate_block_ids.iter() { - for incoming_id in incoming_edges(*intermediate_block_id, builder) { - let incoming_idx = find_block_index_from_id(builder, &incoming_id); - let incoming_block = get_block_mut!(builder, incoming_idx); - - for operand in &mut incoming_block.instructions.last_mut().unwrap().operands { - if *operand == Operand::IdRef(*intermediate_block_id) { - *operand = Operand::IdRef(loop_info.merge_id); // loop_info.merge_id is the same block as the new empty block from the last step. - } - } - } - } - - // Create a switch statement of all intermediate blocks. - change_block_to_switch( - builder, - loop_info.merge_id, - &intermediate_block_ids, - old_merge_block_id, - phi_result_id, - ); - - // point intermediate blocks to the old merge block. - for intermediate_block_id in intermediate_block_ids.iter() { - let intermediate_block_idx = - find_block_index_from_id(builder, intermediate_block_id); - for operand in &mut get_block_mut!(builder, intermediate_block_idx) - .instructions - .last_mut() - .unwrap() - .operands - { - if *operand == Operand::IdRef(loop_info.merge_id) { - *operand = Operand::IdRef(old_merge_block_id); - } - } - } - } - } -} - -// "Combines" all continue blocks into 1 and returns the ID of the continue block. -fn eliminate_multiple_continue_blocks(builder: &mut Builder, header: Word) -> Word { - // Find all possible continue blocks. - let mut continue_blocks = Vec::new(); - for block in get_blocks_ref(builder) { - let block_id = block.label_id().unwrap(); - if ends_in_branch(block) { - let edge = outgoing_edges(block).next().unwrap(); - if edge == header && block_is_parent_of(builder, header, block_id) { - continue_blocks.push(block_id); - } - } - } - // if there are multiple continue blocks we need to retarget towards a single continue. - if continue_blocks.len() > 1 { - let continue_block_id = continue_blocks.last().unwrap(); - for block_id in continue_blocks.iter().take(continue_blocks.len() - 1) { - let idx = find_block_index_from_id(builder, block_id); - let block = get_block_mut!(builder, idx); - for op in &mut block.instructions.last_mut().unwrap().operands { - if *op == Operand::IdRef(header) { - *op = Operand::IdRef(*continue_block_id); - } - } - } - - *continue_block_id - } else { - *continue_blocks.last().unwrap() - } -} - -fn block_leads_into_break(builder: &Builder, cf_info: &ControlFlowInfo, start: Word) -> bool { - let mut next: VecDeque = VecDeque::new(); - next.push_back(start); - - while let Some(front) = next.pop_front() { - let block_idx = find_block_index_from_id(builder, &front); - let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::>(); - - // Make sure we are not looping. - for loop_info in &cf_info.loops { - if new_edges.contains(&loop_info.header_id) { - let index = new_edges - .iter() - .position(|x| *x == loop_info.header_id) - .unwrap(); - new_edges.remove(index); - } - } - - // Skip inner branches. TODO: is this correct? - if ends_in_branch_conditional(get_block_ref!(builder, block_idx)) { - new_edges.clear(); - } - - // if front is a merge block return true - for loop_info in &cf_info.loops { - if front == loop_info.merge_id - && block_is_parent_of(builder, loop_info.header_id, start) - { - return true; - } - } - - next.extend(new_edges); - } - - false -} - -fn block_leads_into_continue(builder: &Builder, cf_info: &ControlFlowInfo, start: Word) -> bool { - let start_idx = find_block_index_from_id(builder, &start); - let new_edges = outgoing_edges(get_block_ref!(builder, start_idx)).collect::>(); - for loop_info in &cf_info.loops { - if new_edges.len() == 1 && loop_info.continue_id == new_edges[0] { - return true; - } - } - - false -} - -// every branch from a reaches b. -fn block_is_reverse_idom_of( - builder: &Builder, - cf_info: &ControlFlowInfo, - a: Word, - b: Word, -) -> bool { - let mut next: VecDeque = VecDeque::new(); - next.push_back(a); - - let mut processed = vec![a]; // ensures we are not looping. - - while let Some(front) = next.pop_front() { - let block_idx = find_block_index_from_id(builder, &front); - - if front == b { - continue; - } - - let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::>(); - - // Skip loop bodies by jumping to the merge block is we hit a header block. - for loop_info in &cf_info.loops { - if front == loop_info.header_id { - // TODO: should only do this for children i guess. - new_edges = vec![loop_info.merge_id]; - } - } - - for loop_info in &cf_info.loops { - // Make sure we are not looping. - if block_is_parent_of(builder, loop_info.header_id, a) - && new_edges.contains(&loop_info.header_id) - { - let index = new_edges - .iter() - .position(|x| *x == loop_info.header_id) - .unwrap(); - new_edges.remove(index); - } - - // Make sure we are not continuing after a merge. - if block_is_parent_of(builder, loop_info.header_id, a) && front == loop_info.merge_id { - new_edges.clear(); - } - } - - if new_edges.is_empty() { - return false; - } - - for id in &processed { - if let Some(i) = new_edges.iter().position(|x| x == id) { - new_edges.remove(i); - } - } - - processed.push(front); - next.extend(new_edges); - } - - true -} -fn get_possible_merge_positions( - builder: &Builder, - cf_info: &ControlFlowInfo, - start: Word, -) -> Vec { - let mut retval = Vec::new(); - for (idx, block) in get_blocks_ref(builder).iter().enumerate() { - if block_is_reverse_idom_of(builder, cf_info, start, block.label_id().unwrap()) { - retval.push(idx); - } - } - - retval -} - -fn block_is_parent_of(builder: &Builder, parent: Word, child: Word) -> bool { - let mut next: VecDeque = VecDeque::new(); - next.push_back(parent); - - let mut processed = vec![parent]; // ensures we are not looping. - - while let Some(front) = next.pop_front() { - let block_idx = find_block_index_from_id(builder, &front); - let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::>(); - - for id in &processed { - if let Some(i) = new_edges.iter().position(|x| x == id) { - new_edges.remove(i); - } - } - - if new_edges.contains(&child) { - return true; - } - - processed.push(front); - next.extend(new_edges); - } - - false -} - -// Returns the idx of the branch that loops. -fn get_looping_branch_from_block( - builder: &Builder, - cf_info: &ControlFlowInfo, - start: Word, -) -> Option { - let mut next: VecDeque = VecDeque::new(); - next.push_back(start); - - let mut processed = Vec::new(); - - while let Some(front) = next.pop_front() { - // make sure we separate inner from outer loops. - if front != start && cf_info.id_is_loops_header(front) { - continue; - } - - let block_idx = find_block_index_from_id(builder, &front); - let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::>(); - - let edge_it = new_edges.iter().find(|&x| x == &start); // Check if the new_edges contain the start - if new_edges.len() == 1 { - if let Some(edge_it) = edge_it { - // loop over the orginal edges to find which branch is looping - let start_idx = find_block_index_from_id(builder, &front); - let start_edges = outgoing_edges(get_block_ref!(builder, start_idx)); - - for (i, start_edge) in start_edges.enumerate() { - if start_edge == *edge_it || block_is_parent_of(builder, start_edge, *edge_it) { - return Some(i); - } - } - } - } - - for id in &processed { - if let Some(i) = new_edges.iter().position(|x| x == id) { - new_edges.remove(i); - } - } - processed.push(front); - - next.extend(new_edges); - } - - None -} - -fn ends_in_branch_conditional(block: &Block) -> bool { - let last_inst = block.instructions.last().unwrap(); - last_inst.class.opcode == Op::BranchConditional -} - -fn ends_in_branch(block: &Block) -> bool { - let last_inst = block.instructions.last().unwrap(); - last_inst.class.opcode == Op::Branch -} - -fn ends_in_return(block: &Block) -> bool { - let last_inst = block.instructions.last().unwrap(); - last_inst.class.opcode == Op::Return || last_inst.class.opcode == Op::ReturnValue -} - -// Returns the new id assigned to the original block. -fn split_block(builder: &mut Builder, block_to_split: Word, retarget: bool) -> Word { - // assign old block new id. - let new_original_block_id = builder.id(); - let block_to_split_index = find_block_index_from_id(builder, &block_to_split); - let orignial_block = get_block_mut!(builder, block_to_split_index); - orignial_block.label.as_mut().unwrap().result_id = Some(new_original_block_id); - // create new block with old id. - builder.begin_block(Some(block_to_split)).unwrap(); - // new block branches to old block. - builder.branch(new_original_block_id).unwrap(); - if retarget { - // update all merge ops to point the the old block with its new id. - for block in get_blocks_mut(builder) { - for inst in &mut block.instructions { - if inst.class.opcode == Op::LoopMerge || inst.class.opcode == Op::SelectionMerge { - for operand in &mut inst.operands { - if *operand == Operand::IdRef(block_to_split) { - *operand = Operand::IdRef(new_original_block_id); - } - } - } - } - } - } - - new_original_block_id -} - -fn make_unreachable_block(builder: &mut Builder) -> Word { - let id = builder.id(); - builder.begin_block(Some(id)).unwrap(); - builder.unreachable().unwrap(); - id -} - -pub fn insert_selection_merge_on_conditional_branch( - sess: &Session, - builder: &mut Builder, - cf_info: &mut ControlFlowInfo, -) { - let mut branch_conditional_ops = Vec::new(); - - // Find conditional branches that are not loops - for block in get_blocks_ref(builder) { - if ends_in_branch_conditional(block) - && !cf_info.id_is_loops_header(block.label_id().unwrap()) - { - branch_conditional_ops.push(block.label_id().unwrap()); - } - } - - let mut modified_ids = FxHashMap::default(); - - // Find convergence point. - for id in branch_conditional_ops.iter() { - let id = match modified_ids.get_key_value(id) { - Some((_, value)) => value, - None => id, - }; - - let bi = find_block_index_from_id(builder, id); - let out = outgoing_edges(&get_blocks_ref(builder)[bi]).collect::>(); - let id = idx_to_id(builder, bi); - let a_nexts = get_possible_merge_positions(builder, cf_info, out[0]); - let b_nexts = get_possible_merge_positions(builder, cf_info, out[1]); - - // Check for a matching possible merge position. - let mut first_merge = None; - 'outer: for a in &a_nexts { - for b in &b_nexts { - if *a == *b { - first_merge = Some(*a); - break 'outer; - } - } - } - - let merge_block_id = if let Some(idx) = first_merge { - // We found a existing block that we can use as a merge block! - idx_to_id(builder, idx) - } else { - let a_first_id = out[0]; - let b_first_id = out[1]; - let a_last_idx = match a_nexts.last() { - Some(last) => *last, - None => find_block_index_from_id(builder, &out[0]), - }; - let b_last_idx = match b_nexts.last() { - Some(last) => *last, - None => find_block_index_from_id(builder, &out[1]), - }; - - let branch_a_breaks = block_leads_into_break(builder, cf_info, a_first_id); - let branch_b_breaks = block_leads_into_break(builder, cf_info, b_first_id); - let branch_a_continues = block_leads_into_continue(builder, cf_info, a_first_id); - let branch_b_continues = block_leads_into_continue(builder, cf_info, b_first_id); - let branch_a_returns = ends_in_return(get_block_ref!(builder, a_last_idx)); - let branch_b_returns = ends_in_return(get_block_ref!(builder, b_last_idx)); - - if ((branch_a_breaks || branch_a_continues) && (branch_b_breaks || branch_b_continues)) - || branch_a_returns && branch_b_returns - { - // (fully unreachable) insert a rando block and mark as merge. - make_unreachable_block(builder) - } else if branch_a_breaks || branch_a_continues || branch_a_returns { - // (partially unreachable) merge block becomes branch b immediatly - b_first_id - } else if branch_b_breaks || branch_b_continues || branch_b_returns { - // (partially unreachable) merge block becomes branch a immediatly - a_first_id - } else { - // In theory this should never happen. - sess.fatal("UNEXPECTED, Unknown exit detected."); - } - }; - - if cf_info.used(merge_block_id) { - let new_id = split_block(builder, merge_block_id, true); - cf_info.retarget(merge_block_id, new_id); - - if branch_conditional_ops.contains(&merge_block_id) { - modified_ids.insert(merge_block_id, new_id); - } - } - - cf_info.if_merge_ids.push(merge_block_id); - - // Insert the merge instruction - let bi = find_block_index_from_id(builder, &id); // after this we don't insert or remove blocks - builder.select_block(Some(bi)).unwrap(); - builder - .insert_selection_merge( - InsertPoint::FromEnd(1), - merge_block_id, - SelectionControl::NONE, - ) - .unwrap(); - } -} - -pub fn insert_loop_merge_on_conditional_branch( - builder: &mut Builder, - cf_info: &mut ControlFlowInfo, + /// `LoopControl` to use in all loops' `OpLoopMerge` instruction. + /// Currently only affected by function-scoped `#[spirv(unroll_loops)]`. loop_control: LoopControl, -) { - let mut branch_conditional_ops = Vec::new(); - // Find conditional branches that are loops, and find which branch is the one that loops. - for block in get_blocks_ref(builder) { - if ends_in_branch_conditional(block) { - let block_id = block.label_id().unwrap(); - if let Some(looping_branch_idx) = - get_looping_branch_from_block(builder, cf_info, block_id) + /// Number of edges pointing to each block. + /// Computed by `post_order` and updated when structuring loops + /// (backedge count is subtracted to hide them from outer regions). + incoming_edge_count: Vec, + + regions: FxHashMap, +} + +impl Structurizer<'_> { + fn structurize_func(&mut self) { + let Globals { + const_false, + const_true, + type_bool, + } = self.globals; + + // By iterating in post-order, we are guaranteed to visit "inner" regions + // before "outer" ones. + for block in self.post_order() { + let block_id = self.func.blocks()[block].label_id().unwrap(); + let terminator = self.func.blocks()[block].instructions.last().unwrap(); + let mut region = match terminator.class.opcode { + Op::Return + | Op::ReturnValue + | Op::Kill + | Op::IgnoreIntersectionKHR + | Op::TerminateRayKHR + | Op::Unreachable => Region { + merge: block, + merge_id: block_id, + exits: indexmap! {}, + }, + + Op::Branch => { + let target = self.block_id_to_idx[&terminator.operands[0].unwrap_id_ref()]; + self.child_region(target).unwrap_or_else(|| { + self.func.builder.select_block(Some(block)).unwrap(); + self.func.builder.pop_instruction().unwrap(); + // Default all merges to `OpUnreachable`, in case they're unused. + self.func.builder.unreachable().unwrap(); + Region { + merge: block, + merge_id: block_id, + exits: indexmap! { + target => Exit { edge_count: 1, condition: None } + }, + } + }) + } + + Op::BranchConditional | Op::Switch => { + let target_operand_indices = match terminator.class.opcode { + Op::BranchConditional => (1..3).step_by(1), + Op::Switch => (1..terminator.operands.len()).step_by(2), + _ => unreachable!(), + }; + + // FIXME(eddyb) avoid wasteful allocation. + let child_regions: Vec<_> = target_operand_indices + .map(|i| { + let target_id = self.func.blocks()[block] + .instructions + .last() + .unwrap() + .operands[i] + .unwrap_id_ref(); + let target = self.block_id_to_idx[&target_id]; + self.child_region(target).unwrap_or_else(|| { + // Synthesize a single-block region for every edge that + // doesn't already enter a child region, so that the + // merge block we later generate has an unique source for + // every single arm of this conditional branch or switch, + // to attach per-exit condition phis to. + let new_block_id = self.func.builder.begin_block(None).unwrap(); + let new_block = self.func.builder.selected_block().unwrap(); + // Default all merges to `OpUnreachable`, in case they're unused. + self.func.builder.unreachable().unwrap(); + self.func.blocks_mut()[block] + .instructions + .last_mut() + .unwrap() + .operands[i] = Operand::IdRef(new_block_id); + Region { + merge: new_block, + merge_id: new_block_id, + exits: indexmap! { + target => Exit { edge_count: 1, condition: None } + }, + } + }) + }) + .collect(); + + self.selection_merge_regions(block, &child_regions) + } + _ => panic!("Invalid block terminator: {:?}", terminator), + }; + + // Peel off deferred exits which have all their edges accounted for + // already, within this region. Repeat until no such exits are left. + while let Some((&target, _)) = region + .exits + .iter() + .find(|&(&target, exit)| exit.edge_count == self.incoming_edge_count[target]) { - branch_conditional_ops.push((block_id, looping_branch_idx)); - cf_info.loops.push(LoopInfo { - header_id: block_id, - merge_id: 0, - continue_id: 0, - }) + let taken_block_id = self.func.blocks()[target].label_id().unwrap(); + let exit = region.exits.remove(&target).unwrap(); + + // Special-case the last exit as unconditional - regardless of + // what might end up in `exit.condition`, what we'd generate is + // `if exit.condition { branch target; } else { unreachable; }` + // which is just `branch target;` with an extra assumption that + // `exit.condition` is `true` (which we can just ignore). + if region.exits.is_empty() { + self.func.builder.select_block(Some(region.merge)).unwrap(); + assert_eq!( + self.func.builder.pop_instruction().unwrap().class.opcode, + Op::Unreachable + ); + self.func.builder.branch(taken_block_id).unwrap(); + region = self.regions.remove(&target).unwrap(); + continue; + } + + // Create a new block for the "`exit` not taken" path. + let not_taken_block_id = self.func.builder.begin_block(None).unwrap(); + let not_taken_block = self.func.builder.selected_block().unwrap(); + // Default all merges to `OpUnreachable`, in case they're unused. + self.func.builder.unreachable().unwrap(); + + // Choose whether to take this `exit`, in the previous merge block. + let branch_block = region.merge; + self.func.builder.select_block(Some(branch_block)).unwrap(); + assert_eq!( + self.func.builder.pop_instruction().unwrap().class.opcode, + Op::Unreachable + ); + self.func + .builder + .branch_conditional( + exit.condition.unwrap(), + taken_block_id, + not_taken_block_id, + iter::empty(), + ) + .unwrap(); + + // Merge the "taken" and "not taken" paths. + let taken_region = self.regions.remove(&target).unwrap(); + let not_taken_region = Region { + merge: not_taken_block, + merge_id: not_taken_block_id, + exits: region.exits, + }; + region = + self.selection_merge_regions(branch_block, &[taken_region, not_taken_region]); } + + // Peel off a backedge exit, which indicates this region is a loop. + if let Some(mut backedge_exit) = region.exits.remove(&block) { + // Inject a `while`-like loop header just before the start of the + // loop body. This is needed because our "`break` vs `continue`" + // choice is *after* the loop body, like in a `do`-`while` loop, + // but SPIR-V requires it at the start, like in a `while` loop. + let while_header_block_id = self.func.builder.begin_block(None).unwrap(); + let while_header_block = self.func.builder.selected_block().unwrap(); + self.func.builder.select_block(None).unwrap(); + let while_exit_block_id = self.func.builder.begin_block(None).unwrap(); + let while_exit_block = self.func.builder.selected_block().unwrap(); + // Default all merges to `OpUnreachable`, in case they're unused. + self.func.builder.unreachable().unwrap(); + let while_body_block_id = self.func.builder.begin_block(None).unwrap(); + let while_body_block = self.func.builder.selected_block().unwrap(); + self.func.builder.select_block(None).unwrap(); + + // Move all of the contents of the original `block` into the + // new loop body, but keep labels and indices intact. + // Also update the existing merge if it happens to be the `block` + // we just moved (this should only be relevant to infinite loops). + self.func.blocks_mut()[while_body_block].instructions = + mem::take(&mut self.func.blocks_mut()[block].instructions); + if region.merge == block { + region.merge = while_body_block; + region.merge_id = while_body_block_id; + } + + // Create a separate merge block for the loop body, as the original + // one might be used by an `OpSelectionMerge` and cannot be reused. + let while_body_merge_id = self.func.builder.begin_block(None).unwrap(); + let while_body_merge = self.func.builder.selected_block().unwrap(); + self.func.builder.select_block(None).unwrap(); + self.func.builder.select_block(Some(region.merge)).unwrap(); + assert_eq!( + self.func.builder.pop_instruction().unwrap().class.opcode, + Op::Unreachable + ); + self.func.builder.branch(while_body_merge_id).unwrap(); + + // Point both the original block and the merge of the loop body, + // at the new loop header, and compute phis for all the exit + // conditions (including the backedge, which indicates "continue"). + self.func.builder.select_block(Some(block)).unwrap(); + self.func.builder.branch(while_header_block_id).unwrap(); + self.func + .builder + .select_block(Some(while_body_merge)) + .unwrap(); + self.func.builder.branch(while_header_block_id).unwrap(); + self.func + .builder + .select_block(Some(while_header_block)) + .unwrap(); + + for (&target, exit) in region + .exits + .iter_mut() + .chain(iter::once((&while_body_block, &mut backedge_exit))) + { + let first_entry_case = ( + if target == while_body_block { + const_true + } else { + const_false + }, + block_id, + ); + let repeat_case = (exit.condition.unwrap_or(const_true), while_body_merge_id); + let phi_cases = [first_entry_case, repeat_case]; + exit.condition = Some( + self.func + .builder + .phi(type_bool, None, phi_cases.iter().copied()) + .unwrap(), + ); + } + + // Choose whether to keep looping, in the `while`-like loop header. + self.func + .builder + .select_block(Some(while_header_block)) + .unwrap(); + self.func + .builder + .loop_merge( + while_exit_block_id, + while_body_merge_id, + self.loop_control, + iter::empty(), + ) + .unwrap(); + self.func + .builder + .select_block(Some(while_header_block)) + .unwrap(); + self.func + .builder + .branch_conditional( + backedge_exit.condition.unwrap(), + while_body_block_id, + while_exit_block_id, + iter::empty(), + ) + .unwrap(); + region.merge = while_exit_block; + region.merge_id = while_exit_block_id; + + // Remove the backedge count from the total incoming count of `block`. + // This will allow outer regions to treat the loop opaquely. + self.incoming_edge_count[block] -= backedge_exit.edge_count; + } + + self.regions.insert(block, region); + } + + assert_eq!(self.regions.len(), 1); + assert_eq!(self.regions.values().next().unwrap().exits.len(), 0); + } + + fn child_region(&mut self, target: BlockIdx) -> Option { + // An "entry" edge is the unique edge into a region. + if self.incoming_edge_count[target] == 1 { + Some(self.regions.remove(&target).unwrap()) + } else { + None } } - let mut modified_ids = FxHashMap::default(); - // Figure out which branch loops and which branch should merge, also find any potential break ops. - for (id, looping_branch_idx) in branch_conditional_ops.iter() { - let id = match modified_ids.get_key_value(id) { - Some((_, value)) => *value, - None => *id, + fn selection_merge_regions(&mut self, block: BlockIdx, child_regions: &[Region]) -> Region { + let Globals { + const_false, + const_true, + type_bool, + } = self.globals; + + // HACK(eddyb) this special-cases the easy case where we can + // just reuse a merge block, and don't have to create our own. + let unconditional_single_exit = |region: &Region| { + region.exits.len() == 1 && region.exits.get_index(0).unwrap().1.condition.is_none() + }; + let structural_merge = if child_regions.iter().all(unconditional_single_exit) { + let merge = *child_regions[0].exits.get_index(0).unwrap().0; + if child_regions + .iter() + .all(|region| *region.exits.get_index(0).unwrap().0 == merge) + && child_regions + .iter() + .map(|region| region.exits.get_index(0).unwrap().1.edge_count) + .sum::() + == self.incoming_edge_count[merge] + { + Some(merge) + } else { + None + } + } else { + None }; - let merge_branch_idx = (looping_branch_idx + 1) % 2; - let bi = find_block_index_from_id(builder, &id); - let out = outgoing_edges(&get_blocks_ref(builder)[bi]).collect::>(); - - let continue_block_id = eliminate_multiple_continue_blocks(builder, id); - let merge_block_id = out[merge_branch_idx]; - - if cf_info.used(continue_block_id) { - let new_id = split_block(builder, continue_block_id, true); - cf_info.retarget(continue_block_id, new_id); - - if branch_conditional_ops.contains(&(continue_block_id, *looping_branch_idx)) { - modified_ids.insert(continue_block_id, new_id); - } - } - if cf_info.used(merge_block_id) { - let new_id = split_block(builder, merge_block_id, true); - cf_info.retarget(merge_block_id, new_id); - - if branch_conditional_ops.contains(&(merge_block_id, *looping_branch_idx)) { - modified_ids.insert(merge_block_id, new_id); - } - } - - cf_info.set_loops_continue_and_merge(id, merge_block_id, continue_block_id); - - // Insert the merge instruction - let bi = find_block_index_from_id(builder, &id); // after this we don't insert or remove blocks - builder.select_block(Some(bi)).unwrap(); - builder - .insert_loop_merge( - InsertPoint::FromEnd(1), - merge_block_id, - continue_block_id, - loop_control, - None, - ) + // Reuse or create a merge block, and use it as the selection merge. + let merge = structural_merge.unwrap_or_else(|| { + self.func.builder.begin_block(None).unwrap(); + self.func.builder.selected_block().unwrap() + }); + let merge_id = self.func.blocks()[merge].label_id().unwrap(); + self.func.builder.select_block(Some(block)).unwrap(); + self.func + .builder + .insert_selection_merge(InsertPoint::FromEnd(1), merge_id, SelectionControl::NONE) .unwrap(); + + // Branch all the child regions into our merge block. + for region in child_regions { + // HACK(eddyb) empty `region.exits` indicate diverging control-flow, + // and that we should ignore `region.merge`. + if !region.exits.is_empty() { + self.func.builder.select_block(Some(region.merge)).unwrap(); + assert_eq!( + self.func.builder.pop_instruction().unwrap().class.opcode, + Op::Unreachable + ); + self.func.builder.branch(merge_id).unwrap(); + } + } + + if let Some(merge) = structural_merge { + self.regions.remove(&merge).unwrap() + } else { + self.func.builder.select_block(Some(merge)).unwrap(); + + // Gather all the potential exits. + let mut exits: IndexMap = indexmap! {}; + for region in child_regions { + for (&target, exit) in ®ion.exits { + exits.entry(target).or_default().edge_count += exit.edge_count; + } + } + + // Update conditions using phis. + for (&target, exit) in &mut exits { + let phi_cases = child_regions + .iter() + .filter(|region| { + // HACK(eddyb) empty `region.exits` indicate diverging control-flow, + // and that we should ignore `region.merge`. + !region.exits.is_empty() + }) + .map(|region| { + ( + match region.exits.get(&target) { + Some(exit) => exit.condition.unwrap_or(const_true), + None => const_false, + }, + region.merge_id, + ) + }); + exit.condition = Some(self.func.builder.phi(type_bool, None, phi_cases).unwrap()); + } + + // Default all merges to `OpUnreachable`, in case they're unused. + self.func.builder.unreachable().unwrap(); + + Region { + merge, + merge_id, + exits, + } + } + } + + // FIXME(eddyb) replace this with `rustc_data_structures::graph::iterate` + // (or similar). + fn post_order(&mut self) -> Vec { + let blocks = self.func.blocks(); + + // HACK(eddyb) compute edge counts through the post-order traversal. + assert!(self.incoming_edge_count.is_empty()); + self.incoming_edge_count = vec![0; blocks.len()]; + + // FIXME(eddyb) use a proper bitset. + let mut visited = vec![false; blocks.len()]; + let mut post_order = Vec::with_capacity(blocks.len()); + + self.post_order_step(0, &mut visited, &mut post_order); + + post_order + } + + fn post_order_step( + &mut self, + block: BlockIdx, + visited: &mut [bool], + post_order: &mut Vec, + ) { + self.incoming_edge_count[block] += 1; + + if visited[block] { + return; + } + visited[block] = true; + + for target in + super::simple_passes::outgoing_edges(&self.func.blocks()[block]).collect::>() + { + self.post_order_step(self.block_id_to_idx[&target], visited, post_order) + } + + post_order.push(block); } } diff --git a/crates/rustc_codegen_spirv/src/linker/test.rs b/crates/rustc_codegen_spirv/src/linker/test.rs index c5e8580586..4573d94b1f 100644 --- a/crates/rustc_codegen_spirv/src/linker/test.rs +++ b/crates/rustc_codegen_spirv/src/linker/test.rs @@ -94,7 +94,6 @@ fn assemble_and_link(binaries: &[&[u8]]) -> Result { inline: false, mem2reg: false, structurize: false, - use_new_structurizer: false, emit_multiple_modules: false, }, );