Make mem2reg support OpAccessChain

This commit is contained in:
khyperia 2020-10-13 18:00:04 +02:00
parent b7aa6f310d
commit d9b64497fc
4 changed files with 326 additions and 181 deletions

View File

@ -1,3 +1,9 @@
//! This algorithm is not intended to be an optimization, it is rather for legalization.
//! Specifically, spir-v disallows things like a StorageClass::Function pointer to a
//! StorageClass::Input pointer. Our frontend definitely allows it, though, this is like taking a
//! `&Input<T>` in a function! So, we inline all functions that take these "illegal" pointers, then
//! run mem2reg (see mem2reg.rs) on the result to "unwrap" the Function pointer.
use crate::mem2reg::compute_preds;
use crate::{apply_rewrite_rules, operand_idref};
use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};

View File

@ -185,26 +185,46 @@ pub fn link<T>(
{
let _timer = timer("link_block_ordering_pass_and_mem2reg");
let pointer_to_pointee = if opts.mem2reg {
output
.types_global_values
.iter()
.filter(|inst| inst.class.opcode == Op::TypePointer)
.map(|inst| {
(
inst.result_id.unwrap(),
operand_idref(&inst.operands[1]).unwrap(),
)
})
.collect()
} else {
Default::default()
};
let mut pointer_to_pointee = HashMap::new();
let mut constants = HashMap::new();
if opts.mem2reg {
let mut u32 = None;
for inst in &output.types_global_values {
match inst.class.opcode {
Op::TypePointer => {
pointer_to_pointee.insert(
inst.result_id.unwrap(),
operand_idref(&inst.operands[1]).unwrap(),
);
}
Op::TypeInt
if inst.operands[0] == Operand::LiteralInt32(32)
&& inst.operands[1] == Operand::LiteralInt32(0) =>
{
assert!(u32.is_none());
u32 = Some(inst.result_id.unwrap());
}
Op::Constant if u32.is_some() && inst.result_type == u32 => {
let value = match inst.operands[0] {
Operand::LiteralInt32(value) => value,
_ => panic!(),
};
constants.insert(inst.result_id.unwrap(), value);
}
_ => {}
}
}
}
for func in &mut output.functions {
simple_passes::block_ordering_pass(func);
if opts.mem2reg {
// Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
mem2reg::mem2reg(output.header.as_mut().unwrap(), &pointer_to_pointee, func);
mem2reg::mem2reg(
output.header.as_mut().unwrap(),
&pointer_to_pointee,
&constants,
func,
);
}
}
}

View File

@ -1,13 +1,24 @@
//! This algorithm is not intended to be an optimization, it is rather for legalization.
//! Specifically, spir-v disallows things like a StorageClass::Function pointer to a
//! StorageClass::Input pointer. Our frontend definitely allows it, though, this is like taking a
//! `&Input<T>` in a function! So, we inline all functions (see inline.rs) that take these
//! "illegal" pointers, then run mem2reg on the result to "unwrap" the Function pointer.
//!
//! Because it's merely a legalization pass, this computes "minimal" SSA form, *not* "pruned" SSA
//! form. The difference is that "minimal" may include extra phi nodes that aren't actually used
//! anywhere - we assume that later optimization passes will take care of these (relying on what
//! wikipedia calls "treat pruning as a dead code elimination problem").
use crate::simple_passes::outgoing_edges;
use crate::{apply_rewrite_rules, id, label_of, operand_idref};
use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand};
use rspirv::spirv::{Op, Word};
use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::{hash_map, HashMap, HashSet};
pub fn mem2reg(
header: &mut ModuleHeader,
pointer_to_pointee: &HashMap<Word, Word>,
constants: &HashMap<Word, u32>,
func: &mut Function,
) {
let preds = compute_preds(&func.blocks);
@ -16,6 +27,7 @@ pub fn mem2reg(
insert_phis_all(
header,
pointer_to_pointee,
constants,
&mut func.blocks,
dominance_frontier,
);
@ -96,6 +108,7 @@ fn compute_dominance_frontier(preds: &[Vec<usize>], idom: &[usize]) -> Vec<HashS
fn insert_phis_all(
header: &mut ModuleHeader,
pointer_to_pointee: &HashMap<Word, Word>,
constants: &HashMap<Word, u32>,
blocks: &mut [Block],
dominance_frontier: Vec<HashSet<usize>>,
) {
@ -105,206 +118,281 @@ fn insert_phis_all(
.filter(|inst| inst.class.opcode == Op::Variable)
.filter_map(|inst| {
let var = inst.result_id.unwrap();
if is_promotable(blocks, var) {
let var_type = *pointer_to_pointee.get(&inst.result_type.unwrap()).unwrap();
Some((var, var_type))
} else {
None
}
let var_ty = *pointer_to_pointee.get(&inst.result_type.unwrap()).unwrap();
Some((
collect_access_chains(pointer_to_pointee, constants, blocks, var, var_ty)?,
var_ty,
))
})
.collect::<Vec<_>>();
for &(var, var_type) in &thing {
insert_phis(header, blocks, &dominance_frontier, var, var_type);
for &(ref var_map, base_var_type) in &thing {
let blocks_with_phi = insert_phis(blocks, &dominance_frontier, var_map);
let mut renamer = Renamer {
header,
blocks,
blocks_with_phi,
base_var_type,
var_map,
phi_defs: HashSet::new(),
visited: HashSet::new(),
stack: Vec::new(),
rewrite_rules: HashMap::new(),
};
renamer.rename(0, None);
apply_rewrite_rules(&renamer.rewrite_rules, blocks);
remove_nops(blocks);
}
blocks[0].instructions.retain(|inst| {
inst.class.opcode != Op::Variable || {
let result_id = inst.result_id.unwrap();
thing.iter().all(|&(var, _)| var != result_id)
}
});
remove_old_variables(blocks, &thing);
}
fn is_promotable(blocks: &[Block], var: Word) -> bool {
for block in blocks {
for inst in &block.instructions {
#[derive(Debug)]
struct VarInfo {
// Type of the *dereferenced* variable.
ty: Word,
// OpAccessChain indexes off the base variable
indices: Vec<u32>,
}
fn collect_access_chains(
pointer_to_pointee: &HashMap<Word, Word>,
constants: &HashMap<Word, u32>,
blocks: &[Block],
base_var: Word,
base_var_ty: Word,
) -> Option<HashMap<Word, VarInfo>> {
fn construct_access_chain_info(
pointer_to_pointee: &HashMap<Word, Word>,
constants: &HashMap<Word, u32>,
inst: &Instruction,
base: &VarInfo,
) -> Option<VarInfo> {
Some(VarInfo {
ty: *pointer_to_pointee.get(&inst.result_type.unwrap()).unwrap(),
indices: {
let mut base_indicies = base.indices.clone();
for op in inst.operands.iter().skip(1) {
base_indicies.push(*constants.get(&operand_idref(op).unwrap())?)
}
base_indicies
},
})
}
let mut variables = HashMap::new();
variables.insert(
base_var,
VarInfo {
ty: base_var_ty,
indices: vec![],
},
);
// Loop in case a previous block references a later AccessChain
loop {
let mut changed = false;
for inst in blocks.iter().flat_map(|b| &b.instructions) {
for op in &inst.operands {
if let Operand::IdRef(id) = *op {
if id == var {
if let Operand::IdRef(id) = op {
if variables.contains_key(id) {
match inst.class.opcode {
Op::Load | Op::Store => {}
_ => return false,
Op::Load | Op::Store | Op::AccessChain | Op::InBoundsAccessChain => {}
_ => return None,
}
}
}
}
if let Op::AccessChain | Op::InBoundsAccessChain = inst.class.opcode {
if let Some(base) = variables.get(&operand_idref(&inst.operands[0]).unwrap()) {
let info =
construct_access_chain_info(pointer_to_pointee, constants, inst, base)?;
match variables.entry(inst.result_id.unwrap()) {
hash_map::Entry::Vacant(entry) => {
entry.insert(info);
changed = true;
}
hash_map::Entry::Occupied(_) => {}
}
}
}
}
if !changed {
break;
}
}
true
Some(variables)
}
// Returns the value for the definition.
fn find_last_store(block: &Block, var: Word) -> Option<Word> {
block.instructions.iter().rev().find_map(|inst| {
if inst.class.opcode == Op::Store && inst.operands[0] == Operand::IdRef(var)
|| inst.class.opcode == Op::Variable
&& inst.result_id == Some(var)
&& inst.operands.len() > 1
{
Some(operand_idref(&inst.operands[1]).unwrap())
} else {
None
}
fn has_store(block: &Block, var_map: &HashMap<Word, VarInfo>) -> bool {
block.instructions.iter().any(|inst| {
let ptr = match inst.class.opcode {
Op::Store => operand_idref(&inst.operands[0]).unwrap(),
Op::Variable if inst.operands.len() < 2 => return false,
Op::Variable => inst.result_id.unwrap(),
_ => return false,
};
var_map.contains_key(&ptr)
})
}
fn insert_phis(
header: &mut ModuleHeader,
blocks: &mut [Block],
blocks: &[Block],
dominance_frontier: &[HashSet<usize>],
var: Word,
var_type: Word,
) {
var_map: &HashMap<Word, VarInfo>,
) -> HashSet<usize> {
// TODO: Some algorithms check if the var is trivial in some way, e.g. all loads and stores are
// in a single block. We should probably do that too.
let mut ever_on_work_list = HashSet::new();
let mut work_list = Vec::new();
let mut phi_defs = HashSet::new();
let mut blocks_with_phi = HashSet::new();
for (block_idx, block) in blocks.iter().enumerate() {
if let Some(def) = find_last_store(block, var) {
if has_store(block, var_map) {
ever_on_work_list.insert(block_idx);
work_list.push((block_idx, def));
work_list.push(block_idx);
}
}
while let Some((x, def)) = work_list.pop() {
while let Some(x) = work_list.pop() {
for &y in &dominance_frontier[x] {
if let Some(new_def) = insert_phi(header, blocks, y, &mut phi_defs, var_type, x, def) {
if ever_on_work_list.insert(y) {
work_list.push((y, new_def))
if blocks_with_phi.insert(y) && ever_on_work_list.insert(y) {
work_list.push(y)
}
}
}
blocks_with_phi
}
struct Renamer<'a> {
header: &'a mut ModuleHeader,
blocks: &'a mut [Block],
blocks_with_phi: HashSet<usize>,
base_var_type: Word,
var_map: &'a HashMap<Word, VarInfo>,
phi_defs: HashSet<Word>,
visited: HashSet<usize>,
stack: Vec<Word>,
rewrite_rules: HashMap<Word, Word>,
}
impl Renamer<'_> {
// Returns the phi definition.
fn insert_phi_value(&mut self, block: usize, from_block: usize) -> Word {
let from_block_label = label_of(&self.blocks[from_block]);
let phi_defs = &self.phi_defs;
let existing_phi = self.blocks[block].instructions.iter_mut().find(|inst| {
inst.class.opcode == Op::Phi && phi_defs.contains(&inst.result_id.unwrap())
});
let top_def = *self.stack.last().unwrap();
match existing_phi {
None => {
let new_id = id(self.header);
self.blocks[block].instructions.insert(
0,
Instruction::new(
Op::Phi,
Some(self.base_var_type),
Some(new_id),
vec![Operand::IdRef(top_def), Operand::IdRef(from_block_label)],
),
);
self.phi_defs.insert(new_id);
new_id
}
Some(existing_phi) => {
existing_phi.operands.extend_from_slice(&[
Operand::IdRef(top_def),
Operand::IdRef(from_block_label),
]);
existing_phi.result_id.unwrap()
}
}
}
fn rename(&mut self, block: usize, from_block: Option<usize>) {
let original_stack = self.stack.len();
if let Some(from_block) = from_block {
if self.blocks_with_phi.contains(&block) {
let new_top = self.insert_phi_value(block, from_block);
self.stack.push(new_top);
}
}
if !self.visited.insert(block) {
while self.stack.len() > original_stack {
self.stack.pop();
}
return;
}
for inst in &mut self.blocks[block].instructions {
if inst.class.opcode == Op::Variable && inst.operands.len() > 1 {
let ptr = inst.result_id.unwrap();
let val = operand_idref(&inst.operands[1]).unwrap();
if let Some(var_info) = self.var_map.get(&ptr) {
assert_eq!(var_info.indices, []);
self.stack.push(val);
}
} else if inst.class.opcode == Op::Store {
let ptr = operand_idref(&inst.operands[0]).unwrap();
let val = operand_idref(&inst.operands[1]).unwrap();
if let Some(var_info) = self.var_map.get(&ptr) {
if var_info.indices.is_empty() {
*inst = Instruction::new(Op::Nop, None, None, vec![]);
self.stack.push(val);
} else {
let new_id = id(self.header);
let prev_comp = *self.stack.last().unwrap();
let mut operands = vec![Operand::IdRef(val), Operand::IdRef(prev_comp)];
operands
.extend(var_info.indices.iter().copied().map(Operand::LiteralInt32));
*inst = Instruction::new(
Op::CompositeInsert,
Some(self.base_var_type),
Some(new_id),
operands,
);
self.stack.push(new_id);
}
}
} else if inst.class.opcode == Op::Load {
let ptr = operand_idref(&inst.operands[0]).unwrap();
if let Some(var_info) = self.var_map.get(&ptr) {
let loaded_val = inst.result_id.unwrap();
let current_obj = *self.stack.last().unwrap();
if var_info.indices.is_empty() {
*inst = Instruction::new(Op::Nop, None, None, vec![]);
self.rewrite_rules.insert(loaded_val, current_obj);
} else {
let new_id = id(self.header);
let mut operands = vec![Operand::IdRef(current_obj)];
operands
.extend(var_info.indices.iter().copied().map(Operand::LiteralInt32));
*inst = Instruction::new(
Op::CompositeExtract,
Some(var_info.ty),
Some(new_id),
operands,
);
self.rewrite_rules.insert(loaded_val, new_id);
}
}
}
}
}
let mut rewrite_rules = HashMap::new();
rename(
header,
blocks,
0,
&phi_defs,
var,
&mut HashSet::new(),
&mut Vec::new(),
&mut rewrite_rules,
);
apply_rewrite_rules(&rewrite_rules, blocks);
remove_nops(blocks);
}
// Returns the newly created phi definition.
fn insert_phi(
header: &mut ModuleHeader,
blocks: &mut [Block],
block: usize,
phi_defs: &mut HashSet<Word>,
var_type: Word,
from_block: usize,
def: Word,
) -> Option<Word> {
let from_block_label = label_of(&blocks[from_block]);
let existing_phi = blocks[block]
.instructions
.iter_mut()
.find(|inst| inst.class.opcode == Op::Phi && phi_defs.contains(&inst.result_id.unwrap()));
match existing_phi {
None => {
let new_id = id(header);
blocks[block].instructions.insert(
0,
Instruction::new(
Op::Phi,
Some(var_type),
Some(new_id),
vec![Operand::IdRef(def), Operand::IdRef(from_block_label)],
),
);
phi_defs.insert(new_id);
Some(new_id)
for dest_id in outgoing_edges(&self.blocks[block]) {
// TODO: Don't do this find
let dest_idx = self
.blocks
.iter()
.position(|b| label_of(b) == dest_id)
.unwrap();
self.rename(dest_idx, Some(block));
}
Some(existing_phi) => {
existing_phi
.operands
.extend_from_slice(&[Operand::IdRef(def), Operand::IdRef(from_block_label)]);
None
while self.stack.len() > original_stack {
self.stack.pop();
}
}
}
#[allow(clippy::too_many_arguments)]
fn rename(
header: &mut ModuleHeader,
blocks: &mut [Block],
block: usize,
phi_defs: &HashSet<Word>,
var: Word,
visited: &mut HashSet<usize>,
stack: &mut Vec<Word>,
rewrite_rules: &mut HashMap<Word, Word>,
) {
if !visited.insert(block) {
return;
}
let original_stack = stack.len();
for inst in &mut blocks[block].instructions {
if inst.class.opcode == Op::Phi {
let result_id = inst.result_id.unwrap();
if phi_defs.contains(&result_id) {
stack.push(result_id);
}
} else if inst.class.opcode == Op::Variable && inst.operands.len() > 1 {
let ptr = inst.result_id.unwrap();
let val = operand_idref(&inst.operands[1]).unwrap();
if ptr == var {
stack.push(val);
}
} else if inst.class.opcode == Op::Store {
let ptr = operand_idref(&inst.operands[0]).unwrap();
let val = operand_idref(&inst.operands[1]).unwrap();
if ptr == var {
stack.push(val);
*inst = Instruction::new(Op::Nop, None, None, vec![]);
}
} else if inst.class.opcode == Op::Load {
let ptr = operand_idref(&inst.operands[0]).unwrap();
let val = inst.result_id.unwrap();
if ptr == var {
rewrite_rules.insert(val, *stack.last().unwrap());
*inst = Instruction::new(Op::Nop, None, None, vec![]);
}
}
}
for dest_id in outgoing_edges(&blocks[block]) {
// TODO: Don't do this find
let dest_idx = blocks.iter().position(|b| label_of(b) == dest_id).unwrap();
rename(
header,
blocks,
dest_idx,
phi_defs,
var,
visited,
stack,
rewrite_rules,
);
}
while stack.len() > original_stack {
stack.pop();
}
}
fn remove_nops(blocks: &mut [Block]) {
for block in blocks {
block
@ -312,3 +400,24 @@ fn remove_nops(blocks: &mut [Block]) {
.retain(|inst| inst.class.opcode != Op::Nop);
}
}
fn remove_old_variables(blocks: &mut [Block], thing: &[(HashMap<u32, VarInfo>, u32)]) {
blocks[0].instructions.retain(|inst| {
inst.class.opcode != Op::Variable || {
let result_id = inst.result_id.unwrap();
thing
.iter()
.all(|(var_map, _)| !var_map.contains_key(&result_id))
}
});
for block in blocks {
block.instructions.retain(|inst| {
!matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain)
|| inst.operands.iter().all(|op| {
operand_idref(op).map_or(true, |id| {
thing.iter().all(|(var_map, _)| !var_map.contains_key(&id))
})
})
})
}
}

View File

@ -632,6 +632,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
pointee: ty,
}
.def(self);
let undef_ty = self.undef(ty);
// "All OpVariable instructions in a function must be the first instructions in the first block."
let mut builder = self.emit();
builder.select_block(Some(0)).unwrap();
@ -657,7 +658,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
Op::Variable,
Some(ptr_ty),
Some(result_id),
vec![Operand::StorageClass(StorageClass::Function)],
vec![
Operand::StorageClass(StorageClass::Function),
// TODO: Always include an undef initializer, because spir-v does not specify the
// value of an uninitialized variable. So, initialize it to undef.
// See #17 for tracking spec'ing this in spir-v:
// https://github.com/EmbarkStudios/rust-gpu/issues/17
// (This also helps out rspirv-linker/mem2reg.rs in some places - it gives it a
// handy source of a deduped OpUndef so it doesn't have to do that work itself)
Operand::IdRef(undef_ty.def),
],
);
builder.insert_into_block(index, inst).unwrap();
result_id.with_type(ptr_ty)