diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 701766c811..c6355be625 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -16,11 +16,13 @@ use rustc_codegen_ssa::traits::{ BackendTypes, BuilderMethods, ConstMethods, IntrinsicCallMethods, LayoutTypeMethods, OverflowOp, }; use rustc_codegen_ssa::MemFlags; +use rustc_data_structures::fx::FxHashSet; use rustc_middle::bug; use rustc_middle::ty::Ty; use rustc_span::Span; use rustc_target::abi::call::FnAbi; use rustc_target::abi::{Abi, Align, Scalar, Size, WrappingRange}; +use smallvec::SmallVec; use std::convert::TryInto; use std::iter::{self, empty}; @@ -2393,7 +2395,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // nor simplified in MIR (e.g. promoted to a constant) in any way, // so we have to try and remove the `fmt::Arguments::new` call here. // HACK(eddyb) this is basically a `try` block. - let remove_simple_format_args_if_possible = || -> Option<()> { + let remove_format_args_if_possible = || -> Option<()> { let format_args_id = match args { &[SpirvValue { kind: SpirvValueKind::Def(format_args_id), @@ -2414,6 +2416,19 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let func_idx = builder.selected_function().unwrap(); let block_idx = builder.selected_block().unwrap(); let func = &mut builder.module_mut().functions[func_idx]; + + // HACK(eddyb) this is used to check that all `Op{Store,Load}`s + // that may get removed, operate on local `OpVariable`s, + // i.e. are not externally observable. + let local_var_ids: FxHashSet<_> = func.blocks[0] + .instructions + .iter() + .take_while(|inst| inst.class.opcode == Op::Variable) + .map(|inst| inst.result_id.unwrap()) + .collect(); + let require_local_var = + |ptr_id| Some(()).filter(|()| local_var_ids.contains(&ptr_id)); + let mut non_debug_insts = func.blocks[block_idx] .instructions .iter() @@ -2426,68 +2441,193 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { .contains(&CustomOp::decode_from_ext_inst(inst)); !(is_standard_debug || is_custom_debug) }); - let mut relevant_insts_next_back = |expected_op| { - non_debug_insts - .next_back() - .filter(|(_, inst)| inst.class.opcode == expected_op) - .map(|(i, inst)| { - ( - i, - inst.result_id, - inst.operands.iter().map(|operand| operand.unwrap_id_ref()), - ) - }) - }; - let (_, load_src_id) = relevant_insts_next_back(Op::Load) - .map(|(_, result_id, mut operands)| { - (result_id.unwrap(), operands.next().unwrap()) - }) - .filter(|&(result_id, _)| result_id == format_args_id)?; - let (_, store_val_id) = relevant_insts_next_back(Op::Store) - .map(|(_, _, mut operands)| { - (operands.next().unwrap(), operands.next().unwrap()) - }) - .filter(|&(store_dst_id, _)| store_dst_id == load_src_id)?; - let call_fmt_args_new_idx = relevant_insts_next_back(Op::FunctionCall) - .filter(|&(_, result_id, _)| result_id == Some(store_val_id)) - .map(|(i, _, mut operands)| (i, operands.next().unwrap(), operands)) - .filter(|&(_, callee, _)| self.fmt_args_new_fn_ids.borrow().contains(&callee)) - .and_then(|(i, _, mut call_args)| { - if call_args.len() == 4 { - // `::new_v1` - let mut arg = || call_args.next().unwrap(); - let [_, _, _, fmt_args_len_id] = [arg(), arg(), arg(), arg()]; - // Only ever remove `fmt::Arguments` with no runtime values. - Some(i).filter(|_| { - matches!( - self.builder.lookup_const_by_id(fmt_args_len_id), - Some(SpirvConst::U32(0)) - ) - }) - } else { - // `::new_const` - assert_eq!(call_args.len(), 2); - Some(i) - } - })?; - // Lastly, ensure that the `Op{Store,Load}` pair operates on - // a local `OpVariable`, i.e. is not externally observable. - let store_load_local_var = func.blocks[0] - .instructions - .iter() - .take_while(|inst| inst.class.opcode == Op::Variable) - .find(|inst| inst.result_id == Some(load_src_id)); - if store_load_local_var.is_some() { - // Keep all instructions up to (but not including) the call. - func.blocks[block_idx] - .instructions - .truncate(call_fmt_args_new_idx); + // HACK(eddyb) to aid in pattern-matching, relevant instructions + // are decoded to values of this `enum`. For instructions that + // produce results, the result ID is the first `ID` value. + #[derive(Debug)] + enum Inst<'tcx, ID> { + Bitcast(ID, ID), + AccessChain(ID, ID, SpirvConst<'tcx>), + InBoundsAccessChain(ID, ID, SpirvConst<'tcx>), + Store(ID, ID), + Load(ID, ID), + Call(ID, ID, SmallVec<[ID; 4]>), } + let mut taken_inst_idx_range = func.blocks[block_idx].instructions.len()..; + + // Take `count` instructions, advancing backwards, but returning + // instructions in their original order (and decoded to `Inst`s). + let mut try_rev_take = |count| { + let maybe_rev_insts = (0..count).map(|_| { + let (i, inst) = non_debug_insts.next_back()?; + taken_inst_idx_range = i..; + + // HACK(eddyb) all instructions accepted below + // are expected to take no more than 4 operands, + // and this is easier to use than an iterator. + let id_operands = inst + .operands + .iter() + .map(|operand| operand.id_ref_any()) + .collect::>>()?; + + // Decode the instruction into one of our `Inst`s. + Some( + match (inst.class.opcode, inst.result_id, &id_operands[..]) { + (Op::Bitcast, Some(r), &[x]) => Inst::Bitcast(r, x), + (Op::AccessChain, Some(r), &[p, i]) => { + Inst::AccessChain(r, p, self.builder.lookup_const_by_id(i)?) + } + (Op::InBoundsAccessChain, Some(r), &[p, i]) => { + Inst::InBoundsAccessChain( + r, + p, + self.builder.lookup_const_by_id(i)?, + ) + } + (Op::Store, None, &[p, v]) => Inst::Store(p, v), + (Op::Load, Some(r), &[p]) => Inst::Load(r, p), + (Op::FunctionCall, Some(r), [f, args @ ..]) => { + Inst::Call(r, *f, args.iter().copied().collect()) + } + _ => return None, + }, + ) + }); + let mut insts = maybe_rev_insts.collect::>>()?; + insts.reverse(); + Some(insts) + }; + + let (rt_args_slice_ptr_id, rt_args_count) = match try_rev_take(3)?[..] { + [ + // HACK(eddyb) comment works around `rustfmt` array bugs. + Inst::Call(call_ret_id, callee_id, ref call_args), + Inst::Store(st_dst_id, st_val_id), + Inst::Load(ld_val_id, ld_src_id), + ] + if self.fmt_args_new_fn_ids.borrow().contains(&callee_id) + && call_ret_id == st_val_id + && st_dst_id == ld_src_id + && ld_val_id == format_args_id => + { + require_local_var(st_dst_id)?; + match call_args[..] { + // `::new_v1` + [_, _, rt_args_slice_ptr_id, rt_args_len_id] => ( + Some(rt_args_slice_ptr_id), + self.builder + .lookup_const_by_id(rt_args_len_id) + .and_then(|ct| match ct { + SpirvConst::U32(x) => Some(x as usize), + _ => None, + })?, + ), + + // `::new_const` + [_, _] => (None, 0), + + _ => return None, + } + } + _ => return None, + }; + + // HACK(eddyb) this is the worst part: if we do have runtime + // arguments (from e.g. new `assert!`s being added to `core`), + // we have to confirm their many instructions for removal. + if rt_args_count > 0 { + let rt_args_slice_ptr_id = rt_args_slice_ptr_id.unwrap(); + let rt_args_array_ptr_id = match try_rev_take(1)?[..] { + [Inst::Bitcast(out_id, in_id)] if out_id == rt_args_slice_ptr_id => in_id, + _ => return None, + }; + require_local_var(rt_args_array_ptr_id); + + // Each runtime argument has its own variable, 6 instructions + // to initialize it, and 9 instructions to copy it to the + // appropriate slot in the array. The groups of 6 and 9 + // instructions, for all runtime args, are each separate. + let copies_from_rt_arg_vars_to_rt_args_array = try_rev_take(rt_args_count * 9)?; + let copies_from_rt_arg_vars_to_rt_args_array = + copies_from_rt_arg_vars_to_rt_args_array.chunks(9); + let inits_of_rt_arg_vars = try_rev_take(rt_args_count * 6)?; + let inits_of_rt_arg_vars = inits_of_rt_arg_vars.chunks(6); + + for ( + rt_arg_idx, + (init_of_rt_arg_var_insts, copy_from_rt_arg_var_to_rt_args_array_insts), + ) in inits_of_rt_arg_vars + .zip(copies_from_rt_arg_vars_to_rt_args_array) + .enumerate() + { + let rt_arg_var_id = match init_of_rt_arg_var_insts[..] { + [ + // HACK(eddyb) comment works around `rustfmt` array bugs. + Inst::Bitcast(b, _), + Inst::Bitcast(a, _), + Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)), + Inst::Store(a_st_dst, a_st_val), + Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)), + Inst::Store(b_st_dst, b_st_val), + ] if a_base_ptr == b_base_ptr + && (a, b) == (a_st_val, b_st_val) + && (a_ptr, b_ptr) == (a_st_dst, b_st_dst) => + { + require_local_var(a_base_ptr); + a_base_ptr + } + _ => return None, + }; + + // HACK(eddyb) this is only split to allow variable name reuse. + let (copy_loads, copy_stores) = + copy_from_rt_arg_var_to_rt_args_array_insts.split_at(4); + let (a, b) = match copy_loads[..] { + [ + // HACK(eddyb) comment works around `rustfmt` array bugs. + Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)), + Inst::Load(a_ld_val, a_ld_src), + Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)), + Inst::Load(b_ld_val, b_ld_src), + ] if [a_base_ptr, b_base_ptr] == [rt_arg_var_id; 2] + && (a_ptr, b_ptr) == (a_ld_src, b_ld_src) => + { + (a_ld_val, b_ld_val) + } + _ => return None, + }; + match copy_stores[..] { + [ + // HACK(eddyb) comment works around `rustfmt` array bugs. + Inst::InBoundsAccessChain(array_slot_ptr, array_base_ptr, SpirvConst::U32(array_idx)), + Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)), + Inst::Store(a_st_dst, a_st_val), + Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)), + Inst::Store(b_st_dst, b_st_val), + ] if array_base_ptr == rt_args_array_ptr_id + && array_idx as usize == rt_arg_idx + && [a_base_ptr, b_base_ptr] == [array_slot_ptr; 2] + && (a, b) == (a_st_val, b_st_val) + && (a_ptr, b_ptr) == (a_st_dst, b_st_dst) => + { + } + _ => return None, + } + } + } + + // Keep all instructions up to (but not including) the last one + // confirmed above to be the first instruction of `format_args!`. + func.blocks[block_idx] + .instructions + .truncate(taken_inst_idx_range.start); + None }; - remove_simple_format_args_if_possible(); + remove_format_args_if_possible(); // HACK(eddyb) redirect any possible panic call to an abort, to avoid // needing to materialize `&core::panic::Location` or `format_args!`.