builder: generalize the panic format_args! remover to handle runtime args.

This commit is contained in:
Eduard-Mihai Burtescu 2023-06-02 09:34:41 +03:00 committed by Eduard-Mihai Burtescu
parent e9cdb9666b
commit 54d98c882f

View File

@ -16,11 +16,13 @@ use rustc_codegen_ssa::traits::{
BackendTypes, BuilderMethods, ConstMethods, IntrinsicCallMethods, LayoutTypeMethods, OverflowOp,
};
use rustc_codegen_ssa::MemFlags;
use rustc_data_structures::fx::FxHashSet;
use rustc_middle::bug;
use rustc_middle::ty::Ty;
use rustc_span::Span;
use rustc_target::abi::call::FnAbi;
use rustc_target::abi::{Abi, Align, Scalar, Size, WrappingRange};
use smallvec::SmallVec;
use std::convert::TryInto;
use std::iter::{self, empty};
@ -2393,7 +2395,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// nor simplified in MIR (e.g. promoted to a constant) in any way,
// so we have to try and remove the `fmt::Arguments::new` call here.
// HACK(eddyb) this is basically a `try` block.
let remove_simple_format_args_if_possible = || -> Option<()> {
let remove_format_args_if_possible = || -> Option<()> {
let format_args_id = match args {
&[SpirvValue {
kind: SpirvValueKind::Def(format_args_id),
@ -2414,6 +2416,19 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
let func_idx = builder.selected_function().unwrap();
let block_idx = builder.selected_block().unwrap();
let func = &mut builder.module_mut().functions[func_idx];
// HACK(eddyb) this is used to check that all `Op{Store,Load}`s
// that may get removed, operate on local `OpVariable`s,
// i.e. are not externally observable.
let local_var_ids: FxHashSet<_> = func.blocks[0]
.instructions
.iter()
.take_while(|inst| inst.class.opcode == Op::Variable)
.map(|inst| inst.result_id.unwrap())
.collect();
let require_local_var =
|ptr_id| Some(()).filter(|()| local_var_ids.contains(&ptr_id));
let mut non_debug_insts = func.blocks[block_idx]
.instructions
.iter()
@ -2426,68 +2441,193 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.contains(&CustomOp::decode_from_ext_inst(inst));
!(is_standard_debug || is_custom_debug)
});
let mut relevant_insts_next_back = |expected_op| {
non_debug_insts
.next_back()
.filter(|(_, inst)| inst.class.opcode == expected_op)
.map(|(i, inst)| {
(
i,
inst.result_id,
inst.operands.iter().map(|operand| operand.unwrap_id_ref()),
)
})
};
let (_, load_src_id) = relevant_insts_next_back(Op::Load)
.map(|(_, result_id, mut operands)| {
(result_id.unwrap(), operands.next().unwrap())
})
.filter(|&(result_id, _)| result_id == format_args_id)?;
let (_, store_val_id) = relevant_insts_next_back(Op::Store)
.map(|(_, _, mut operands)| {
(operands.next().unwrap(), operands.next().unwrap())
})
.filter(|&(store_dst_id, _)| store_dst_id == load_src_id)?;
let call_fmt_args_new_idx = relevant_insts_next_back(Op::FunctionCall)
.filter(|&(_, result_id, _)| result_id == Some(store_val_id))
.map(|(i, _, mut operands)| (i, operands.next().unwrap(), operands))
.filter(|&(_, callee, _)| self.fmt_args_new_fn_ids.borrow().contains(&callee))
.and_then(|(i, _, mut call_args)| {
if call_args.len() == 4 {
// `<core::fmt::Arguments>::new_v1`
let mut arg = || call_args.next().unwrap();
let [_, _, _, fmt_args_len_id] = [arg(), arg(), arg(), arg()];
// Only ever remove `fmt::Arguments` with no runtime values.
Some(i).filter(|_| {
matches!(
self.builder.lookup_const_by_id(fmt_args_len_id),
Some(SpirvConst::U32(0))
)
})
} else {
// `<core::fmt::Arguments>::new_const`
assert_eq!(call_args.len(), 2);
Some(i)
}
})?;
// Lastly, ensure that the `Op{Store,Load}` pair operates on
// a local `OpVariable`, i.e. is not externally observable.
let store_load_local_var = func.blocks[0]
.instructions
// HACK(eddyb) to aid in pattern-matching, relevant instructions
// are decoded to values of this `enum`. For instructions that
// produce results, the result ID is the first `ID` value.
#[derive(Debug)]
enum Inst<'tcx, ID> {
Bitcast(ID, ID),
AccessChain(ID, ID, SpirvConst<'tcx>),
InBoundsAccessChain(ID, ID, SpirvConst<'tcx>),
Store(ID, ID),
Load(ID, ID),
Call(ID, ID, SmallVec<[ID; 4]>),
}
let mut taken_inst_idx_range = func.blocks[block_idx].instructions.len()..;
// Take `count` instructions, advancing backwards, but returning
// instructions in their original order (and decoded to `Inst`s).
let mut try_rev_take = |count| {
let maybe_rev_insts = (0..count).map(|_| {
let (i, inst) = non_debug_insts.next_back()?;
taken_inst_idx_range = i..;
// HACK(eddyb) all instructions accepted below
// are expected to take no more than 4 operands,
// and this is easier to use than an iterator.
let id_operands = inst
.operands
.iter()
.take_while(|inst| inst.class.opcode == Op::Variable)
.find(|inst| inst.result_id == Some(load_src_id));
if store_load_local_var.is_some() {
// Keep all instructions up to (but not including) the call.
.map(|operand| operand.id_ref_any())
.collect::<Option<SmallVec<[_; 4]>>>()?;
// Decode the instruction into one of our `Inst`s.
Some(
match (inst.class.opcode, inst.result_id, &id_operands[..]) {
(Op::Bitcast, Some(r), &[x]) => Inst::Bitcast(r, x),
(Op::AccessChain, Some(r), &[p, i]) => {
Inst::AccessChain(r, p, self.builder.lookup_const_by_id(i)?)
}
(Op::InBoundsAccessChain, Some(r), &[p, i]) => {
Inst::InBoundsAccessChain(
r,
p,
self.builder.lookup_const_by_id(i)?,
)
}
(Op::Store, None, &[p, v]) => Inst::Store(p, v),
(Op::Load, Some(r), &[p]) => Inst::Load(r, p),
(Op::FunctionCall, Some(r), [f, args @ ..]) => {
Inst::Call(r, *f, args.iter().copied().collect())
}
_ => return None,
},
)
});
let mut insts = maybe_rev_insts.collect::<Option<SmallVec<[_; 4]>>>()?;
insts.reverse();
Some(insts)
};
let (rt_args_slice_ptr_id, rt_args_count) = match try_rev_take(3)?[..] {
[
// HACK(eddyb) comment works around `rustfmt` array bugs.
Inst::Call(call_ret_id, callee_id, ref call_args),
Inst::Store(st_dst_id, st_val_id),
Inst::Load(ld_val_id, ld_src_id),
]
if self.fmt_args_new_fn_ids.borrow().contains(&callee_id)
&& call_ret_id == st_val_id
&& st_dst_id == ld_src_id
&& ld_val_id == format_args_id =>
{
require_local_var(st_dst_id)?;
match call_args[..] {
// `<core::fmt::Arguments>::new_v1`
[_, _, rt_args_slice_ptr_id, rt_args_len_id] => (
Some(rt_args_slice_ptr_id),
self.builder
.lookup_const_by_id(rt_args_len_id)
.and_then(|ct| match ct {
SpirvConst::U32(x) => Some(x as usize),
_ => None,
})?,
),
// `<core::fmt::Arguments>::new_const`
[_, _] => (None, 0),
_ => return None,
}
}
_ => return None,
};
// HACK(eddyb) this is the worst part: if we do have runtime
// arguments (from e.g. new `assert!`s being added to `core`),
// we have to confirm their many instructions for removal.
if rt_args_count > 0 {
let rt_args_slice_ptr_id = rt_args_slice_ptr_id.unwrap();
let rt_args_array_ptr_id = match try_rev_take(1)?[..] {
[Inst::Bitcast(out_id, in_id)] if out_id == rt_args_slice_ptr_id => in_id,
_ => return None,
};
require_local_var(rt_args_array_ptr_id);
// Each runtime argument has its own variable, 6 instructions
// to initialize it, and 9 instructions to copy it to the
// appropriate slot in the array. The groups of 6 and 9
// instructions, for all runtime args, are each separate.
let copies_from_rt_arg_vars_to_rt_args_array = try_rev_take(rt_args_count * 9)?;
let copies_from_rt_arg_vars_to_rt_args_array =
copies_from_rt_arg_vars_to_rt_args_array.chunks(9);
let inits_of_rt_arg_vars = try_rev_take(rt_args_count * 6)?;
let inits_of_rt_arg_vars = inits_of_rt_arg_vars.chunks(6);
for (
rt_arg_idx,
(init_of_rt_arg_var_insts, copy_from_rt_arg_var_to_rt_args_array_insts),
) in inits_of_rt_arg_vars
.zip(copies_from_rt_arg_vars_to_rt_args_array)
.enumerate()
{
let rt_arg_var_id = match init_of_rt_arg_var_insts[..] {
[
// HACK(eddyb) comment works around `rustfmt` array bugs.
Inst::Bitcast(b, _),
Inst::Bitcast(a, _),
Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)),
Inst::Store(a_st_dst, a_st_val),
Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)),
Inst::Store(b_st_dst, b_st_val),
] if a_base_ptr == b_base_ptr
&& (a, b) == (a_st_val, b_st_val)
&& (a_ptr, b_ptr) == (a_st_dst, b_st_dst) =>
{
require_local_var(a_base_ptr);
a_base_ptr
}
_ => return None,
};
// HACK(eddyb) this is only split to allow variable name reuse.
let (copy_loads, copy_stores) =
copy_from_rt_arg_var_to_rt_args_array_insts.split_at(4);
let (a, b) = match copy_loads[..] {
[
// HACK(eddyb) comment works around `rustfmt` array bugs.
Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)),
Inst::Load(a_ld_val, a_ld_src),
Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)),
Inst::Load(b_ld_val, b_ld_src),
] if [a_base_ptr, b_base_ptr] == [rt_arg_var_id; 2]
&& (a_ptr, b_ptr) == (a_ld_src, b_ld_src) =>
{
(a_ld_val, b_ld_val)
}
_ => return None,
};
match copy_stores[..] {
[
// HACK(eddyb) comment works around `rustfmt` array bugs.
Inst::InBoundsAccessChain(array_slot_ptr, array_base_ptr, SpirvConst::U32(array_idx)),
Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)),
Inst::Store(a_st_dst, a_st_val),
Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)),
Inst::Store(b_st_dst, b_st_val),
] if array_base_ptr == rt_args_array_ptr_id
&& array_idx as usize == rt_arg_idx
&& [a_base_ptr, b_base_ptr] == [array_slot_ptr; 2]
&& (a, b) == (a_st_val, b_st_val)
&& (a_ptr, b_ptr) == (a_st_dst, b_st_dst) =>
{
}
_ => return None,
}
}
}
// Keep all instructions up to (but not including) the last one
// confirmed above to be the first instruction of `format_args!`.
func.blocks[block_idx]
.instructions
.truncate(call_fmt_args_new_idx);
}
.truncate(taken_inst_idx_range.start);
None
};
remove_simple_format_args_if_possible();
remove_format_args_if_possible();
// HACK(eddyb) redirect any possible panic call to an abort, to avoid
// needing to materialize `&core::panic::Location` or `format_args!`.