Continue and minor break fixes (#202)

* Continue and break fixes

* Ashley Clippy

* Incorporated Feedback

* Clippy lint

* clippy

* minor fix
This commit is contained in:
Viktor Zoutman 2020-11-02 16:14:06 +01:00 committed by GitHub
parent 72e1373e1c
commit 8d2b8ce5f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 248 additions and 40 deletions

View File

@ -132,6 +132,14 @@ impl ControlFlowInfo {
}
}
fn emit_compiler_error(sess: Option<&Session>, msg: &'static str) -> ! {
if let Some(sess) = sess {
sess.fatal(msg);
} else {
panic!(msg);
}
}
pub fn structurize(sess: Option<&Session>, module: &mut Module) {
let mut debug_names = Vec::new();
@ -228,7 +236,39 @@ fn retarget_loop_children_if_needed(blocks: &mut [Block], cf_info: &ControlFlowI
}
}
fn block_leads_into_break(blocks: &mut [Block], cf_info: &ControlFlowInfo, start: Word) -> bool {
// "Combines" all continue blocks into 1 and returns the ID of the continue block.
fn eliminate_multiple_continue_blocks(blocks: &mut Vec<Block>, header: Word) -> Word {
// Find all possible continue blocks.
let mut continue_blocks = Vec::new();
for block in blocks.iter() {
let block_id = block.label_id().unwrap();
if ends_in_branch(block) {
let edge = outgoing_edges(block)[0];
if edge == header && block_is_parent_of(header, block_id, blocks) {
continue_blocks.push(block_id);
}
}
}
// if there are multiple continue blocks we need to retarget towards a single continue.
if continue_blocks.len() > 1 {
let continue_block_id = continue_blocks.last().unwrap();
for block_id in continue_blocks.iter().take(continue_blocks.len() - 1) {
let idx = find_block_index_from_id(blocks, block_id);
let block = &mut blocks[idx];
for op in &mut block.instructions.last_mut().unwrap().operands {
if *op == Operand::IdRef(header) {
*op = Operand::IdRef(*continue_block_id);
}
}
}
*continue_block_id
} else {
*continue_blocks.last().unwrap()
}
}
fn block_leads_into_break(blocks: &[Block], cf_info: &ControlFlowInfo, start: Word) -> bool {
let mut next: VecDeque<Word> = VecDeque::new();
next.push_back(start);
@ -266,6 +306,18 @@ fn block_leads_into_break(blocks: &mut [Block], cf_info: &ControlFlowInfo, start
false
}
fn block_leads_into_continue(blocks: &[Block], cf_info: &ControlFlowInfo, start: Word) -> bool {
let start_idx = find_block_index_from_id(blocks, &start);
let new_edges = outgoing_edges(&blocks[start_idx]);
for loop_info in &cf_info.loops {
if new_edges.len() == 1 && loop_info.continue_id == new_edges[0] {
return true;
}
}
false
}
fn get_possible_merge_positions(
blocks: &[Block],
cf_info: &ControlFlowInfo,
@ -295,8 +347,9 @@ fn get_possible_merge_positions(
.unwrap();
new_edges.remove(index);
}
// Make sure we are not continuing after a merge.
if block_is_parent_of(loop_info.merge_id, start, blocks) && front == loop_info.merge_id
if block_is_parent_of(loop_info.header_id, start, blocks) && front == loop_info.merge_id
{
new_edges.clear();
}
@ -318,21 +371,22 @@ fn block_is_parent_of(parent: Word, child: Word, blocks: &[Block]) -> bool {
next.push_back(parent);
let mut processed = Vec::new();
processed.push(parent); // ensures we are not looping.
while let Some(front) = next.pop_front() {
let block_idx = find_block_index_from_id(blocks, &front);
let mut new_edges = outgoing_edges(&blocks[block_idx]);
if new_edges.contains(&child) {
return true;
}
for id in &processed {
if let Some(i) = new_edges.iter().position(|x| x == id) {
new_edges.remove(i);
}
}
if new_edges.contains(&child) {
return true;
}
processed.push(front);
next.extend(new_edges);
}
@ -340,12 +394,12 @@ fn block_is_parent_of(parent: Word, child: Word, blocks: &[Block]) -> bool {
false
}
// Returns the idx of the branch that loops and the idx to the block that branches to the original block.
// Returns the idx of the branch that loops.
fn get_looping_branch_from_block(
blocks: &[Block],
cf_info: &ControlFlowInfo,
start: Word,
) -> Option<(usize, usize)> {
) -> Option<usize> {
let mut next: VecDeque<Word> = VecDeque::new();
next.push_back(start);
@ -361,13 +415,15 @@ fn get_looping_branch_from_block(
let mut new_edges = outgoing_edges(&blocks[block_idx]);
let edge_it = new_edges.iter().find(|&x| x == &start); // Check if the new_edges contain the start
if let Some(edge_it) = edge_it {
// loop over the orginal edges to find which branch is looping
let start_edges = outgoing_edges(&blocks[find_block_index_from_id(blocks, &start)]);
if new_edges.len() == 1 {
if let Some(edge_it) = edge_it {
// loop over the orginal edges to find which branch is looping
let start_edges = outgoing_edges(&blocks[find_block_index_from_id(blocks, &start)]);
for (i, start_edge) in start_edges.iter().enumerate() {
if start_edge == edge_it || block_is_parent_of(*start_edge, *edge_it, blocks) {
return Some((i, block_idx));
for (i, start_edge) in start_edges.iter().enumerate() {
if start_edge == edge_it || block_is_parent_of(*start_edge, *edge_it, blocks) {
return Some(i);
}
}
}
}
@ -390,6 +446,11 @@ fn ends_in_branch_conditional(block: &Block) -> bool {
last_inst.class.opcode == Op::BranchConditional
}
fn ends_in_branch(block: &Block) -> bool {
let last_inst = block.instructions.last().unwrap();
last_inst.class.opcode == Op::Branch
}
fn ends_in_return(block: &Block) -> bool {
let last_inst = block.instructions.last().unwrap();
last_inst.class.opcode == Op::Return || last_inst.class.opcode == Op::ReturnValue
@ -432,6 +493,20 @@ fn split_block(header: &mut ModuleHeader, blocks: &mut Vec<Block>, block_to_spli
new_original_block_id
}
fn make_unreachable_block(header: &mut ModuleHeader, blocks: &mut Vec<Block>) -> Word {
let id = id(header);
let mut new_block = Block::new();
new_block.label = Some(Instruction::new(Op::Label, None, Some(id), vec![]));
// new block is unreachable
new_block
.instructions
.push(Instruction::new(Op::Unreachable, None, None, vec![]));
// insert new block at the end
blocks.push(new_block);
id
}
pub fn insert_selection_merge_on_conditional_branch(
sess: Option<&Session>,
header: &mut ModuleHeader,
@ -441,7 +516,7 @@ pub fn insert_selection_merge_on_conditional_branch(
let mut branch_conditional_ops = Vec::new();
// Find conditional branches that are not loops
for block in &blocks.clone() {
for block in blocks.iter() {
if ends_in_branch_conditional(block)
&& !cf_info.id_is_loops_header(block.label_id().unwrap())
{
@ -450,7 +525,7 @@ pub fn insert_selection_merge_on_conditional_branch(
}
// Find convergence point.
for id in branch_conditional_ops.clone() {
for id in branch_conditional_ops {
let bi = find_block_index_from_id(blocks, &id);
let out = outgoing_edges(&blocks[bi]);
let id = &blocks[bi].label_id().unwrap();
@ -487,41 +562,37 @@ pub fn insert_selection_merge_on_conditional_branch(
let branch_a_breaks = block_leads_into_break(blocks, cf_info, a_first_id);
let branch_b_breaks = block_leads_into_break(blocks, cf_info, b_first_id);
let branch_a_continues = false;
let branch_b_continues = false;
let branch_a_continues = block_leads_into_continue(blocks, cf_info, a_first_id);
let branch_b_continues = block_leads_into_continue(blocks, cf_info, b_first_id);
let branch_a_returns = ends_in_return(&blocks[a_last_idx]);
let branch_b_returns = ends_in_return(&blocks[b_last_idx]);
if (branch_a_breaks || branch_a_continues) && (branch_b_breaks || branch_b_continues) {
if ((branch_a_breaks || branch_a_continues) && (branch_b_breaks || branch_b_continues))
|| branch_a_returns && branch_b_returns
{
// (fully unreachable) insert a rando block and mark as merge.
if let Some(sess) = sess {
sess.err("UNIMPLEMENTED, A fully unreachable case was detected.");
}
return;
} else if branch_a_breaks || branch_a_continues {
make_unreachable_block(header, blocks)
} else if branch_a_breaks || branch_a_continues || branch_a_returns {
// (partially unreachable) merge block becomes branch b immediatly
blocks[b_first_idx].label_id().unwrap()
} else if branch_b_breaks || branch_a_continues {
} else if branch_b_breaks || branch_b_continues || branch_b_returns {
// (partially unreachable) merge block becomes branch a immediatly
blocks[a_first_idx].label_id().unwrap()
} else if branch_a_returns {
// (partially unreachable) merge block becomes end/start of b.
if let Some(sess) = sess {
sess.err("UNIMPLEMENTED, A partially unreachable case was detected on a.");
}
return;
emit_compiler_error(
sess,
"UNIMPLEMENTED, A partially unreachable case was detected on a.",
);
} else if branch_b_returns {
// (partially unreachable) merge block becomes end/start of a.
if let Some(sess) = sess {
sess.err("UNIMPLEMENTED, A partially unreachable case was detected on b.");
}
return;
emit_compiler_error(
sess,
"UNIMPLEMENTED, A partially unreachable case was detected on b.",
);
} else {
// (fully unreachable) insert a rando block and mark as merge.
if let Some(sess) = sess {
sess.err("UNIMPLEMENTED, A fully unreachable case was detected.");
}
return;
// In theory this should never happen.
emit_compiler_error(sess, "UNEXPECTED, Unknown exit detected.");
}
};
@ -572,12 +643,12 @@ pub fn insert_loop_merge_on_conditional_branch(
}
// Figure out which branch loops and which branch should merge, also find any potential break ops.
for (bi, (looping_branch_idx, continue_block_idx)) in branch_conditional_ops {
for (bi, looping_branch_idx) in branch_conditional_ops {
let merge_branch_idx = (looping_branch_idx + 1) % 2;
let id = &blocks[bi].label_id().unwrap();
let out = outgoing_edges(&blocks[bi]);
let continue_block_id = blocks[continue_block_idx].label_id().unwrap();
let continue_block_id = eliminate_multiple_continue_blocks(blocks, *id);
let merge_block_id = out[merge_branch_idx];
if cf_info.used(continue_block_id) {

View File

@ -103,6 +103,143 @@ pub fn main(i: Input<i32>) {
"#);
}
#[test]
fn cf_while_if_break_if_break() {
val(r#"
#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(i: Input<i32>) {
while i.load() < 10 {
if i.load() == 0 {
break;
}
if i.load() == 1 {
break;
}
}
}
"#);
}
#[test]
fn cf_while_while_continue() {
val(r#"
#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(i: Input<i32>) {
while i.load() < 20 {
while i.load() < 10 {
continue;
}
}
}
"#);
}
#[test]
fn cf_while_while_if_continue() {
val(r#"
#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(i: Input<i32>) {
while i.load() < 20 {
while i.load() < 10 {
if i.load() > 5 {
continue;
}
}
}
}
"#);
}
#[test]
fn cf_while_continue() {
val(r#"
#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(i: Input<i32>) {
while i.load() < 10 {
continue;
}
}
"#);
}
#[test]
fn cf_while_if_continue() {
val(r#"
#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(i: Input<i32>) {
while i.load() < 10 {
if i.load() == 0 {
continue;
}
}
}
"#);
}
#[test]
fn cf_while_if_continue_else_continue() {
val(r#"
#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(i: Input<i32>) {
while i.load() < 10 {
if i.load() == 0 {
continue;
} else {
continue;
}
}
}
"#);
}
#[test]
fn cf_while_return() {
val(r#"
#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(i: Input<i32>) {
while i.load() < 10 {
return;
}
}
"#);
}
#[test]
fn cf_if_return_else() {
val(r#"
#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(i: Input<i32>) {
if i.load() < 10 {
return;
} else {
}
}
"#);
}
#[test]
fn cf_if_return_else_return() {
val(r#"
#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(i: Input<i32>) {
if i.load() < 10 {
return;
} else {
return;
}
}
"#);
}
#[test]
fn cf_if_while() {
val(r#"