Aggressively prune no-side-effect instructions during DCE. (#691)

* Aggressively prune no-side-effect instructions during DCE.

Since we're walking all the instructions anyway, it's practically
zero-cost.

* Reverse iteration order within a function.

This allows to root more instructions per `spread_roots`
invocation, becoming zero-cost in absence of loops.

* Manually iterate over function instructions in reverse order.
This commit is contained in:
Alex Es 2021-08-04 11:03:38 +03:00 committed by GitHub
parent cccb9737d7
commit d548268140
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 173 additions and 22 deletions

View File

@ -7,8 +7,8 @@
//! *references* a rooted thing is also rooted, not the other way around - but that's the basic
//! concept.
use rspirv::dr::{Function, Instruction, Module};
use rspirv::spirv::{Op, Word};
use rspirv::dr::{Function, Instruction, Module, Operand};
use rspirv::spirv::{Op, StorageClass, Word};
use rustc_data_structures::fx::FxHashSet;
pub fn dce(module: &mut Module) {
@ -36,8 +36,29 @@ fn spread_roots(module: &Module, rooted: &mut FxHashSet<Word>) -> bool {
}
for func in &module.functions {
if rooted.contains(&func.def_id().unwrap()) {
for inst in func.all_inst_iter() {
any |= root(inst, rooted);
// NB (Mobius 2021) - since later insts are much more likely to reference
// earlier insts, by reversing the iteration order, we're more likely to root the
// entire relevant function at once.
// See https://github.com/EmbarkStudios/rust-gpu/pull/691#discussion_r681477091
for inst in func
.end
.iter()
.chain(
func.blocks
.iter()
.rev()
.flat_map(|b| b.instructions.iter().rev().chain(b.label.iter())),
)
.chain(func.parameters.iter().rev())
.chain(func.def.iter())
{
if !instruction_is_pure(inst) {
any |= root(inst, rooted);
} else if let Some(id) = inst.result_id {
if rooted.contains(&id) {
any |= root(inst, rooted);
}
}
}
}
}
@ -90,6 +111,13 @@ fn kill_unrooted(module: &mut Module, rooted: &FxHashSet<Word>) {
module
.functions
.retain(|f| is_rooted(f.def.as_ref().unwrap(), rooted));
module.functions.iter_mut().for_each(|fun| {
fun.blocks.iter_mut().for_each(|block| {
block
.instructions
.retain(|inst| !instruction_is_pure(inst) || is_rooted(inst, rooted));
});
});
}
pub fn dce_phi(func: &mut Function) {
@ -115,3 +143,127 @@ pub fn dce_phi(func: &mut Function) {
.retain(|inst| inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()));
}
}
fn instruction_is_pure(inst: &Instruction) -> bool {
use Op::*;
match inst.class.opcode {
Nop
| Undef
| ConstantTrue
| ConstantFalse
| Constant
| ConstantComposite
| ConstantSampler
| ConstantNull
| AccessChain
| InBoundsAccessChain
| PtrAccessChain
| ArrayLength
| InBoundsPtrAccessChain
| CompositeConstruct
| CompositeExtract
| CopyObject
| Transpose
| ConvertFToU
| ConvertFToS
| ConvertSToF
| ConvertUToF
| UConvert
| SConvert
| FConvert
| QuantizeToF16
| ConvertPtrToU
| SatConvertSToU
| SatConvertUToS
| ConvertUToPtr
| PtrCastToGeneric
| GenericCastToPtr
| GenericCastToPtrExplicit
| Bitcast
| SNegate
| FNegate
| IAdd
| FAdd
| ISub
| FSub
| IMul
| FMul
| UDiv
| SDiv
| FDiv
| UMod
| SRem
| SMod
| FRem
| FMod
| VectorTimesScalar
| MatrixTimesScalar
| VectorTimesMatrix
| MatrixTimesVector
| MatrixTimesMatrix
| OuterProduct
| Dot
| IAddCarry
| ISubBorrow
| UMulExtended
| SMulExtended
| Any
| All
| IsNan
| IsInf
| IsFinite
| IsNormal
| SignBitSet
| LessOrGreater
| Ordered
| Unordered
| LogicalEqual
| LogicalNotEqual
| LogicalOr
| LogicalAnd
| LogicalNot
| Select
| IEqual
| INotEqual
| UGreaterThan
| SGreaterThan
| UGreaterThanEqual
| SGreaterThanEqual
| ULessThan
| SLessThan
| ULessThanEqual
| SLessThanEqual
| FOrdEqual
| FUnordEqual
| FOrdNotEqual
| FUnordNotEqual
| FOrdLessThan
| FUnordLessThan
| FOrdGreaterThan
| FUnordGreaterThan
| FOrdLessThanEqual
| FUnordLessThanEqual
| FOrdGreaterThanEqual
| FUnordGreaterThanEqual
| ShiftRightLogical
| ShiftRightArithmetic
| ShiftLeftLogical
| BitwiseOr
| BitwiseXor
| BitwiseAnd
| Not
| BitFieldInsert
| BitFieldSExtract
| BitFieldUExtract
| BitReverse
| BitCount
| Phi
| SizeOf
| CopyLogical
| PtrEqual
| PtrNotEqual
| PtrDiff => true,
Variable => inst.operands.get(0) == Some(&Operand::StorageClass(StorageClass::Function)),
_ => false,
}
}

View File

@ -5,33 +5,32 @@ OpLine %5 7 12
%10 = OpArrayLength %11 %8 0
OpLine %5 7 0
%12 = OpCompositeInsert %13 %6 %14 0
%15 = OpCompositeConstruct %13 %6 %10
OpLine %5 8 21
%16 = OpULessThan %17 %9 %10
%15 = OpULessThan %16 %9 %10
OpLine %5 8 21
OpSelectionMerge %18 None
OpBranchConditional %16 %19 %20
%19 = OpLabel
OpSelectionMerge %17 None
OpBranchConditional %15 %18 %19
%18 = OpLabel
OpLine %5 8 21
%21 = OpInBoundsAccessChain %22 %6 %9
%23 = OpLoad %24 %21
%20 = OpInBoundsAccessChain %21 %6 %9
%22 = OpLoad %23 %20
OpLine %5 10 1
OpReturn
%20 = OpLabel
%19 = OpLabel
OpLine %5 8 21
OpBranch %24
%24 = OpLabel
OpBranch %25
%25 = OpLabel
OpBranch %26
%26 = OpLabel
%27 = OpPhi %17 %28 %25 %28 %29
OpLoopMerge %30 %29 None
OpBranchConditional %27 %31 %30
%31 = OpLabel
OpBranch %29
%29 = OpLabel
OpBranch %26
%26 = OpPhi %16 %27 %24 %27 %28
OpLoopMerge %29 %28 None
OpBranchConditional %26 %30 %29
%30 = OpLabel
OpBranch %28
%28 = OpLabel
OpBranch %25
%29 = OpLabel
OpUnreachable
%18 = OpLabel
%17 = OpLabel
OpUnreachable
OpFunctionEnd