mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-25 08:14:12 +00:00
Structurizer fixes (#244)
* Structurizer fixes * reverted some unnessessary changes
This commit is contained in:
parent
2d75e0473f
commit
daa382368c
@ -1,18 +1,11 @@
|
||||
// This pass inserts merge instructions for structured control flow with the assumption the spir-v is reducible.
|
||||
|
||||
// TODO: Could i simplify break detection by just checking, hey does the start branch branch to a merge block?
|
||||
// TODO: Verify we are never splitting a block that is queued for structurization.
|
||||
// TODO: are there any cases where I need to retarget branches or conditional branches when splitting a block?
|
||||
|
||||
use super::id;
|
||||
use super::simple_passes::outgoing_edges;
|
||||
use rspirv::spirv::{Op, Word};
|
||||
use rspirv::{
|
||||
dr::{Block, Instruction, Module, ModuleHeader, Operand},
|
||||
spirv::SelectionControl,
|
||||
};
|
||||
use rspirv::dr::{Block, Instruction, Module, ModuleHeader, Operand};
|
||||
use rspirv::spirv::{Op, SelectionControl, Word};
|
||||
use rustc_session::Session;
|
||||
use std::collections::VecDeque;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
|
||||
pub struct LoopInfo {
|
||||
merge_id: Word,
|
||||
@ -153,6 +146,14 @@ pub fn structurize(sess: &Session, module: &mut Module) {
|
||||
&mut cf_info,
|
||||
);
|
||||
|
||||
defer_loop_internals(
|
||||
&mut module.header.as_mut().unwrap(),
|
||||
&mut func.blocks,
|
||||
&cf_info,
|
||||
&mut debug_names,
|
||||
&mut module.types_global_values,
|
||||
);
|
||||
|
||||
debug_names.extend(cf_info.get_debug_names());
|
||||
}
|
||||
|
||||
@ -228,6 +229,217 @@ fn retarget_loop_children_if_needed(blocks: &mut [Block], cf_info: &ControlFlowI
|
||||
}
|
||||
}
|
||||
|
||||
fn incoming_edges(id: Word, blocks: &[Block]) -> Vec<Word> {
|
||||
let mut incoming_edges = Vec::new();
|
||||
for block in blocks {
|
||||
let out = outgoing_edges(block);
|
||||
if out.contains(&id) {
|
||||
incoming_edges.push(block.label_id().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
incoming_edges
|
||||
}
|
||||
|
||||
fn num_incoming_edges(id: Word, blocks: &[Block]) -> usize {
|
||||
incoming_edges(id, blocks).len()
|
||||
}
|
||||
|
||||
// Turn a block into a conditional branch that either goes to yes or goes to merge.
|
||||
fn change_block_to_switch(
|
||||
blocks: &mut [Block],
|
||||
block_id: Word,
|
||||
cases: &[Word],
|
||||
merge: Word,
|
||||
condition: Word,
|
||||
) {
|
||||
let cb_idx = find_block_index_from_id(blocks, &block_id);
|
||||
let cb_block = &mut blocks[cb_idx];
|
||||
|
||||
// selection merge
|
||||
let merge_operands = vec![
|
||||
Operand::IdRef(merge),
|
||||
Operand::SelectionControl(SelectionControl::NONE),
|
||||
];
|
||||
*cb_block.instructions.last_mut().unwrap() =
|
||||
Instruction::new(Op::SelectionMerge, None, None, merge_operands);
|
||||
|
||||
// conditional branch
|
||||
let mut switch_operands = vec![Operand::IdRef(condition), Operand::IdRef(merge)];
|
||||
|
||||
for (i, label_id) in cases.iter().enumerate() {
|
||||
let literal = i as u32 + 1;
|
||||
switch_operands.push(Operand::LiteralInt32(literal));
|
||||
switch_operands.push(Operand::IdRef(*label_id));
|
||||
}
|
||||
|
||||
cb_block
|
||||
.instructions
|
||||
.push(Instruction::new(Op::Switch, None, None, switch_operands));
|
||||
}
|
||||
|
||||
fn find_or_create_int_constant(
|
||||
opcode: Op,
|
||||
constants: &mut Vec<Instruction>,
|
||||
header: &mut ModuleHeader,
|
||||
type_result_id: Word,
|
||||
value: u32,
|
||||
) -> Word {
|
||||
// create
|
||||
let result_id = id(header);
|
||||
let new_constant = Instruction::new(
|
||||
opcode,
|
||||
Some(type_result_id),
|
||||
Some(result_id),
|
||||
vec![Operand::LiteralInt32(value)],
|
||||
);
|
||||
constants.push(new_constant);
|
||||
|
||||
result_id
|
||||
}
|
||||
|
||||
fn find_or_create_type_constant(
|
||||
opcode: Op,
|
||||
constants: &mut Vec<Instruction>,
|
||||
header: &mut ModuleHeader,
|
||||
) -> Word {
|
||||
// find
|
||||
for constant in constants.iter() {
|
||||
if constant.class.opcode == opcode {
|
||||
return constant.result_id.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// create
|
||||
let result_id = id(header);
|
||||
let new_constant = Instruction::new(opcode, None, Some(result_id), vec![]);
|
||||
constants.push(new_constant);
|
||||
|
||||
result_id
|
||||
}
|
||||
|
||||
// detect the intermediate break block by checking whether a block that branches to a merge block has 2 parents.
|
||||
fn defer_loop_internals(
|
||||
header: &mut ModuleHeader,
|
||||
blocks: &mut Vec<Block>,
|
||||
cf_info: &ControlFlowInfo,
|
||||
debug_names: &mut Vec<(Word, String)>,
|
||||
constants: &mut Vec<Instruction>,
|
||||
) {
|
||||
for loop_info in &cf_info.loops {
|
||||
// find all blocks that branch to a merge block.
|
||||
let mut possible_intermediate_block_idexes = Vec::new();
|
||||
for (i, block) in blocks.iter().enumerate() {
|
||||
let out = outgoing_edges(block);
|
||||
if out.len() == 1 && out[0] == loop_info.merge_id {
|
||||
possible_intermediate_block_idexes.push(i)
|
||||
}
|
||||
}
|
||||
// check how many incoming edges the branch has and use that to collect a list of intermediate blocks.
|
||||
let mut intermediate_blocks = Vec::new();
|
||||
for i in possible_intermediate_block_idexes {
|
||||
let intermediate_block = &blocks[i];
|
||||
let num_incoming_edges =
|
||||
num_incoming_edges(intermediate_block.label_id().unwrap(), blocks);
|
||||
if num_incoming_edges > 1 {
|
||||
intermediate_blocks.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
if !intermediate_blocks.is_empty() {
|
||||
let intermediate_block_ids: Vec<Word> = intermediate_blocks
|
||||
.iter()
|
||||
.map(|idx| blocks[*idx].label_id().unwrap())
|
||||
.collect();
|
||||
|
||||
// Create a new empty block.
|
||||
let old_merge_block_id = split_block(header, blocks, loop_info.merge_id, false);
|
||||
|
||||
// Create Phi
|
||||
let phi_result_id = id(header);
|
||||
let int_type_id = find_or_create_type_constant(Op::TypeInt, constants, header);
|
||||
let const_0_id =
|
||||
find_or_create_int_constant(Op::Constant, constants, header, int_type_id, 0);
|
||||
|
||||
let mut phi_operands = vec![];
|
||||
for (intermediate_i, intermediate_block_id) in intermediate_block_ids.iter().enumerate()
|
||||
{
|
||||
let intermediate_i = intermediate_i as u32 + 1;
|
||||
let intermediate_block_idx =
|
||||
find_block_index_from_id(blocks, intermediate_block_id);
|
||||
let intermediate_block_id = blocks[intermediate_block_idx].label_id().unwrap();
|
||||
let const_x_id = find_or_create_int_constant(
|
||||
Op::Constant,
|
||||
constants,
|
||||
header,
|
||||
int_type_id,
|
||||
intermediate_i,
|
||||
);
|
||||
let t = incoming_edges(intermediate_block_id, blocks);
|
||||
for blocks_that_go_to_intermediate in t {
|
||||
phi_operands.push(Operand::IdRef(const_x_id));
|
||||
phi_operands.push(Operand::IdRef(blocks_that_go_to_intermediate));
|
||||
}
|
||||
debug_names.push((intermediate_block_id, "deferred".to_string()));
|
||||
}
|
||||
phi_operands.push(Operand::IdRef(const_0_id));
|
||||
phi_operands.push(Operand::IdRef(loop_info.header_id));
|
||||
let phi_inst = Instruction::new(
|
||||
Op::Phi,
|
||||
Some(int_type_id),
|
||||
Some(phi_result_id),
|
||||
phi_operands,
|
||||
);
|
||||
|
||||
// add phi to the empty merge block.
|
||||
{
|
||||
let merge_block_idx = find_block_index_from_id(blocks, &loop_info.merge_id);
|
||||
let merge_block = &mut blocks[merge_block_idx];
|
||||
merge_block.instructions.insert(0, phi_inst);
|
||||
}
|
||||
|
||||
// point all intermediate blocks to the new empty merge block.
|
||||
for intermediate_block_id in intermediate_block_ids.iter() {
|
||||
for incoming_id in incoming_edges(*intermediate_block_id, blocks) {
|
||||
let incoming_idx = find_block_index_from_id(blocks, &incoming_id);
|
||||
let incoming_block = &mut blocks[incoming_idx];
|
||||
|
||||
for operand in &mut incoming_block.instructions.last_mut().unwrap().operands {
|
||||
if *operand == Operand::IdRef(*intermediate_block_id) {
|
||||
*operand = Operand::IdRef(loop_info.merge_id); // loop_info.merge_id is the same block as the new empty block from the last step.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a switch statement of all intermediate blocks.
|
||||
change_block_to_switch(
|
||||
blocks,
|
||||
loop_info.merge_id,
|
||||
&intermediate_block_ids,
|
||||
old_merge_block_id,
|
||||
phi_result_id,
|
||||
);
|
||||
|
||||
// point intermediate blocks to the old merge block.
|
||||
for intermediate_block_id in intermediate_block_ids.iter() {
|
||||
let intermediate_block_idx =
|
||||
find_block_index_from_id(blocks, intermediate_block_id);
|
||||
for operand in &mut blocks[intermediate_block_idx]
|
||||
.instructions
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.operands
|
||||
{
|
||||
if *operand == Operand::IdRef(loop_info.merge_id) {
|
||||
*operand = Operand::IdRef(old_merge_block_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// "Combines" all continue blocks into 1 and returns the ID of the continue block.
|
||||
fn eliminate_multiple_continue_blocks(blocks: &mut Vec<Block>, header: Word) -> Word {
|
||||
// Find all possible continue blocks.
|
||||
@ -310,27 +522,34 @@ fn block_leads_into_continue(blocks: &[Block], cf_info: &ControlFlowInfo, start:
|
||||
false
|
||||
}
|
||||
|
||||
fn get_possible_merge_positions(
|
||||
blocks: &[Block],
|
||||
cf_info: &ControlFlowInfo,
|
||||
start: Word,
|
||||
) -> Vec<usize> {
|
||||
let mut retval = Vec::new();
|
||||
// every branch from a reaches b.
|
||||
fn block_is_reverse_idom_of(blocks: &[Block], cf_info: &ControlFlowInfo, a: Word, b: Word) -> bool {
|
||||
let mut next: VecDeque<Word> = VecDeque::new();
|
||||
next.push_back(start);
|
||||
next.push_back(a);
|
||||
|
||||
let mut processed = Vec::new();
|
||||
processed.push(a); // ensures we are not looping.
|
||||
|
||||
while let Some(front) = next.pop_front() {
|
||||
let block_idx = find_block_index_from_id(blocks, &front);
|
||||
|
||||
if front == b {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut new_edges = outgoing_edges(&blocks[block_idx]);
|
||||
|
||||
// Don't queue the start block if its a edge
|
||||
if let Some(i) = new_edges.iter().position(|x| *x == start) {
|
||||
new_edges.remove(i);
|
||||
// Skip loop bodies by jumping to the merge block is we hit a header block.
|
||||
for loop_info in &cf_info.loops {
|
||||
if front == loop_info.header_id {
|
||||
// TODO: should only do this for children i guess.
|
||||
new_edges = vec![loop_info.merge_id];
|
||||
}
|
||||
}
|
||||
|
||||
for loop_info in &cf_info.loops {
|
||||
// Make sure we are not looping.
|
||||
if block_is_parent_of(loop_info.header_id, start, blocks)
|
||||
if block_is_parent_of(loop_info.header_id, a, blocks)
|
||||
&& new_edges.contains(&loop_info.header_id)
|
||||
{
|
||||
let index = new_edges
|
||||
@ -341,20 +560,39 @@ fn get_possible_merge_positions(
|
||||
}
|
||||
|
||||
// Make sure we are not continuing after a merge.
|
||||
if block_is_parent_of(loop_info.header_id, start, blocks) && front == loop_info.merge_id
|
||||
{
|
||||
if block_is_parent_of(loop_info.header_id, a, blocks) && front == loop_info.merge_id {
|
||||
new_edges.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// We found a possible merge position, make sure it isn't a merge of a loop because in that case we want to use break logic.
|
||||
if new_edges.len() == 1 && !cf_info.id_is_loops_merge(new_edges[0]) {
|
||||
retval.push(find_block_index_from_id(blocks, &new_edges[0]));
|
||||
if new_edges.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
for id in &processed {
|
||||
if let Some(i) = new_edges.iter().position(|x| x == id) {
|
||||
new_edges.remove(i);
|
||||
}
|
||||
}
|
||||
|
||||
processed.push(front);
|
||||
next.extend(new_edges);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
fn get_possible_merge_positions(
|
||||
blocks: &[Block],
|
||||
cf_info: &ControlFlowInfo,
|
||||
start: Word,
|
||||
) -> Vec<usize> {
|
||||
let mut retval = Vec::new();
|
||||
for (idx, block) in blocks.iter().enumerate() {
|
||||
if block_is_reverse_idom_of(blocks, cf_info, start, block.label_id().unwrap()) {
|
||||
retval.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
retval
|
||||
}
|
||||
|
||||
@ -449,7 +687,12 @@ fn ends_in_return(block: &Block) -> bool {
|
||||
}
|
||||
|
||||
// Returns the new id assigned to the original block.
|
||||
fn split_block(header: &mut ModuleHeader, blocks: &mut Vec<Block>, block_to_split: Word) -> Word {
|
||||
fn split_block(
|
||||
header: &mut ModuleHeader,
|
||||
blocks: &mut Vec<Block>,
|
||||
block_to_split: Word,
|
||||
retarget: bool,
|
||||
) -> Word {
|
||||
// create new block with old id.
|
||||
let block_to_split_index = find_block_index_from_id(blocks, &block_to_split);
|
||||
let orignial_block = &mut blocks[block_to_split_index];
|
||||
@ -467,13 +710,15 @@ fn split_block(header: &mut ModuleHeader, blocks: &mut Vec<Block>, block_to_spli
|
||||
vec![Operand::IdRef(new_original_block_id)],
|
||||
);
|
||||
new_block.instructions.push(branch_inst);
|
||||
// update all merge ops to point the the old block with its new id.
|
||||
for block in blocks.iter_mut() {
|
||||
for inst in &mut block.instructions {
|
||||
if inst.class.opcode == Op::LoopMerge || inst.class.opcode == Op::SelectionMerge {
|
||||
for operand in &mut inst.operands {
|
||||
if *operand == Operand::IdRef(original_id) {
|
||||
*operand = Operand::IdRef(new_original_block_id);
|
||||
if retarget {
|
||||
// update all merge ops to point the the old block with its new id.
|
||||
for block in blocks.iter_mut() {
|
||||
for inst in &mut block.instructions {
|
||||
if inst.class.opcode == Op::LoopMerge || inst.class.opcode == Op::SelectionMerge {
|
||||
for operand in &mut inst.operands {
|
||||
if *operand == Operand::IdRef(original_id) {
|
||||
*operand = Operand::IdRef(new_original_block_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -516,9 +761,16 @@ pub fn insert_selection_merge_on_conditional_branch(
|
||||
}
|
||||
}
|
||||
|
||||
let mut modified_ids = HashMap::new();
|
||||
|
||||
// Find convergence point.
|
||||
for id in branch_conditional_ops {
|
||||
let bi = find_block_index_from_id(blocks, &id);
|
||||
for id in branch_conditional_ops.iter() {
|
||||
let id = match modified_ids.get_key_value(id) {
|
||||
Some((_, value)) => value,
|
||||
None => id,
|
||||
};
|
||||
|
||||
let bi = find_block_index_from_id(blocks, id);
|
||||
let out = outgoing_edges(&blocks[bi]);
|
||||
let id = &blocks[bi].label_id().unwrap();
|
||||
let a_nexts = get_possible_merge_positions(blocks, cf_info, out[0]);
|
||||
@ -583,8 +835,12 @@ pub fn insert_selection_merge_on_conditional_branch(
|
||||
};
|
||||
|
||||
if cf_info.used(merge_block_id) {
|
||||
let new_id = split_block(header, blocks, merge_block_id);
|
||||
let new_id = split_block(header, blocks, merge_block_id, true);
|
||||
cf_info.retarget(merge_block_id, new_id);
|
||||
|
||||
if branch_conditional_ops.contains(&merge_block_id) {
|
||||
modified_ids.insert(merge_block_id, new_id);
|
||||
}
|
||||
}
|
||||
|
||||
let merge_operands = vec![
|
||||
@ -612,13 +868,13 @@ pub fn insert_loop_merge_on_conditional_branch(
|
||||
let mut branch_conditional_ops = Vec::new();
|
||||
|
||||
// Find conditional branches that are loops, and find which branch is the one that loops.
|
||||
for (bi, block) in blocks.iter().enumerate() {
|
||||
for block in blocks.iter() {
|
||||
if ends_in_branch_conditional(block) {
|
||||
let block_id = block.label_id().unwrap();
|
||||
if let Some(looping_branch_idx_and_block_idx) =
|
||||
if let Some(looping_branch_idx) =
|
||||
get_looping_branch_from_block(blocks, cf_info, block_id)
|
||||
{
|
||||
branch_conditional_ops.push((bi, looping_branch_idx_and_block_idx));
|
||||
branch_conditional_ops.push((block_id, looping_branch_idx));
|
||||
cf_info.loops.push(LoopInfo {
|
||||
header_id: block_id,
|
||||
merge_id: 0,
|
||||
@ -628,25 +884,40 @@ pub fn insert_loop_merge_on_conditional_branch(
|
||||
}
|
||||
}
|
||||
|
||||
let mut modified_ids = HashMap::new();
|
||||
|
||||
// Figure out which branch loops and which branch should merge, also find any potential break ops.
|
||||
for (bi, looping_branch_idx) in branch_conditional_ops {
|
||||
for (id, looping_branch_idx) in branch_conditional_ops.iter() {
|
||||
let id = match modified_ids.get_key_value(id) {
|
||||
Some((_, value)) => *value,
|
||||
None => *id,
|
||||
};
|
||||
|
||||
let merge_branch_idx = (looping_branch_idx + 1) % 2;
|
||||
let id = &blocks[bi].label_id().unwrap();
|
||||
let bi = find_block_index_from_id(blocks, &id);
|
||||
let out = outgoing_edges(&blocks[bi]);
|
||||
|
||||
let continue_block_id = eliminate_multiple_continue_blocks(blocks, *id);
|
||||
let continue_block_id = eliminate_multiple_continue_blocks(blocks, id);
|
||||
let merge_block_id = out[merge_branch_idx];
|
||||
|
||||
if cf_info.used(continue_block_id) {
|
||||
let new_id = split_block(header, blocks, continue_block_id);
|
||||
let new_id = split_block(header, blocks, continue_block_id, true);
|
||||
cf_info.retarget(continue_block_id, new_id);
|
||||
|
||||
if branch_conditional_ops.contains(&(continue_block_id, *looping_branch_idx)) {
|
||||
modified_ids.insert(continue_block_id, new_id);
|
||||
}
|
||||
}
|
||||
if cf_info.used(merge_block_id) {
|
||||
let new_id = split_block(header, blocks, merge_block_id);
|
||||
let new_id = split_block(header, blocks, merge_block_id, true);
|
||||
cf_info.retarget(merge_block_id, new_id);
|
||||
|
||||
if branch_conditional_ops.contains(&(merge_block_id, *looping_branch_idx)) {
|
||||
modified_ids.insert(merge_block_id, new_id);
|
||||
}
|
||||
}
|
||||
|
||||
let bi = find_block_index_from_id(blocks, id); // after this we don't insert or remove blocks
|
||||
let bi = find_block_index_from_id(blocks, &id); // after this we don't insert or remove blocks
|
||||
let check_block = &mut blocks[bi];
|
||||
|
||||
let merge_operands = vec![
|
||||
@ -655,7 +926,7 @@ pub fn insert_loop_merge_on_conditional_branch(
|
||||
Operand::SelectionControl(SelectionControl::NONE),
|
||||
];
|
||||
|
||||
cf_info.set_loops_continue_and_merge(*id, merge_block_id, continue_block_id);
|
||||
cf_info.set_loops_continue_and_merge(id, merge_block_id, continue_block_id);
|
||||
|
||||
// Insert the merge instruction
|
||||
let merge_inst = Instruction::new(Op::LoopMerge, None, None, merge_operands);
|
||||
|
Loading…
Reference in New Issue
Block a user