mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 06:45:13 +00:00
Don't treat OpSelectionMerge as an edge in mem2reg (#571)
* Don't treat OpSelectionMerge as an edge in mem2reg * Make outgoing_edges not allocate * Update test
This commit is contained in:
parent
4fa73bddb4
commit
f0f0f318ec
@ -5,7 +5,7 @@
|
||||
//! run mem2reg (see mem2reg.rs) on the result to "unwrap" the Function pointer.
|
||||
|
||||
use super::apply_rewrite_rules;
|
||||
use super::mem2reg::compute_preds;
|
||||
use super::simple_passes::outgoing_edges;
|
||||
use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};
|
||||
use rspirv::spirv::{FunctionControl, Op, StorageClass, Word};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@ -413,3 +413,17 @@ fn fuse_trivial_branches(function: &mut Function) {
|
||||
}
|
||||
function.blocks.retain(|b| !b.instructions.is_empty());
|
||||
}
|
||||
|
||||
fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
|
||||
let mut result = vec![vec![]; blocks.len()];
|
||||
for (source_idx, source) in blocks.iter().enumerate() {
|
||||
for dest_id in outgoing_edges(source) {
|
||||
let dest_idx = blocks
|
||||
.iter()
|
||||
.position(|b| b.label_id().unwrap() == dest_id)
|
||||
.unwrap();
|
||||
result[dest_idx].push(source_idx);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
@ -22,8 +22,9 @@ pub fn mem2reg(
|
||||
constants: &HashMap<Word, u32>,
|
||||
func: &mut Function,
|
||||
) {
|
||||
let preds = compute_preds(&func.blocks);
|
||||
let idom = compute_idom(&preds);
|
||||
let reachable = compute_reachable(&func.blocks);
|
||||
let preds = compute_preds(&func.blocks, &reachable);
|
||||
let idom = compute_idom(&preds, &reachable);
|
||||
let dominance_frontier = compute_dominance_frontier(&preds, &idom);
|
||||
insert_phis_all(
|
||||
header,
|
||||
@ -35,24 +36,38 @@ pub fn mem2reg(
|
||||
);
|
||||
}
|
||||
|
||||
pub fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
|
||||
let mut result = vec![vec![]; blocks.len()];
|
||||
for (source_idx, source) in blocks.iter().enumerate() {
|
||||
let mut edges = outgoing_edges(source);
|
||||
// HACK(eddyb) treat `OpSelectionMerge` as an edge, in case it points
|
||||
// to an otherwise-unreachable block.
|
||||
if let Some(before_last_idx) = source.instructions.len().checked_sub(2) {
|
||||
if let Some(before_last) = source.instructions.get(before_last_idx) {
|
||||
if before_last.class.opcode == Op::SelectionMerge {
|
||||
edges.push(before_last.operands[0].unwrap_id_ref());
|
||||
}
|
||||
fn label_to_index(blocks: &[Block], id: Word) -> usize {
|
||||
blocks
|
||||
.iter()
|
||||
.position(|b| b.label_id().unwrap() == id)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
|
||||
fn recurse(blocks: &[Block], reachable: &mut [bool], block: usize) {
|
||||
if !reachable[block] {
|
||||
reachable[block] = true;
|
||||
for dest_id in outgoing_edges(&blocks[block]) {
|
||||
let dest_idx = label_to_index(blocks, dest_id);
|
||||
recurse(blocks, reachable, dest_idx);
|
||||
}
|
||||
}
|
||||
for dest_id in edges {
|
||||
let dest_idx = blocks
|
||||
.iter()
|
||||
.position(|b| b.label_id().unwrap() == dest_id)
|
||||
.unwrap();
|
||||
}
|
||||
let mut reachable = vec![false; blocks.len()];
|
||||
recurse(blocks, &mut reachable, 0);
|
||||
reachable
|
||||
}
|
||||
|
||||
fn compute_preds(blocks: &[Block], reachable_blocks: &[bool]) -> Vec<Vec<usize>> {
|
||||
let mut result = vec![vec![]; blocks.len()];
|
||||
// Do not count unreachable blocks as valid preds of blocks
|
||||
for (source_idx, source) in blocks
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(b, _)| reachable_blocks[b])
|
||||
{
|
||||
for dest_id in outgoing_edges(source) {
|
||||
let dest_idx = label_to_index(blocks, dest_id);
|
||||
result[dest_idx].push(source_idx);
|
||||
}
|
||||
}
|
||||
@ -62,7 +77,8 @@ pub fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
|
||||
// Paper: A Simple, Fast Dominance Algorithm
|
||||
// https://www.cs.rice.edu/~keith/EMBED/dom.pdf
|
||||
// Note: requires nodes in reverse postorder
|
||||
fn compute_idom(preds: &[Vec<usize>]) -> Vec<usize> {
|
||||
// If a result is None, that means the block is unreachable, and therefore has no idom.
|
||||
fn compute_idom(preds: &[Vec<usize>], reachable_blocks: &[bool]) -> Vec<Option<usize>> {
|
||||
fn intersect(doms: &[Option<usize>], mut finger1: usize, mut finger2: usize) -> usize {
|
||||
// TODO: This may return an optional result?
|
||||
while finger1 != finger2 {
|
||||
@ -83,7 +99,8 @@ fn compute_idom(preds: &[Vec<usize>]) -> Vec<usize> {
|
||||
let mut changed = true;
|
||||
while changed {
|
||||
changed = false;
|
||||
for node in 1..(preds.len()) {
|
||||
// Unreachable blocks have no preds, and therefore no idom
|
||||
for node in (1..(preds.len())).filter(|&i| reachable_blocks[i]) {
|
||||
let mut new_idom: Option<usize> = None;
|
||||
for &pred in &preds[node] {
|
||||
if idom[pred].is_some() {
|
||||
@ -99,20 +116,25 @@ fn compute_idom(preds: &[Vec<usize>]) -> Vec<usize> {
|
||||
}
|
||||
}
|
||||
}
|
||||
idom.iter().map(|x| x.unwrap()).collect()
|
||||
assert!(idom
|
||||
.iter()
|
||||
.enumerate()
|
||||
.all(|(i, x)| x.is_some() == reachable_blocks[i]));
|
||||
idom
|
||||
}
|
||||
|
||||
// Same paper as above
|
||||
fn compute_dominance_frontier(preds: &[Vec<usize>], idom: &[usize]) -> Vec<HashSet<usize>> {
|
||||
fn compute_dominance_frontier(preds: &[Vec<usize>], idom: &[Option<usize>]) -> Vec<HashSet<usize>> {
|
||||
assert_eq!(preds.len(), idom.len());
|
||||
let mut dominance_frontier = vec![HashSet::new(); preds.len()];
|
||||
for node in 0..preds.len() {
|
||||
if preds[node].len() >= 2 {
|
||||
let node_idom = idom[node].unwrap();
|
||||
for &pred in &preds[node] {
|
||||
let mut runner = pred;
|
||||
while runner != idom[node] {
|
||||
while runner != node_idom {
|
||||
dominance_frontier[runner].insert(node);
|
||||
runner = idom[runner];
|
||||
runner = idom[runner].unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -442,13 +464,9 @@ impl Renamer<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
for dest_id in outgoing_edges(&self.blocks[block]) {
|
||||
for dest_id in outgoing_edges(&self.blocks[block]).collect::<Vec<_>>() {
|
||||
// TODO: Don't do this find
|
||||
let dest_idx = self
|
||||
.blocks
|
||||
.iter()
|
||||
.position(|b| b.label_id().unwrap() == dest_id)
|
||||
.unwrap();
|
||||
let dest_idx = label_to_index(self.blocks, dest_id);
|
||||
self.rename(dest_idx, Some(block));
|
||||
}
|
||||
|
||||
|
@ -561,7 +561,9 @@ impl Structurizer<'_> {
|
||||
}
|
||||
visited[block] = true;
|
||||
|
||||
for target in super::simple_passes::outgoing_edges(&self.func.blocks()[block]) {
|
||||
for target in
|
||||
super::simple_passes::outgoing_edges(&self.func.blocks()[block]).collect::<Vec<_>>()
|
||||
{
|
||||
self.post_order_step(self.block_id_to_idx[&target], visited, post_order)
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
use rspirv::dr::{Block, Function, Module};
|
||||
use rspirv::spirv::{Op, Word};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::iter::once;
|
||||
use std::mem::replace;
|
||||
|
||||
pub fn shift_ids(module: &mut Module, add: u32) {
|
||||
@ -44,7 +43,7 @@ pub fn block_ordering_pass(func: &mut Function) {
|
||||
.iter()
|
||||
.find(|b| b.label_id().unwrap() == current)
|
||||
.unwrap();
|
||||
let mut edges = outgoing_edges(current_block);
|
||||
let mut edges = outgoing_edges(current_block).collect::<Vec<_>>();
|
||||
// HACK(eddyb) treat `OpSelectionMerge` as an edge, in case it points
|
||||
// to an otherwise-unreachable block.
|
||||
if let Some(before_last_idx) = current_block.instructions.len().checked_sub(2) {
|
||||
@ -80,27 +79,17 @@ pub fn block_ordering_pass(func: &mut Function) {
|
||||
assert_eq!(func.blocks[0].label_id().unwrap(), entry_label);
|
||||
}
|
||||
|
||||
// FIXME(eddyb) use `Either`, `Cow`, and/or `SmallVec`.
|
||||
pub fn outgoing_edges(block: &Block) -> Vec<Word> {
|
||||
pub fn outgoing_edges(block: &Block) -> impl Iterator<Item = Word> + '_ {
|
||||
let terminator = block.instructions.last().unwrap();
|
||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Termination
|
||||
match terminator.class.opcode {
|
||||
Op::Branch => vec![terminator.operands[0].unwrap_id_ref()],
|
||||
Op::BranchConditional => vec![
|
||||
terminator.operands[1].unwrap_id_ref(),
|
||||
terminator.operands[2].unwrap_id_ref(),
|
||||
],
|
||||
Op::Switch => once(terminator.operands[1].unwrap_id_ref())
|
||||
.chain(
|
||||
terminator.operands[3..]
|
||||
.iter()
|
||||
.step_by(2)
|
||||
.map(|op| op.unwrap_id_ref()),
|
||||
)
|
||||
.collect(),
|
||||
Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => Vec::new(),
|
||||
let operand_indices = match terminator.class.opcode {
|
||||
Op::Branch => (0..1).step_by(1),
|
||||
Op::BranchConditional => (1..3).step_by(1),
|
||||
Op::Switch => (1..terminator.operands.len()).step_by(2),
|
||||
Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => (0..0).step_by(1),
|
||||
_ => panic!("Invalid block terminator: {:?}", terminator),
|
||||
}
|
||||
};
|
||||
operand_indices.map(move |i| terminator.operands[i].unwrap_id_ref())
|
||||
}
|
||||
|
||||
pub fn compact_ids(module: &mut Module) -> u32 {
|
||||
|
@ -208,7 +208,8 @@ fn retarget_loop_children_if_needed(builder: &mut Builder, cf_info: &ControlFlow
|
||||
|
||||
while let Some(front) = next.pop_front() {
|
||||
let block_idx = find_block_index_from_id(builder, &front);
|
||||
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx));
|
||||
let mut new_edges =
|
||||
outgoing_edges(get_block_ref!(builder, block_idx)).collect::<Vec<_>>();
|
||||
|
||||
// Make sure we are not looping or going into child loops.
|
||||
for loop_info in &cf_info.loops {
|
||||
@ -249,8 +250,7 @@ fn retarget_loop_children_if_needed(builder: &mut Builder, cf_info: &ControlFlow
|
||||
fn incoming_edges(id: Word, builder: &mut Builder) -> Vec<Word> {
|
||||
let mut incoming_edges = Vec::new();
|
||||
for block in get_blocks_ref(builder) {
|
||||
let out = outgoing_edges(block);
|
||||
if out.contains(&id) {
|
||||
if outgoing_edges(block).any(|x| x == id) {
|
||||
incoming_edges.push(block.label_id().unwrap());
|
||||
}
|
||||
}
|
||||
@ -292,8 +292,8 @@ fn defer_loop_internals(builder: &mut Builder, cf_info: &ControlFlowInfo) {
|
||||
// find all blocks that branch to a merge block.
|
||||
let mut possible_intermediate_block_idexes = Vec::new();
|
||||
for (i, block) in get_blocks_ref(builder).iter().enumerate() {
|
||||
let out = outgoing_edges(block);
|
||||
if out.len() == 1 && out[0] == loop_info.merge_id {
|
||||
let mut out = outgoing_edges(block);
|
||||
if out.next() == Some(loop_info.merge_id) && out.next() == None {
|
||||
possible_intermediate_block_idexes.push(i)
|
||||
}
|
||||
}
|
||||
@ -390,7 +390,7 @@ fn eliminate_multiple_continue_blocks(builder: &mut Builder, header: Word) -> Wo
|
||||
for block in get_blocks_ref(builder) {
|
||||
let block_id = block.label_id().unwrap();
|
||||
if ends_in_branch(block) {
|
||||
let edge = outgoing_edges(block)[0];
|
||||
let edge = outgoing_edges(block).next().unwrap();
|
||||
if edge == header && block_is_parent_of(builder, header, block_id) {
|
||||
continue_blocks.push(block_id);
|
||||
}
|
||||
@ -421,7 +421,7 @@ fn block_leads_into_break(builder: &Builder, cf_info: &ControlFlowInfo, start: W
|
||||
|
||||
while let Some(front) = next.pop_front() {
|
||||
let block_idx = find_block_index_from_id(builder, &front);
|
||||
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx));
|
||||
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::<Vec<_>>();
|
||||
|
||||
// Make sure we are not looping.
|
||||
for loop_info in &cf_info.loops {
|
||||
@ -456,7 +456,7 @@ fn block_leads_into_break(builder: &Builder, cf_info: &ControlFlowInfo, start: W
|
||||
|
||||
fn block_leads_into_continue(builder: &Builder, cf_info: &ControlFlowInfo, start: Word) -> bool {
|
||||
let start_idx = find_block_index_from_id(builder, &start);
|
||||
let new_edges = outgoing_edges(get_block_ref!(builder, start_idx));
|
||||
let new_edges = outgoing_edges(get_block_ref!(builder, start_idx)).collect::<Vec<_>>();
|
||||
for loop_info in &cf_info.loops {
|
||||
if new_edges.len() == 1 && loop_info.continue_id == new_edges[0] {
|
||||
return true;
|
||||
@ -485,7 +485,7 @@ fn block_is_reverse_idom_of(
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx));
|
||||
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::<Vec<_>>();
|
||||
|
||||
// Skip loop bodies by jumping to the merge block is we hit a header block.
|
||||
for loop_info in &cf_info.loops {
|
||||
@ -552,7 +552,7 @@ fn block_is_parent_of(builder: &Builder, parent: Word, child: Word) -> bool {
|
||||
|
||||
while let Some(front) = next.pop_front() {
|
||||
let block_idx = find_block_index_from_id(builder, &front);
|
||||
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx));
|
||||
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::<Vec<_>>();
|
||||
|
||||
for id in &processed {
|
||||
if let Some(i) = new_edges.iter().position(|x| x == id) {
|
||||
@ -589,7 +589,7 @@ fn get_looping_branch_from_block(
|
||||
}
|
||||
|
||||
let block_idx = find_block_index_from_id(builder, &front);
|
||||
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx));
|
||||
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::<Vec<_>>();
|
||||
|
||||
let edge_it = new_edges.iter().find(|&x| x == &start); // Check if the new_edges contain the start
|
||||
if new_edges.len() == 1 {
|
||||
@ -598,8 +598,8 @@ fn get_looping_branch_from_block(
|
||||
let start_idx = find_block_index_from_id(builder, &front);
|
||||
let start_edges = outgoing_edges(get_block_ref!(builder, start_idx));
|
||||
|
||||
for (i, start_edge) in start_edges.iter().enumerate() {
|
||||
if start_edge == edge_it || block_is_parent_of(builder, *start_edge, *edge_it) {
|
||||
for (i, start_edge) in start_edges.enumerate() {
|
||||
if start_edge == *edge_it || block_is_parent_of(builder, start_edge, *edge_it) {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
@ -696,7 +696,7 @@ pub fn insert_selection_merge_on_conditional_branch(
|
||||
};
|
||||
|
||||
let bi = find_block_index_from_id(builder, id);
|
||||
let out = outgoing_edges(&get_blocks_ref(builder)[bi]);
|
||||
let out = outgoing_edges(&get_blocks_ref(builder)[bi]).collect::<Vec<_>>();
|
||||
let id = idx_to_id(builder, bi);
|
||||
let a_nexts = get_possible_merge_positions(builder, cf_info, out[0]);
|
||||
let b_nexts = get_possible_merge_positions(builder, cf_info, out[1]);
|
||||
@ -809,7 +809,7 @@ pub fn insert_loop_merge_on_conditional_branch(
|
||||
|
||||
let merge_branch_idx = (looping_branch_idx + 1) % 2;
|
||||
let bi = find_block_index_from_id(builder, &id);
|
||||
let out = outgoing_edges(&get_blocks_ref(builder)[bi]);
|
||||
let out = outgoing_edges(&get_blocks_ref(builder)[bi]).collect::<Vec<_>>();
|
||||
|
||||
let continue_block_id = eliminate_multiple_continue_blocks(builder, id);
|
||||
let merge_block_id = out[merge_branch_idx];
|
||||
|
@ -242,14 +242,12 @@ OpSelectionMerge %24 None
|
||||
OpBranchConditional %22 %25 %26
|
||||
%25 = OpLabel
|
||||
%27 = OpIMul %2 %28 %14
|
||||
%29 = OpIAdd %2 %27 %5
|
||||
%30 = OpIAdd %10 %9 %31
|
||||
%15 = OpIAdd %2 %27 %5
|
||||
%12 = OpIAdd %10 %9 %29
|
||||
OpBranch %24
|
||||
%26 = OpLabel
|
||||
OpReturnValue %14
|
||||
%24 = OpLabel
|
||||
%12 = OpPhi %10 %30 %25
|
||||
%15 = OpPhi %2 %29 %25
|
||||
%19 = OpPhi %17 %18 %25
|
||||
OpBranch %13
|
||||
%13 = OpLabel
|
||||
|
Loading…
Reference in New Issue
Block a user