mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-21 22:34:34 +00:00
builder: generalize the panic format_args!
remover to handle runtime args.
This commit is contained in:
parent
e9cdb9666b
commit
54d98c882f
@ -16,11 +16,13 @@ use rustc_codegen_ssa::traits::{
|
||||
BackendTypes, BuilderMethods, ConstMethods, IntrinsicCallMethods, LayoutTypeMethods, OverflowOp,
|
||||
};
|
||||
use rustc_codegen_ssa::MemFlags;
|
||||
use rustc_data_structures::fx::FxHashSet;
|
||||
use rustc_middle::bug;
|
||||
use rustc_middle::ty::Ty;
|
||||
use rustc_span::Span;
|
||||
use rustc_target::abi::call::FnAbi;
|
||||
use rustc_target::abi::{Abi, Align, Scalar, Size, WrappingRange};
|
||||
use smallvec::SmallVec;
|
||||
use std::convert::TryInto;
|
||||
use std::iter::{self, empty};
|
||||
|
||||
@ -2393,7 +2395,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
|
||||
// nor simplified in MIR (e.g. promoted to a constant) in any way,
|
||||
// so we have to try and remove the `fmt::Arguments::new` call here.
|
||||
// HACK(eddyb) this is basically a `try` block.
|
||||
let remove_simple_format_args_if_possible = || -> Option<()> {
|
||||
let remove_format_args_if_possible = || -> Option<()> {
|
||||
let format_args_id = match args {
|
||||
&[SpirvValue {
|
||||
kind: SpirvValueKind::Def(format_args_id),
|
||||
@ -2414,6 +2416,19 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
|
||||
let func_idx = builder.selected_function().unwrap();
|
||||
let block_idx = builder.selected_block().unwrap();
|
||||
let func = &mut builder.module_mut().functions[func_idx];
|
||||
|
||||
// HACK(eddyb) this is used to check that all `Op{Store,Load}`s
|
||||
// that may get removed, operate on local `OpVariable`s,
|
||||
// i.e. are not externally observable.
|
||||
let local_var_ids: FxHashSet<_> = func.blocks[0]
|
||||
.instructions
|
||||
.iter()
|
||||
.take_while(|inst| inst.class.opcode == Op::Variable)
|
||||
.map(|inst| inst.result_id.unwrap())
|
||||
.collect();
|
||||
let require_local_var =
|
||||
|ptr_id| Some(()).filter(|()| local_var_ids.contains(&ptr_id));
|
||||
|
||||
let mut non_debug_insts = func.blocks[block_idx]
|
||||
.instructions
|
||||
.iter()
|
||||
@ -2426,68 +2441,193 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
|
||||
.contains(&CustomOp::decode_from_ext_inst(inst));
|
||||
!(is_standard_debug || is_custom_debug)
|
||||
});
|
||||
let mut relevant_insts_next_back = |expected_op| {
|
||||
non_debug_insts
|
||||
.next_back()
|
||||
.filter(|(_, inst)| inst.class.opcode == expected_op)
|
||||
.map(|(i, inst)| {
|
||||
(
|
||||
i,
|
||||
inst.result_id,
|
||||
inst.operands.iter().map(|operand| operand.unwrap_id_ref()),
|
||||
)
|
||||
})
|
||||
};
|
||||
let (_, load_src_id) = relevant_insts_next_back(Op::Load)
|
||||
.map(|(_, result_id, mut operands)| {
|
||||
(result_id.unwrap(), operands.next().unwrap())
|
||||
})
|
||||
.filter(|&(result_id, _)| result_id == format_args_id)?;
|
||||
let (_, store_val_id) = relevant_insts_next_back(Op::Store)
|
||||
.map(|(_, _, mut operands)| {
|
||||
(operands.next().unwrap(), operands.next().unwrap())
|
||||
})
|
||||
.filter(|&(store_dst_id, _)| store_dst_id == load_src_id)?;
|
||||
let call_fmt_args_new_idx = relevant_insts_next_back(Op::FunctionCall)
|
||||
.filter(|&(_, result_id, _)| result_id == Some(store_val_id))
|
||||
.map(|(i, _, mut operands)| (i, operands.next().unwrap(), operands))
|
||||
.filter(|&(_, callee, _)| self.fmt_args_new_fn_ids.borrow().contains(&callee))
|
||||
.and_then(|(i, _, mut call_args)| {
|
||||
if call_args.len() == 4 {
|
||||
// `<core::fmt::Arguments>::new_v1`
|
||||
let mut arg = || call_args.next().unwrap();
|
||||
let [_, _, _, fmt_args_len_id] = [arg(), arg(), arg(), arg()];
|
||||
// Only ever remove `fmt::Arguments` with no runtime values.
|
||||
Some(i).filter(|_| {
|
||||
matches!(
|
||||
self.builder.lookup_const_by_id(fmt_args_len_id),
|
||||
Some(SpirvConst::U32(0))
|
||||
)
|
||||
})
|
||||
} else {
|
||||
// `<core::fmt::Arguments>::new_const`
|
||||
assert_eq!(call_args.len(), 2);
|
||||
Some(i)
|
||||
}
|
||||
})?;
|
||||
|
||||
// Lastly, ensure that the `Op{Store,Load}` pair operates on
|
||||
// a local `OpVariable`, i.e. is not externally observable.
|
||||
let store_load_local_var = func.blocks[0]
|
||||
.instructions
|
||||
// HACK(eddyb) to aid in pattern-matching, relevant instructions
|
||||
// are decoded to values of this `enum`. For instructions that
|
||||
// produce results, the result ID is the first `ID` value.
|
||||
#[derive(Debug)]
|
||||
enum Inst<'tcx, ID> {
|
||||
Bitcast(ID, ID),
|
||||
AccessChain(ID, ID, SpirvConst<'tcx>),
|
||||
InBoundsAccessChain(ID, ID, SpirvConst<'tcx>),
|
||||
Store(ID, ID),
|
||||
Load(ID, ID),
|
||||
Call(ID, ID, SmallVec<[ID; 4]>),
|
||||
}
|
||||
|
||||
let mut taken_inst_idx_range = func.blocks[block_idx].instructions.len()..;
|
||||
|
||||
// Take `count` instructions, advancing backwards, but returning
|
||||
// instructions in their original order (and decoded to `Inst`s).
|
||||
let mut try_rev_take = |count| {
|
||||
let maybe_rev_insts = (0..count).map(|_| {
|
||||
let (i, inst) = non_debug_insts.next_back()?;
|
||||
taken_inst_idx_range = i..;
|
||||
|
||||
// HACK(eddyb) all instructions accepted below
|
||||
// are expected to take no more than 4 operands,
|
||||
// and this is easier to use than an iterator.
|
||||
let id_operands = inst
|
||||
.operands
|
||||
.iter()
|
||||
.take_while(|inst| inst.class.opcode == Op::Variable)
|
||||
.find(|inst| inst.result_id == Some(load_src_id));
|
||||
if store_load_local_var.is_some() {
|
||||
// Keep all instructions up to (but not including) the call.
|
||||
.map(|operand| operand.id_ref_any())
|
||||
.collect::<Option<SmallVec<[_; 4]>>>()?;
|
||||
|
||||
// Decode the instruction into one of our `Inst`s.
|
||||
Some(
|
||||
match (inst.class.opcode, inst.result_id, &id_operands[..]) {
|
||||
(Op::Bitcast, Some(r), &[x]) => Inst::Bitcast(r, x),
|
||||
(Op::AccessChain, Some(r), &[p, i]) => {
|
||||
Inst::AccessChain(r, p, self.builder.lookup_const_by_id(i)?)
|
||||
}
|
||||
(Op::InBoundsAccessChain, Some(r), &[p, i]) => {
|
||||
Inst::InBoundsAccessChain(
|
||||
r,
|
||||
p,
|
||||
self.builder.lookup_const_by_id(i)?,
|
||||
)
|
||||
}
|
||||
(Op::Store, None, &[p, v]) => Inst::Store(p, v),
|
||||
(Op::Load, Some(r), &[p]) => Inst::Load(r, p),
|
||||
(Op::FunctionCall, Some(r), [f, args @ ..]) => {
|
||||
Inst::Call(r, *f, args.iter().copied().collect())
|
||||
}
|
||||
_ => return None,
|
||||
},
|
||||
)
|
||||
});
|
||||
let mut insts = maybe_rev_insts.collect::<Option<SmallVec<[_; 4]>>>()?;
|
||||
insts.reverse();
|
||||
Some(insts)
|
||||
};
|
||||
|
||||
let (rt_args_slice_ptr_id, rt_args_count) = match try_rev_take(3)?[..] {
|
||||
[
|
||||
// HACK(eddyb) comment works around `rustfmt` array bugs.
|
||||
Inst::Call(call_ret_id, callee_id, ref call_args),
|
||||
Inst::Store(st_dst_id, st_val_id),
|
||||
Inst::Load(ld_val_id, ld_src_id),
|
||||
]
|
||||
if self.fmt_args_new_fn_ids.borrow().contains(&callee_id)
|
||||
&& call_ret_id == st_val_id
|
||||
&& st_dst_id == ld_src_id
|
||||
&& ld_val_id == format_args_id =>
|
||||
{
|
||||
require_local_var(st_dst_id)?;
|
||||
match call_args[..] {
|
||||
// `<core::fmt::Arguments>::new_v1`
|
||||
[_, _, rt_args_slice_ptr_id, rt_args_len_id] => (
|
||||
Some(rt_args_slice_ptr_id),
|
||||
self.builder
|
||||
.lookup_const_by_id(rt_args_len_id)
|
||||
.and_then(|ct| match ct {
|
||||
SpirvConst::U32(x) => Some(x as usize),
|
||||
_ => None,
|
||||
})?,
|
||||
),
|
||||
|
||||
// `<core::fmt::Arguments>::new_const`
|
||||
[_, _] => (None, 0),
|
||||
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
// HACK(eddyb) this is the worst part: if we do have runtime
|
||||
// arguments (from e.g. new `assert!`s being added to `core`),
|
||||
// we have to confirm their many instructions for removal.
|
||||
if rt_args_count > 0 {
|
||||
let rt_args_slice_ptr_id = rt_args_slice_ptr_id.unwrap();
|
||||
let rt_args_array_ptr_id = match try_rev_take(1)?[..] {
|
||||
[Inst::Bitcast(out_id, in_id)] if out_id == rt_args_slice_ptr_id => in_id,
|
||||
_ => return None,
|
||||
};
|
||||
require_local_var(rt_args_array_ptr_id);
|
||||
|
||||
// Each runtime argument has its own variable, 6 instructions
|
||||
// to initialize it, and 9 instructions to copy it to the
|
||||
// appropriate slot in the array. The groups of 6 and 9
|
||||
// instructions, for all runtime args, are each separate.
|
||||
let copies_from_rt_arg_vars_to_rt_args_array = try_rev_take(rt_args_count * 9)?;
|
||||
let copies_from_rt_arg_vars_to_rt_args_array =
|
||||
copies_from_rt_arg_vars_to_rt_args_array.chunks(9);
|
||||
let inits_of_rt_arg_vars = try_rev_take(rt_args_count * 6)?;
|
||||
let inits_of_rt_arg_vars = inits_of_rt_arg_vars.chunks(6);
|
||||
|
||||
for (
|
||||
rt_arg_idx,
|
||||
(init_of_rt_arg_var_insts, copy_from_rt_arg_var_to_rt_args_array_insts),
|
||||
) in inits_of_rt_arg_vars
|
||||
.zip(copies_from_rt_arg_vars_to_rt_args_array)
|
||||
.enumerate()
|
||||
{
|
||||
let rt_arg_var_id = match init_of_rt_arg_var_insts[..] {
|
||||
[
|
||||
// HACK(eddyb) comment works around `rustfmt` array bugs.
|
||||
Inst::Bitcast(b, _),
|
||||
Inst::Bitcast(a, _),
|
||||
Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)),
|
||||
Inst::Store(a_st_dst, a_st_val),
|
||||
Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)),
|
||||
Inst::Store(b_st_dst, b_st_val),
|
||||
] if a_base_ptr == b_base_ptr
|
||||
&& (a, b) == (a_st_val, b_st_val)
|
||||
&& (a_ptr, b_ptr) == (a_st_dst, b_st_dst) =>
|
||||
{
|
||||
require_local_var(a_base_ptr);
|
||||
a_base_ptr
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
// HACK(eddyb) this is only split to allow variable name reuse.
|
||||
let (copy_loads, copy_stores) =
|
||||
copy_from_rt_arg_var_to_rt_args_array_insts.split_at(4);
|
||||
let (a, b) = match copy_loads[..] {
|
||||
[
|
||||
// HACK(eddyb) comment works around `rustfmt` array bugs.
|
||||
Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)),
|
||||
Inst::Load(a_ld_val, a_ld_src),
|
||||
Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)),
|
||||
Inst::Load(b_ld_val, b_ld_src),
|
||||
] if [a_base_ptr, b_base_ptr] == [rt_arg_var_id; 2]
|
||||
&& (a_ptr, b_ptr) == (a_ld_src, b_ld_src) =>
|
||||
{
|
||||
(a_ld_val, b_ld_val)
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
match copy_stores[..] {
|
||||
[
|
||||
// HACK(eddyb) comment works around `rustfmt` array bugs.
|
||||
Inst::InBoundsAccessChain(array_slot_ptr, array_base_ptr, SpirvConst::U32(array_idx)),
|
||||
Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)),
|
||||
Inst::Store(a_st_dst, a_st_val),
|
||||
Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)),
|
||||
Inst::Store(b_st_dst, b_st_val),
|
||||
] if array_base_ptr == rt_args_array_ptr_id
|
||||
&& array_idx as usize == rt_arg_idx
|
||||
&& [a_base_ptr, b_base_ptr] == [array_slot_ptr; 2]
|
||||
&& (a, b) == (a_st_val, b_st_val)
|
||||
&& (a_ptr, b_ptr) == (a_st_dst, b_st_dst) =>
|
||||
{
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keep all instructions up to (but not including) the last one
|
||||
// confirmed above to be the first instruction of `format_args!`.
|
||||
func.blocks[block_idx]
|
||||
.instructions
|
||||
.truncate(call_fmt_args_new_idx);
|
||||
}
|
||||
.truncate(taken_inst_idx_range.start);
|
||||
|
||||
None
|
||||
};
|
||||
remove_simple_format_args_if_possible();
|
||||
remove_format_args_if_possible();
|
||||
|
||||
// HACK(eddyb) redirect any possible panic call to an abort, to avoid
|
||||
// needing to materialize `&core::panic::Location` or `format_args!`.
|
||||
|
Loading…
Reference in New Issue
Block a user