mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-25 00:04:11 +00:00
parent
f58c6f20af
commit
c10a1ca756
@ -266,6 +266,7 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
|
||||
for func in &mut output.functions {
|
||||
peephole_opts::composite_construct(&types, func);
|
||||
peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
|
||||
peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
use super::id;
|
||||
use rspirv::dr::{Function, Instruction, Module, ModuleHeader, Operand};
|
||||
use rspirv::spirv::{Op, Word};
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
|
||||
use rustc_middle::bug;
|
||||
|
||||
pub fn collect_types(module: &Module) -> FxHashMap<Word, Instruction> {
|
||||
@ -447,3 +448,168 @@ pub fn vector_ops(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn can_fuse_bool(
|
||||
types: &FxHashMap<Word, Instruction>,
|
||||
defs: &FxHashMap<Word, (usize, Instruction)>,
|
||||
inst: &Instruction,
|
||||
) -> bool {
|
||||
fn constant_value(types: &FxHashMap<Word, Instruction>, val: Word) -> Option<u32> {
|
||||
let inst = match types.get(&val) {
|
||||
None => return None,
|
||||
Some(inst) => inst,
|
||||
};
|
||||
if inst.class.opcode != Op::Constant {
|
||||
return None;
|
||||
}
|
||||
match inst.operands[0] {
|
||||
Operand::LiteralInt32(v) => Some(v),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn visit(
|
||||
types: &FxHashMap<Word, Instruction>,
|
||||
defs: &FxHashMap<Word, (usize, Instruction)>,
|
||||
visited: &mut FxHashSet<Word>,
|
||||
value: Word,
|
||||
) -> bool {
|
||||
if visited.insert(value) {
|
||||
let inst = match defs.get(&value) {
|
||||
Some((_, inst)) => inst,
|
||||
None => return false,
|
||||
};
|
||||
match inst.class.opcode {
|
||||
Op::Select => {
|
||||
constant_value(types, inst.operands[1].unwrap_id_ref()) == Some(1)
|
||||
&& constant_value(types, inst.operands[2].unwrap_id_ref()) == Some(0)
|
||||
}
|
||||
Op::Phi => inst
|
||||
.operands
|
||||
.iter()
|
||||
.step_by(2)
|
||||
.all(|op| visit(types, defs, visited, op.unwrap_id_ref())),
|
||||
_ => false,
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
if inst.class.opcode != Op::INotEqual
|
||||
|| constant_value(types, inst.operands[1].unwrap_id_ref()) != Some(0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
let int_value = inst.operands[0].unwrap_id_ref();
|
||||
|
||||
visit(types, defs, &mut FxHashSet::default(), int_value)
|
||||
}
|
||||
|
||||
fn fuse_bool(
|
||||
header: &mut ModuleHeader,
|
||||
defs: &FxHashMap<Word, (usize, Instruction)>,
|
||||
phis_to_insert: &mut Vec<(usize, Instruction)>,
|
||||
already_mapped: &mut FxHashMap<Word, Word>,
|
||||
bool_ty: Word,
|
||||
int_value: Word,
|
||||
) -> Word {
|
||||
if let Some(&result) = already_mapped.get(&int_value) {
|
||||
return result;
|
||||
}
|
||||
let (block_of_inst, inst) = defs.get(&int_value).unwrap();
|
||||
match inst.class.opcode {
|
||||
Op::Select => inst.operands[0].unwrap_id_ref(),
|
||||
Op::Phi => {
|
||||
let result_id = id(header);
|
||||
already_mapped.insert(int_value, result_id);
|
||||
let new_phi_args = inst
|
||||
.operands
|
||||
.chunks(2)
|
||||
.flat_map(|arr| {
|
||||
let phi_value = &arr[0];
|
||||
let block = &arr[1];
|
||||
[
|
||||
Operand::IdRef(fuse_bool(
|
||||
header,
|
||||
defs,
|
||||
phis_to_insert,
|
||||
already_mapped,
|
||||
bool_ty,
|
||||
phi_value.unwrap_id_ref(),
|
||||
)),
|
||||
block.clone(),
|
||||
]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let inst = Instruction::new(Op::Phi, Some(bool_ty), Some(result_id), new_phi_args);
|
||||
phis_to_insert.push((*block_of_inst, inst));
|
||||
result_id
|
||||
}
|
||||
_ => bug!("can_fuse_bool should have prevented this case"),
|
||||
}
|
||||
}
|
||||
|
||||
// The compiler generates a lot of code that looks like this:
|
||||
// %v_int = OpSelect %int %v %const_1 %const_0
|
||||
// %v2 = OpINotEqual %bool %v_int %const_0
|
||||
// (This is due to rustc/spirv not supporting bools in memory, and needing to convert to u8, but
|
||||
// then things get inlined/mem2reg'd)
|
||||
//
|
||||
// This pass fuses together those two instructions to strip out the intermediate integer variable.
|
||||
// The purpose is to make simple code that doesn't actually do memory-stuff with bools not require
|
||||
// the Int8 capability (and so we can't rely on spirv-opt to do this same pass).
|
||||
//
|
||||
// Unfortunately, things get complicated because of phis: the majority of actually useful cases to
|
||||
// do this pass need to track pseudo-bool ints through phi instructions.
|
||||
//
|
||||
// The logic goes like:
|
||||
// 1) Figure out what we *can* fuse. This means finding OpINotEqual instructions (converting back
|
||||
// from int->bool) and tracing the value back recursively through any phis, and making sure each
|
||||
// one terminates in either a loop back around to something we've already seen, or an OpSelect
|
||||
// (converting from bool->int).
|
||||
// 2) Do the fusion. Trace back through phis, generating a second bool-typed phi alongside the
|
||||
// original int-typed phi, and when hitting an OpSelect, taking the bool value directly.
|
||||
// 3) DCE the dead OpSelects/int-typed OpPhis (done in a later pass). We don't nuke them here,
|
||||
// since they might be used elsewhere, and don't want to accidentally leave a dangling
|
||||
// reference.
|
||||
pub fn bool_fusion(
|
||||
header: &mut ModuleHeader,
|
||||
types: &FxHashMap<Word, Instruction>,
|
||||
function: &mut Function,
|
||||
) {
|
||||
let defs: FxHashMap<Word, (usize, Instruction)> = function
|
||||
.blocks
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(block_id, block)| {
|
||||
block
|
||||
.instructions
|
||||
.iter()
|
||||
.filter_map(move |inst| Some((inst.result_id?, (block_id, inst.clone()))))
|
||||
})
|
||||
.collect();
|
||||
let mut rewrite_rules = FxHashMap::default();
|
||||
let mut phis_to_insert = Default::default();
|
||||
let mut already_mapped = Default::default();
|
||||
for block in &mut function.blocks {
|
||||
for inst in &mut block.instructions {
|
||||
if can_fuse_bool(types, &defs, inst) {
|
||||
let rewrite_to = fuse_bool(
|
||||
header,
|
||||
&defs,
|
||||
&mut phis_to_insert,
|
||||
&mut already_mapped,
|
||||
inst.result_type.unwrap(),
|
||||
inst.operands[0].unwrap_id_ref(),
|
||||
);
|
||||
rewrite_rules.insert(inst.result_id.unwrap(), rewrite_to);
|
||||
*inst = Instruction::new(Op::Nop, None, None, Vec::new());
|
||||
}
|
||||
}
|
||||
}
|
||||
for (block, phi) in phis_to_insert {
|
||||
function.blocks[block].instructions.insert(0, phi);
|
||||
}
|
||||
super::apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
|
||||
}
|
||||
|
12
tests/ui/lang/core/unwrap_or.rs
Normal file
12
tests/ui/lang/core/unwrap_or.rs
Normal file
@ -0,0 +1,12 @@
|
||||
// unwrap_or generates some memory-bools (as u8). Test to make sure they're fused away.
|
||||
// OpINotEqual, as well as %bool, should not appear in the output.
|
||||
|
||||
// build-pass
|
||||
// compile-flags: -C llvm-args=--disassemble-entry=main
|
||||
|
||||
use spirv_std as _;
|
||||
|
||||
#[spirv(fragment)]
|
||||
pub fn main(out: &mut u32) {
|
||||
*out = None.unwrap_or(15);
|
||||
}
|
39
tests/ui/lang/core/unwrap_or.stderr
Normal file
39
tests/ui/lang/core/unwrap_or.stderr
Normal file
@ -0,0 +1,39 @@
|
||||
%1 = OpFunction %2 None %3
|
||||
%4 = OpLabel
|
||||
OpLine %5 11 11
|
||||
%6 = OpCompositeInsert %7 %8 %9 0
|
||||
OpLine %5 11 11
|
||||
%10 = OpCompositeExtract %11 %6 1
|
||||
OpLine %12 767 14
|
||||
%13 = OpBitcast %14 %8
|
||||
OpLine %12 767 8
|
||||
OpSelectionMerge %15 None
|
||||
OpSwitch %13 %16 0 %17 1 %18
|
||||
%16 = OpLabel
|
||||
OpLine %12 767 14
|
||||
OpUnreachable
|
||||
%17 = OpLabel
|
||||
OpLine %12 769 20
|
||||
OpBranch %15
|
||||
%18 = OpLabel
|
||||
OpLine %12 771 4
|
||||
OpBranch %15
|
||||
%15 = OpLabel
|
||||
%19 = OpPhi %20 %21 %17 %22 %18
|
||||
%23 = OpPhi %11 %24 %17 %10 %18
|
||||
OpBranch %25
|
||||
%25 = OpLabel
|
||||
OpLine %12 771 4
|
||||
OpSelectionMerge %26 None
|
||||
OpBranchConditional %19 %27 %28
|
||||
%27 = OpLabel
|
||||
OpLine %12 771 4
|
||||
OpBranch %26
|
||||
%28 = OpLabel
|
||||
OpBranch %26
|
||||
%26 = OpLabel
|
||||
OpLine %5 11 4
|
||||
OpStore %29 %23
|
||||
OpLine %5 12 1
|
||||
OpReturn
|
||||
OpFunctionEnd
|
Loading…
Reference in New Issue
Block a user