diff --git a/crates/rustc_codegen_spirv/src/linker/structurizer.rs b/crates/rustc_codegen_spirv/src/linker/structurizer.rs index ec4b0040b8..d3ece57adb 100644 --- a/crates/rustc_codegen_spirv/src/linker/structurizer.rs +++ b/crates/rustc_codegen_spirv/src/linker/structurizer.rs @@ -132,6 +132,14 @@ impl ControlFlowInfo { } } +fn emit_compiler_error(sess: Option<&Session>, msg: &'static str) -> ! { + if let Some(sess) = sess { + sess.fatal(msg); + } else { + panic!(msg); + } +} + pub fn structurize(sess: Option<&Session>, module: &mut Module) { let mut debug_names = Vec::new(); @@ -228,7 +236,39 @@ fn retarget_loop_children_if_needed(blocks: &mut [Block], cf_info: &ControlFlowI } } -fn block_leads_into_break(blocks: &mut [Block], cf_info: &ControlFlowInfo, start: Word) -> bool { +// "Combines" all continue blocks into 1 and returns the ID of the continue block. +fn eliminate_multiple_continue_blocks(blocks: &mut Vec, header: Word) -> Word { + // Find all possible continue blocks. + let mut continue_blocks = Vec::new(); + for block in blocks.iter() { + let block_id = block.label_id().unwrap(); + if ends_in_branch(block) { + let edge = outgoing_edges(block)[0]; + if edge == header && block_is_parent_of(header, block_id, blocks) { + continue_blocks.push(block_id); + } + } + } + // if there are multiple continue blocks we need to retarget towards a single continue. + if continue_blocks.len() > 1 { + let continue_block_id = continue_blocks.last().unwrap(); + for block_id in continue_blocks.iter().take(continue_blocks.len() - 1) { + let idx = find_block_index_from_id(blocks, block_id); + let block = &mut blocks[idx]; + for op in &mut block.instructions.last_mut().unwrap().operands { + if *op == Operand::IdRef(header) { + *op = Operand::IdRef(*continue_block_id); + } + } + } + + *continue_block_id + } else { + *continue_blocks.last().unwrap() + } +} + +fn block_leads_into_break(blocks: &[Block], cf_info: &ControlFlowInfo, start: Word) -> bool { let mut next: VecDeque = VecDeque::new(); next.push_back(start); @@ -266,6 +306,18 @@ fn block_leads_into_break(blocks: &mut [Block], cf_info: &ControlFlowInfo, start false } +fn block_leads_into_continue(blocks: &[Block], cf_info: &ControlFlowInfo, start: Word) -> bool { + let start_idx = find_block_index_from_id(blocks, &start); + let new_edges = outgoing_edges(&blocks[start_idx]); + for loop_info in &cf_info.loops { + if new_edges.len() == 1 && loop_info.continue_id == new_edges[0] { + return true; + } + } + + false +} + fn get_possible_merge_positions( blocks: &[Block], cf_info: &ControlFlowInfo, @@ -295,8 +347,9 @@ fn get_possible_merge_positions( .unwrap(); new_edges.remove(index); } + // Make sure we are not continuing after a merge. - if block_is_parent_of(loop_info.merge_id, start, blocks) && front == loop_info.merge_id + if block_is_parent_of(loop_info.header_id, start, blocks) && front == loop_info.merge_id { new_edges.clear(); } @@ -318,21 +371,22 @@ fn block_is_parent_of(parent: Word, child: Word, blocks: &[Block]) -> bool { next.push_back(parent); let mut processed = Vec::new(); + processed.push(parent); // ensures we are not looping. while let Some(front) = next.pop_front() { let block_idx = find_block_index_from_id(blocks, &front); let mut new_edges = outgoing_edges(&blocks[block_idx]); - if new_edges.contains(&child) { - return true; - } - for id in &processed { if let Some(i) = new_edges.iter().position(|x| x == id) { new_edges.remove(i); } } + if new_edges.contains(&child) { + return true; + } + processed.push(front); next.extend(new_edges); } @@ -340,12 +394,12 @@ fn block_is_parent_of(parent: Word, child: Word, blocks: &[Block]) -> bool { false } -// Returns the idx of the branch that loops and the idx to the block that branches to the original block. +// Returns the idx of the branch that loops. fn get_looping_branch_from_block( blocks: &[Block], cf_info: &ControlFlowInfo, start: Word, -) -> Option<(usize, usize)> { +) -> Option { let mut next: VecDeque = VecDeque::new(); next.push_back(start); @@ -361,13 +415,15 @@ fn get_looping_branch_from_block( let mut new_edges = outgoing_edges(&blocks[block_idx]); let edge_it = new_edges.iter().find(|&x| x == &start); // Check if the new_edges contain the start - if let Some(edge_it) = edge_it { - // loop over the orginal edges to find which branch is looping - let start_edges = outgoing_edges(&blocks[find_block_index_from_id(blocks, &start)]); + if new_edges.len() == 1 { + if let Some(edge_it) = edge_it { + // loop over the orginal edges to find which branch is looping + let start_edges = outgoing_edges(&blocks[find_block_index_from_id(blocks, &start)]); - for (i, start_edge) in start_edges.iter().enumerate() { - if start_edge == edge_it || block_is_parent_of(*start_edge, *edge_it, blocks) { - return Some((i, block_idx)); + for (i, start_edge) in start_edges.iter().enumerate() { + if start_edge == edge_it || block_is_parent_of(*start_edge, *edge_it, blocks) { + return Some(i); + } } } } @@ -390,6 +446,11 @@ fn ends_in_branch_conditional(block: &Block) -> bool { last_inst.class.opcode == Op::BranchConditional } +fn ends_in_branch(block: &Block) -> bool { + let last_inst = block.instructions.last().unwrap(); + last_inst.class.opcode == Op::Branch +} + fn ends_in_return(block: &Block) -> bool { let last_inst = block.instructions.last().unwrap(); last_inst.class.opcode == Op::Return || last_inst.class.opcode == Op::ReturnValue @@ -432,6 +493,20 @@ fn split_block(header: &mut ModuleHeader, blocks: &mut Vec, block_to_spli new_original_block_id } +fn make_unreachable_block(header: &mut ModuleHeader, blocks: &mut Vec) -> Word { + let id = id(header); + let mut new_block = Block::new(); + new_block.label = Some(Instruction::new(Op::Label, None, Some(id), vec![])); + // new block is unreachable + new_block + .instructions + .push(Instruction::new(Op::Unreachable, None, None, vec![])); + + // insert new block at the end + blocks.push(new_block); + id +} + pub fn insert_selection_merge_on_conditional_branch( sess: Option<&Session>, header: &mut ModuleHeader, @@ -441,7 +516,7 @@ pub fn insert_selection_merge_on_conditional_branch( let mut branch_conditional_ops = Vec::new(); // Find conditional branches that are not loops - for block in &blocks.clone() { + for block in blocks.iter() { if ends_in_branch_conditional(block) && !cf_info.id_is_loops_header(block.label_id().unwrap()) { @@ -450,7 +525,7 @@ pub fn insert_selection_merge_on_conditional_branch( } // Find convergence point. - for id in branch_conditional_ops.clone() { + for id in branch_conditional_ops { let bi = find_block_index_from_id(blocks, &id); let out = outgoing_edges(&blocks[bi]); let id = &blocks[bi].label_id().unwrap(); @@ -487,41 +562,37 @@ pub fn insert_selection_merge_on_conditional_branch( let branch_a_breaks = block_leads_into_break(blocks, cf_info, a_first_id); let branch_b_breaks = block_leads_into_break(blocks, cf_info, b_first_id); - let branch_a_continues = false; - let branch_b_continues = false; + let branch_a_continues = block_leads_into_continue(blocks, cf_info, a_first_id); + let branch_b_continues = block_leads_into_continue(blocks, cf_info, b_first_id); let branch_a_returns = ends_in_return(&blocks[a_last_idx]); let branch_b_returns = ends_in_return(&blocks[b_last_idx]); - if (branch_a_breaks || branch_a_continues) && (branch_b_breaks || branch_b_continues) { + if ((branch_a_breaks || branch_a_continues) && (branch_b_breaks || branch_b_continues)) + || branch_a_returns && branch_b_returns + { // (fully unreachable) insert a rando block and mark as merge. - if let Some(sess) = sess { - sess.err("UNIMPLEMENTED, A fully unreachable case was detected."); - } - return; - } else if branch_a_breaks || branch_a_continues { + make_unreachable_block(header, blocks) + } else if branch_a_breaks || branch_a_continues || branch_a_returns { // (partially unreachable) merge block becomes branch b immediatly blocks[b_first_idx].label_id().unwrap() - } else if branch_b_breaks || branch_a_continues { + } else if branch_b_breaks || branch_b_continues || branch_b_returns { // (partially unreachable) merge block becomes branch a immediatly blocks[a_first_idx].label_id().unwrap() } else if branch_a_returns { // (partially unreachable) merge block becomes end/start of b. - if let Some(sess) = sess { - sess.err("UNIMPLEMENTED, A partially unreachable case was detected on a."); - } - return; + emit_compiler_error( + sess, + "UNIMPLEMENTED, A partially unreachable case was detected on a.", + ); } else if branch_b_returns { // (partially unreachable) merge block becomes end/start of a. - if let Some(sess) = sess { - sess.err("UNIMPLEMENTED, A partially unreachable case was detected on b."); - } - return; + emit_compiler_error( + sess, + "UNIMPLEMENTED, A partially unreachable case was detected on b.", + ); } else { - // (fully unreachable) insert a rando block and mark as merge. - if let Some(sess) = sess { - sess.err("UNIMPLEMENTED, A fully unreachable case was detected."); - } - return; + // In theory this should never happen. + emit_compiler_error(sess, "UNEXPECTED, Unknown exit detected."); } }; @@ -572,12 +643,12 @@ pub fn insert_loop_merge_on_conditional_branch( } // Figure out which branch loops and which branch should merge, also find any potential break ops. - for (bi, (looping_branch_idx, continue_block_idx)) in branch_conditional_ops { + for (bi, looping_branch_idx) in branch_conditional_ops { let merge_branch_idx = (looping_branch_idx + 1) % 2; let id = &blocks[bi].label_id().unwrap(); let out = outgoing_edges(&blocks[bi]); - let continue_block_id = blocks[continue_block_idx].label_id().unwrap(); + let continue_block_id = eliminate_multiple_continue_blocks(blocks, *id); let merge_block_id = out[merge_branch_idx]; if cf_info.used(continue_block_id) { diff --git a/crates/spirv-builder/src/test/control_flow.rs b/crates/spirv-builder/src/test/control_flow.rs index 9286e60731..72e740d7a1 100644 --- a/crates/spirv-builder/src/test/control_flow.rs +++ b/crates/spirv-builder/src/test/control_flow.rs @@ -103,6 +103,143 @@ pub fn main(i: Input) { "#); } +#[test] +fn cf_while_if_break_if_break() { + val(r#" +#[allow(unused_attributes)] +#[spirv(fragment)] +pub fn main(i: Input) { + while i.load() < 10 { + if i.load() == 0 { + break; + } + if i.load() == 1 { + break; + } + } +} +"#); +} + +#[test] +fn cf_while_while_continue() { + val(r#" +#[allow(unused_attributes)] +#[spirv(fragment)] +pub fn main(i: Input) { + while i.load() < 20 { + while i.load() < 10 { + continue; + } + } +} +"#); +} + +#[test] +fn cf_while_while_if_continue() { + val(r#" +#[allow(unused_attributes)] +#[spirv(fragment)] +pub fn main(i: Input) { + while i.load() < 20 { + while i.load() < 10 { + if i.load() > 5 { + continue; + } + } + } +} +"#); +} + +#[test] +fn cf_while_continue() { + val(r#" +#[allow(unused_attributes)] +#[spirv(fragment)] +pub fn main(i: Input) { + while i.load() < 10 { + continue; + } +} +"#); +} + +#[test] +fn cf_while_if_continue() { + val(r#" +#[allow(unused_attributes)] +#[spirv(fragment)] +pub fn main(i: Input) { + while i.load() < 10 { + if i.load() == 0 { + continue; + } + } +} +"#); +} + +#[test] +fn cf_while_if_continue_else_continue() { + val(r#" +#[allow(unused_attributes)] +#[spirv(fragment)] +pub fn main(i: Input) { + while i.load() < 10 { + if i.load() == 0 { + continue; + } else { + continue; + } + } +} +"#); +} + +#[test] +fn cf_while_return() { + val(r#" +#[allow(unused_attributes)] +#[spirv(fragment)] +pub fn main(i: Input) { + while i.load() < 10 { + return; + } +} +"#); +} + +#[test] +fn cf_if_return_else() { + val(r#" +#[allow(unused_attributes)] +#[spirv(fragment)] +pub fn main(i: Input) { + if i.load() < 10 { + return; + } else { + } +} +"#); +} + +#[test] +fn cf_if_return_else_return() { + val(r#" +#[allow(unused_attributes)] +#[spirv(fragment)] +pub fn main(i: Input) { + if i.load() < 10 { + return; + } else { + return; + } +} +"#); +} + #[test] fn cf_if_while() { val(r#"