diff --git a/crates/rustc_codegen_spirv/src/linker/dce.rs b/crates/rustc_codegen_spirv/src/linker/dce.rs index 472a96b079..40eefa993e 100644 --- a/crates/rustc_codegen_spirv/src/linker/dce.rs +++ b/crates/rustc_codegen_spirv/src/linker/dce.rs @@ -7,8 +7,8 @@ //! *references* a rooted thing is also rooted, not the other way around - but that's the basic //! concept. -use rspirv::dr::{Function, Instruction, Module}; -use rspirv::spirv::{Op, Word}; +use rspirv::dr::{Function, Instruction, Module, Operand}; +use rspirv::spirv::{Op, StorageClass, Word}; use rustc_data_structures::fx::FxHashSet; pub fn dce(module: &mut Module) { @@ -36,8 +36,29 @@ fn spread_roots(module: &Module, rooted: &mut FxHashSet) -> bool { } for func in &module.functions { if rooted.contains(&func.def_id().unwrap()) { - for inst in func.all_inst_iter() { - any |= root(inst, rooted); + // NB (Mobius 2021) - since later insts are much more likely to reference + // earlier insts, by reversing the iteration order, we're more likely to root the + // entire relevant function at once. + // See https://github.com/EmbarkStudios/rust-gpu/pull/691#discussion_r681477091 + for inst in func + .end + .iter() + .chain( + func.blocks + .iter() + .rev() + .flat_map(|b| b.instructions.iter().rev().chain(b.label.iter())), + ) + .chain(func.parameters.iter().rev()) + .chain(func.def.iter()) + { + if !instruction_is_pure(inst) { + any |= root(inst, rooted); + } else if let Some(id) = inst.result_id { + if rooted.contains(&id) { + any |= root(inst, rooted); + } + } } } } @@ -90,6 +111,13 @@ fn kill_unrooted(module: &mut Module, rooted: &FxHashSet) { module .functions .retain(|f| is_rooted(f.def.as_ref().unwrap(), rooted)); + module.functions.iter_mut().for_each(|fun| { + fun.blocks.iter_mut().for_each(|block| { + block + .instructions + .retain(|inst| !instruction_is_pure(inst) || is_rooted(inst, rooted)); + }); + }); } pub fn dce_phi(func: &mut Function) { @@ -115,3 +143,127 @@ pub fn dce_phi(func: &mut Function) { .retain(|inst| inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap())); } } + +fn instruction_is_pure(inst: &Instruction) -> bool { + use Op::*; + match inst.class.opcode { + Nop + | Undef + | ConstantTrue + | ConstantFalse + | Constant + | ConstantComposite + | ConstantSampler + | ConstantNull + | AccessChain + | InBoundsAccessChain + | PtrAccessChain + | ArrayLength + | InBoundsPtrAccessChain + | CompositeConstruct + | CompositeExtract + | CopyObject + | Transpose + | ConvertFToU + | ConvertFToS + | ConvertSToF + | ConvertUToF + | UConvert + | SConvert + | FConvert + | QuantizeToF16 + | ConvertPtrToU + | SatConvertSToU + | SatConvertUToS + | ConvertUToPtr + | PtrCastToGeneric + | GenericCastToPtr + | GenericCastToPtrExplicit + | Bitcast + | SNegate + | FNegate + | IAdd + | FAdd + | ISub + | FSub + | IMul + | FMul + | UDiv + | SDiv + | FDiv + | UMod + | SRem + | SMod + | FRem + | FMod + | VectorTimesScalar + | MatrixTimesScalar + | VectorTimesMatrix + | MatrixTimesVector + | MatrixTimesMatrix + | OuterProduct + | Dot + | IAddCarry + | ISubBorrow + | UMulExtended + | SMulExtended + | Any + | All + | IsNan + | IsInf + | IsFinite + | IsNormal + | SignBitSet + | LessOrGreater + | Ordered + | Unordered + | LogicalEqual + | LogicalNotEqual + | LogicalOr + | LogicalAnd + | LogicalNot + | Select + | IEqual + | INotEqual + | UGreaterThan + | SGreaterThan + | UGreaterThanEqual + | SGreaterThanEqual + | ULessThan + | SLessThan + | ULessThanEqual + | SLessThanEqual + | FOrdEqual + | FUnordEqual + | FOrdNotEqual + | FUnordNotEqual + | FOrdLessThan + | FUnordLessThan + | FOrdGreaterThan + | FUnordGreaterThan + | FOrdLessThanEqual + | FUnordLessThanEqual + | FOrdGreaterThanEqual + | FUnordGreaterThanEqual + | ShiftRightLogical + | ShiftRightArithmetic + | ShiftLeftLogical + | BitwiseOr + | BitwiseXor + | BitwiseAnd + | Not + | BitFieldInsert + | BitFieldSExtract + | BitFieldUExtract + | BitReverse + | BitCount + | Phi + | SizeOf + | CopyLogical + | PtrEqual + | PtrNotEqual + | PtrDiff => true, + Variable => inst.operands.get(0) == Some(&Operand::StorageClass(StorageClass::Function)), + _ => false, + } +} diff --git a/tests/ui/dis/index_user_dst.stderr b/tests/ui/dis/index_user_dst.stderr index 2cc20726c0..2335846a19 100644 --- a/tests/ui/dis/index_user_dst.stderr +++ b/tests/ui/dis/index_user_dst.stderr @@ -5,33 +5,32 @@ OpLine %5 7 12 %10 = OpArrayLength %11 %8 0 OpLine %5 7 0 %12 = OpCompositeInsert %13 %6 %14 0 -%15 = OpCompositeConstruct %13 %6 %10 OpLine %5 8 21 -%16 = OpULessThan %17 %9 %10 +%15 = OpULessThan %16 %9 %10 OpLine %5 8 21 -OpSelectionMerge %18 None -OpBranchConditional %16 %19 %20 -%19 = OpLabel +OpSelectionMerge %17 None +OpBranchConditional %15 %18 %19 +%18 = OpLabel OpLine %5 8 21 -%21 = OpInBoundsAccessChain %22 %6 %9 -%23 = OpLoad %24 %21 +%20 = OpInBoundsAccessChain %21 %6 %9 +%22 = OpLoad %23 %20 OpLine %5 10 1 OpReturn -%20 = OpLabel +%19 = OpLabel OpLine %5 8 21 +OpBranch %24 +%24 = OpLabel OpBranch %25 %25 = OpLabel -OpBranch %26 -%26 = OpLabel -%27 = OpPhi %17 %28 %25 %28 %29 -OpLoopMerge %30 %29 None -OpBranchConditional %27 %31 %30 -%31 = OpLabel -OpBranch %29 -%29 = OpLabel -OpBranch %26 +%26 = OpPhi %16 %27 %24 %27 %28 +OpLoopMerge %29 %28 None +OpBranchConditional %26 %30 %29 %30 = OpLabel +OpBranch %28 +%28 = OpLabel +OpBranch %25 +%29 = OpLabel OpUnreachable -%18 = OpLabel +%17 = OpLabel OpUnreachable OpFunctionEnd