Implement bool fusion pass (#776)

Fixes #677
This commit is contained in:
Ashley Hauck 2021-10-27 14:52:31 +02:00 committed by GitHub
parent f58c6f20af
commit c10a1ca756
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 219 additions and 1 deletions

View File

@ -266,6 +266,7 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
for func in &mut output.functions {
peephole_opts::composite_construct(&types, func);
peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
}
}

View File

@ -1,6 +1,7 @@
use super::id;
use rspirv::dr::{Function, Instruction, Module, ModuleHeader, Operand};
use rspirv::spirv::{Op, Word};
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_middle::bug;
pub fn collect_types(module: &Module) -> FxHashMap<Word, Instruction> {
@ -447,3 +448,168 @@ pub fn vector_ops(
}
}
}
fn can_fuse_bool(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, (usize, Instruction)>,
inst: &Instruction,
) -> bool {
fn constant_value(types: &FxHashMap<Word, Instruction>, val: Word) -> Option<u32> {
let inst = match types.get(&val) {
None => return None,
Some(inst) => inst,
};
if inst.class.opcode != Op::Constant {
return None;
}
match inst.operands[0] {
Operand::LiteralInt32(v) => Some(v),
_ => None,
}
}
fn visit(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, (usize, Instruction)>,
visited: &mut FxHashSet<Word>,
value: Word,
) -> bool {
if visited.insert(value) {
let inst = match defs.get(&value) {
Some((_, inst)) => inst,
None => return false,
};
match inst.class.opcode {
Op::Select => {
constant_value(types, inst.operands[1].unwrap_id_ref()) == Some(1)
&& constant_value(types, inst.operands[2].unwrap_id_ref()) == Some(0)
}
Op::Phi => inst
.operands
.iter()
.step_by(2)
.all(|op| visit(types, defs, visited, op.unwrap_id_ref())),
_ => false,
}
} else {
true
}
}
if inst.class.opcode != Op::INotEqual
|| constant_value(types, inst.operands[1].unwrap_id_ref()) != Some(0)
{
return false;
}
let int_value = inst.operands[0].unwrap_id_ref();
visit(types, defs, &mut FxHashSet::default(), int_value)
}
fn fuse_bool(
header: &mut ModuleHeader,
defs: &FxHashMap<Word, (usize, Instruction)>,
phis_to_insert: &mut Vec<(usize, Instruction)>,
already_mapped: &mut FxHashMap<Word, Word>,
bool_ty: Word,
int_value: Word,
) -> Word {
if let Some(&result) = already_mapped.get(&int_value) {
return result;
}
let (block_of_inst, inst) = defs.get(&int_value).unwrap();
match inst.class.opcode {
Op::Select => inst.operands[0].unwrap_id_ref(),
Op::Phi => {
let result_id = id(header);
already_mapped.insert(int_value, result_id);
let new_phi_args = inst
.operands
.chunks(2)
.flat_map(|arr| {
let phi_value = &arr[0];
let block = &arr[1];
[
Operand::IdRef(fuse_bool(
header,
defs,
phis_to_insert,
already_mapped,
bool_ty,
phi_value.unwrap_id_ref(),
)),
block.clone(),
]
})
.collect::<Vec<_>>();
let inst = Instruction::new(Op::Phi, Some(bool_ty), Some(result_id), new_phi_args);
phis_to_insert.push((*block_of_inst, inst));
result_id
}
_ => bug!("can_fuse_bool should have prevented this case"),
}
}
// The compiler generates a lot of code that looks like this:
// %v_int = OpSelect %int %v %const_1 %const_0
// %v2 = OpINotEqual %bool %v_int %const_0
// (This is due to rustc/spirv not supporting bools in memory, and needing to convert to u8, but
// then things get inlined/mem2reg'd)
//
// This pass fuses together those two instructions to strip out the intermediate integer variable.
// The purpose is to make simple code that doesn't actually do memory-stuff with bools not require
// the Int8 capability (and so we can't rely on spirv-opt to do this same pass).
//
// Unfortunately, things get complicated because of phis: the majority of actually useful cases to
// do this pass need to track pseudo-bool ints through phi instructions.
//
// The logic goes like:
// 1) Figure out what we *can* fuse. This means finding OpINotEqual instructions (converting back
// from int->bool) and tracing the value back recursively through any phis, and making sure each
// one terminates in either a loop back around to something we've already seen, or an OpSelect
// (converting from bool->int).
// 2) Do the fusion. Trace back through phis, generating a second bool-typed phi alongside the
// original int-typed phi, and when hitting an OpSelect, taking the bool value directly.
// 3) DCE the dead OpSelects/int-typed OpPhis (done in a later pass). We don't nuke them here,
// since they might be used elsewhere, and don't want to accidentally leave a dangling
// reference.
pub fn bool_fusion(
header: &mut ModuleHeader,
types: &FxHashMap<Word, Instruction>,
function: &mut Function,
) {
let defs: FxHashMap<Word, (usize, Instruction)> = function
.blocks
.iter()
.enumerate()
.flat_map(|(block_id, block)| {
block
.instructions
.iter()
.filter_map(move |inst| Some((inst.result_id?, (block_id, inst.clone()))))
})
.collect();
let mut rewrite_rules = FxHashMap::default();
let mut phis_to_insert = Default::default();
let mut already_mapped = Default::default();
for block in &mut function.blocks {
for inst in &mut block.instructions {
if can_fuse_bool(types, &defs, inst) {
let rewrite_to = fuse_bool(
header,
&defs,
&mut phis_to_insert,
&mut already_mapped,
inst.result_type.unwrap(),
inst.operands[0].unwrap_id_ref(),
);
rewrite_rules.insert(inst.result_id.unwrap(), rewrite_to);
*inst = Instruction::new(Op::Nop, None, None, Vec::new());
}
}
}
for (block, phi) in phis_to_insert {
function.blocks[block].instructions.insert(0, phi);
}
super::apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
}

View File

@ -0,0 +1,12 @@
// unwrap_or generates some memory-bools (as u8). Test to make sure they're fused away.
// OpINotEqual, as well as %bool, should not appear in the output.
// build-pass
// compile-flags: -C llvm-args=--disassemble-entry=main
use spirv_std as _;
#[spirv(fragment)]
pub fn main(out: &mut u32) {
*out = None.unwrap_or(15);
}

View File

@ -0,0 +1,39 @@
%1 = OpFunction %2 None %3
%4 = OpLabel
OpLine %5 11 11
%6 = OpCompositeInsert %7 %8 %9 0
OpLine %5 11 11
%10 = OpCompositeExtract %11 %6 1
OpLine %12 767 14
%13 = OpBitcast %14 %8
OpLine %12 767 8
OpSelectionMerge %15 None
OpSwitch %13 %16 0 %17 1 %18
%16 = OpLabel
OpLine %12 767 14
OpUnreachable
%17 = OpLabel
OpLine %12 769 20
OpBranch %15
%18 = OpLabel
OpLine %12 771 4
OpBranch %15
%15 = OpLabel
%19 = OpPhi %20 %21 %17 %22 %18
%23 = OpPhi %11 %24 %17 %10 %18
OpBranch %25
%25 = OpLabel
OpLine %12 771 4
OpSelectionMerge %26 None
OpBranchConditional %19 %27 %28
%27 = OpLabel
OpLine %12 771 4
OpBranch %26
%28 = OpLabel
OpBranch %26
%26 = OpLabel
OpLine %5 11 4
OpStore %29 %23
OpLine %5 12 1
OpReturn
OpFunctionEnd