mirror of
https://github.com/rust-lang/rust.git
synced 2024-10-30 14:01:51 +00:00
Rollup merge of #119461 - cjgillot:jump-threading-interp, r=tmiasko
Use an interpreter in MIR jump threading This allows to understand assignments of aggregate constants. This case appears more frequently with GVN promoting aggregates to constants.
This commit is contained in:
commit
203cc6930e
@ -36,16 +36,21 @@
|
||||
//! cost by `MAX_COST`.
|
||||
|
||||
use rustc_arena::DroplessArena;
|
||||
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
|
||||
use rustc_data_structures::fx::FxHashSet;
|
||||
use rustc_index::bit_set::BitSet;
|
||||
use rustc_index::IndexVec;
|
||||
use rustc_middle::mir::interpret::Scalar;
|
||||
use rustc_middle::mir::visit::Visitor;
|
||||
use rustc_middle::mir::*;
|
||||
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
|
||||
use rustc_middle::ty::layout::LayoutOf;
|
||||
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
|
||||
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
|
||||
use rustc_span::DUMMY_SP;
|
||||
use rustc_target::abi::{TagEncoding, Variants};
|
||||
|
||||
use crate::cost_checker::CostChecker;
|
||||
use crate::dataflow_const_prop::DummyMachine;
|
||||
|
||||
pub struct JumpThreading;
|
||||
|
||||
@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
|
||||
let mut finder = TOFinder {
|
||||
tcx,
|
||||
param_env,
|
||||
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
|
||||
body,
|
||||
arena: &arena,
|
||||
map: &map,
|
||||
@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
|
||||
debug!(?discr, ?bb);
|
||||
|
||||
let discr_ty = discr.ty(body, tcx).ty;
|
||||
let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
|
||||
let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue };
|
||||
|
||||
let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
|
||||
debug!(?discr);
|
||||
@ -142,6 +148,7 @@ struct ThreadingOpportunity {
|
||||
struct TOFinder<'tcx, 'a> {
|
||||
tcx: TyCtxt<'tcx>,
|
||||
param_env: ty::ParamEnv<'tcx>,
|
||||
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
|
||||
body: &'a Body<'tcx>,
|
||||
map: &'a Map,
|
||||
loop_headers: &'a BitSet<BasicBlock>,
|
||||
@ -329,11 +336,11 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(self))]
|
||||
fn process_operand(
|
||||
fn process_immediate(
|
||||
&mut self,
|
||||
bb: BasicBlock,
|
||||
lhs: PlaceIndex,
|
||||
rhs: &Operand<'tcx>,
|
||||
rhs: ImmTy<'tcx>,
|
||||
state: &mut State<ConditionSet<'a>>,
|
||||
) -> Option<!> {
|
||||
let register_opportunity = |c: Condition| {
|
||||
@ -341,13 +348,70 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
|
||||
};
|
||||
|
||||
let conditions = state.try_get_idx(lhs, self.map)?;
|
||||
if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
|
||||
conditions.iter_matches(int).for_each(register_opportunity);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
|
||||
#[instrument(level = "trace", skip(self))]
|
||||
fn process_constant(
|
||||
&mut self,
|
||||
bb: BasicBlock,
|
||||
lhs: PlaceIndex,
|
||||
constant: OpTy<'tcx>,
|
||||
state: &mut State<ConditionSet<'a>>,
|
||||
) {
|
||||
self.map.for_each_projection_value(
|
||||
lhs,
|
||||
constant,
|
||||
&mut |elem, op| match elem {
|
||||
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
|
||||
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
|
||||
TrackElem::Discriminant => {
|
||||
let variant = self.ecx.read_discriminant(op).ok()?;
|
||||
let discr_value =
|
||||
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
|
||||
Some(discr_value.into())
|
||||
}
|
||||
TrackElem::DerefLen => {
|
||||
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
|
||||
let len_usize = op.len(&self.ecx).ok()?;
|
||||
let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
|
||||
Some(ImmTy::from_uint(len_usize, layout).into())
|
||||
}
|
||||
},
|
||||
&mut |place, op| {
|
||||
if let Some(conditions) = state.try_get_idx(place, self.map)
|
||||
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
|
||||
&& let Some(imm) = imm.right()
|
||||
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
|
||||
{
|
||||
conditions.iter_matches(int).for_each(|c: Condition| {
|
||||
self.opportunities
|
||||
.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
|
||||
})
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(self))]
|
||||
fn process_operand(
|
||||
&mut self,
|
||||
bb: BasicBlock,
|
||||
lhs: PlaceIndex,
|
||||
rhs: &Operand<'tcx>,
|
||||
state: &mut State<ConditionSet<'a>>,
|
||||
) -> Option<!> {
|
||||
match rhs {
|
||||
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
|
||||
Operand::Constant(constant) => {
|
||||
let conditions = state.try_get_idx(lhs, self.map)?;
|
||||
let constant =
|
||||
constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
|
||||
conditions.iter_matches(constant).for_each(register_opportunity);
|
||||
let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
|
||||
self.process_constant(bb, lhs, constant, state);
|
||||
}
|
||||
// Transfer the conditions on the copied rhs.
|
||||
Operand::Move(rhs) | Operand::Copy(rhs) => {
|
||||
@ -359,6 +423,84 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||
None
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(self))]
|
||||
fn process_assign(
|
||||
&mut self,
|
||||
bb: BasicBlock,
|
||||
lhs_place: &Place<'tcx>,
|
||||
rhs: &Rvalue<'tcx>,
|
||||
state: &mut State<ConditionSet<'a>>,
|
||||
) -> Option<!> {
|
||||
let lhs = self.map.find(lhs_place.as_ref())?;
|
||||
match rhs {
|
||||
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
|
||||
// Transfer the conditions on the copy rhs.
|
||||
Rvalue::CopyForDeref(rhs) => {
|
||||
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
|
||||
}
|
||||
Rvalue::Discriminant(rhs) => {
|
||||
let rhs = self.map.find_discr(rhs.as_ref())?;
|
||||
state.insert_place_idx(rhs, lhs, self.map);
|
||||
}
|
||||
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
|
||||
Rvalue::Aggregate(box ref kind, ref operands) => {
|
||||
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
|
||||
let lhs = match kind {
|
||||
// Do not support unions.
|
||||
AggregateKind::Adt(.., Some(_)) => return None,
|
||||
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
|
||||
if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
|
||||
&& let Ok(discr_value) =
|
||||
self.ecx.discriminant_for_variant(agg_ty, *variant_index)
|
||||
{
|
||||
self.process_immediate(bb, discr_target, discr_value, state);
|
||||
}
|
||||
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
|
||||
}
|
||||
_ => lhs,
|
||||
};
|
||||
for (field_index, operand) in operands.iter_enumerated() {
|
||||
if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
|
||||
self.process_operand(bb, field, operand, state);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Transfer the conditions on the copy rhs, after inversing polarity.
|
||||
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
|
||||
let conditions = state.try_get_idx(lhs, self.map)?;
|
||||
let place = self.map.find(place.as_ref())?;
|
||||
let conds = conditions.map(self.arena, Condition::inv);
|
||||
state.insert_value_idx(place, conds, self.map);
|
||||
}
|
||||
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
|
||||
// Create a condition on `rhs ?= B`.
|
||||
Rvalue::BinaryOp(
|
||||
op,
|
||||
box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
|
||||
| box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
|
||||
) => {
|
||||
let conditions = state.try_get_idx(lhs, self.map)?;
|
||||
let place = self.map.find(place.as_ref())?;
|
||||
let equals = match op {
|
||||
BinOp::Eq => ScalarInt::TRUE,
|
||||
BinOp::Ne => ScalarInt::FALSE,
|
||||
_ => return None,
|
||||
};
|
||||
let value = value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
|
||||
let conds = conditions.map(self.arena, |c| Condition {
|
||||
value,
|
||||
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
|
||||
..c
|
||||
});
|
||||
state.insert_value_idx(place, conds, self.map);
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(self))]
|
||||
fn process_statement(
|
||||
&mut self,
|
||||
@ -374,18 +516,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||
// Below, `lhs` is the return value of `mutated_statement`,
|
||||
// the place to which `conditions` apply.
|
||||
|
||||
let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
|
||||
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
|
||||
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
|
||||
let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
|
||||
Some(Operand::const_from_scalar(
|
||||
self.tcx,
|
||||
discr.ty,
|
||||
scalar.into(),
|
||||
rustc_span::DUMMY_SP,
|
||||
))
|
||||
};
|
||||
|
||||
match &stmt.kind {
|
||||
// If we expect `discriminant(place) ?= A`,
|
||||
// we have an opportunity if `variant_index ?= A`.
|
||||
@ -395,7 +525,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
|
||||
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
|
||||
// nothing.
|
||||
let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
|
||||
let enum_layout = self.ecx.layout_of(enum_ty).ok()?;
|
||||
let writes_discriminant = match enum_layout.variants {
|
||||
Variants::Single { index } => {
|
||||
assert_eq!(index, *variant_index);
|
||||
@ -408,8 +538,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||
} => *variant_index != untagged_variant,
|
||||
};
|
||||
if writes_discriminant {
|
||||
let discr = discriminant_for_variant(enum_ty, *variant_index)?;
|
||||
self.process_operand(bb, discr_target, &discr, state)?;
|
||||
let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
|
||||
self.process_immediate(bb, discr_target, discr, state)?;
|
||||
}
|
||||
}
|
||||
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
|
||||
@ -420,89 +550,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||
conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
|
||||
}
|
||||
StatementKind::Assign(box (lhs_place, rhs)) => {
|
||||
if let Some(lhs) = self.map.find(lhs_place.as_ref()) {
|
||||
match rhs {
|
||||
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
|
||||
// Transfer the conditions on the copy rhs.
|
||||
Rvalue::CopyForDeref(rhs) => {
|
||||
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
|
||||
}
|
||||
Rvalue::Discriminant(rhs) => {
|
||||
let rhs = self.map.find_discr(rhs.as_ref())?;
|
||||
state.insert_place_idx(rhs, lhs, self.map);
|
||||
}
|
||||
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
|
||||
Rvalue::Aggregate(box ref kind, ref operands) => {
|
||||
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
|
||||
let lhs = match kind {
|
||||
// Do not support unions.
|
||||
AggregateKind::Adt(.., Some(_)) => return None,
|
||||
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
|
||||
if let Some(discr_target) =
|
||||
self.map.apply(lhs, TrackElem::Discriminant)
|
||||
&& let Some(discr_value) =
|
||||
discriminant_for_variant(agg_ty, *variant_index)
|
||||
{
|
||||
self.process_operand(bb, discr_target, &discr_value, state);
|
||||
}
|
||||
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
|
||||
}
|
||||
_ => lhs,
|
||||
};
|
||||
for (field_index, operand) in operands.iter_enumerated() {
|
||||
if let Some(field) =
|
||||
self.map.apply(lhs, TrackElem::Field(field_index))
|
||||
{
|
||||
self.process_operand(bb, field, operand, state);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Transfer the conditions on the copy rhs, after inversing polarity.
|
||||
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
|
||||
let conditions = state.try_get_idx(lhs, self.map)?;
|
||||
let place = self.map.find(place.as_ref())?;
|
||||
let conds = conditions.map(self.arena, Condition::inv);
|
||||
state.insert_value_idx(place, conds, self.map);
|
||||
}
|
||||
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
|
||||
// Create a condition on `rhs ?= B`.
|
||||
Rvalue::BinaryOp(
|
||||
op,
|
||||
box (
|
||||
Operand::Move(place) | Operand::Copy(place),
|
||||
Operand::Constant(value),
|
||||
)
|
||||
| box (
|
||||
Operand::Constant(value),
|
||||
Operand::Move(place) | Operand::Copy(place),
|
||||
),
|
||||
) => {
|
||||
let conditions = state.try_get_idx(lhs, self.map)?;
|
||||
let place = self.map.find(place.as_ref())?;
|
||||
let equals = match op {
|
||||
BinOp::Eq => ScalarInt::TRUE,
|
||||
BinOp::Ne => ScalarInt::FALSE,
|
||||
_ => return None,
|
||||
};
|
||||
let value = value
|
||||
.const_
|
||||
.normalize(self.tcx, self.param_env)
|
||||
.try_to_scalar_int()?;
|
||||
let conds = conditions.map(self.arena, |c| Condition {
|
||||
value,
|
||||
polarity: if c.matches(equals) {
|
||||
Polarity::Eq
|
||||
} else {
|
||||
Polarity::Ne
|
||||
},
|
||||
..c
|
||||
});
|
||||
state.insert_value_idx(place, conds, self.map);
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
self.process_assign(bb, lhs_place, rhs, state)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
@ -577,7 +625,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||
|
||||
let discr = discr.place()?;
|
||||
let discr_ty = discr.ty(self.body, self.tcx).ty;
|
||||
let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
|
||||
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
|
||||
let conditions = state.try_get(discr.as_ref(), self.map)?;
|
||||
|
||||
if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
|
||||
|
@ -0,0 +1,52 @@
|
||||
- // MIR for `aggregate` before JumpThreading
|
||||
+ // MIR for `aggregate` after JumpThreading
|
||||
|
||||
fn aggregate(_1: u8) -> u8 {
|
||||
debug x => _1;
|
||||
let mut _0: u8;
|
||||
let _2: u8;
|
||||
let _3: u8;
|
||||
let mut _4: (u8, u8);
|
||||
let mut _5: bool;
|
||||
let mut _6: u8;
|
||||
scope 1 {
|
||||
debug a => _2;
|
||||
debug b => _3;
|
||||
}
|
||||
|
||||
bb0: {
|
||||
StorageLive(_4);
|
||||
_4 = const _;
|
||||
StorageLive(_2);
|
||||
_2 = (_4.0: u8);
|
||||
StorageLive(_3);
|
||||
_3 = (_4.1: u8);
|
||||
StorageDead(_4);
|
||||
StorageLive(_5);
|
||||
StorageLive(_6);
|
||||
_6 = _2;
|
||||
_5 = Eq(move _6, const 7_u8);
|
||||
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
|
||||
+ goto -> bb2;
|
||||
}
|
||||
|
||||
bb1: {
|
||||
StorageDead(_6);
|
||||
_0 = _3;
|
||||
goto -> bb3;
|
||||
}
|
||||
|
||||
bb2: {
|
||||
StorageDead(_6);
|
||||
_0 = _2;
|
||||
goto -> bb3;
|
||||
}
|
||||
|
||||
bb3: {
|
||||
StorageDead(_5);
|
||||
StorageDead(_3);
|
||||
StorageDead(_2);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,52 @@
|
||||
- // MIR for `aggregate` before JumpThreading
|
||||
+ // MIR for `aggregate` after JumpThreading
|
||||
|
||||
fn aggregate(_1: u8) -> u8 {
|
||||
debug x => _1;
|
||||
let mut _0: u8;
|
||||
let _2: u8;
|
||||
let _3: u8;
|
||||
let mut _4: (u8, u8);
|
||||
let mut _5: bool;
|
||||
let mut _6: u8;
|
||||
scope 1 {
|
||||
debug a => _2;
|
||||
debug b => _3;
|
||||
}
|
||||
|
||||
bb0: {
|
||||
StorageLive(_4);
|
||||
_4 = const _;
|
||||
StorageLive(_2);
|
||||
_2 = (_4.0: u8);
|
||||
StorageLive(_3);
|
||||
_3 = (_4.1: u8);
|
||||
StorageDead(_4);
|
||||
StorageLive(_5);
|
||||
StorageLive(_6);
|
||||
_6 = _2;
|
||||
_5 = Eq(move _6, const 7_u8);
|
||||
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
|
||||
+ goto -> bb2;
|
||||
}
|
||||
|
||||
bb1: {
|
||||
StorageDead(_6);
|
||||
_0 = _3;
|
||||
goto -> bb3;
|
||||
}
|
||||
|
||||
bb2: {
|
||||
StorageDead(_6);
|
||||
_0 = _2;
|
||||
goto -> bb3;
|
||||
}
|
||||
|
||||
bb3: {
|
||||
StorageDead(_5);
|
||||
StorageDead(_3);
|
||||
StorageDead(_2);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -453,7 +453,23 @@ fn disappearing_bb(x: u8) -> u8 {
|
||||
)
|
||||
}
|
||||
|
||||
/// Verify that we can thread jumps when we assign from an aggregate constant.
|
||||
fn aggregate(x: u8) -> u8 {
|
||||
// CHECK-LABEL: fn aggregate(
|
||||
// CHECK-NOT: switchInt(
|
||||
|
||||
const FOO: (u8, u8) = (5, 13);
|
||||
|
||||
let (a, b) = FOO;
|
||||
if a == 7 {
|
||||
b
|
||||
} else {
|
||||
a
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// CHECK-LABEL: fn main(
|
||||
too_complex(Ok(0));
|
||||
identity(Ok(0));
|
||||
custom_discr(false);
|
||||
@ -464,6 +480,7 @@ fn main() {
|
||||
mutable_ref();
|
||||
renumbered_bb(true);
|
||||
disappearing_bb(7);
|
||||
aggregate(7);
|
||||
}
|
||||
|
||||
// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
|
||||
@ -476,3 +493,4 @@ fn main() {
|
||||
// EMIT_MIR jump_threading.mutable_ref.JumpThreading.diff
|
||||
// EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff
|
||||
// EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff
|
||||
// EMIT_MIR jump_threading.aggregate.JumpThreading.diff
|
||||
|
Loading…
Reference in New Issue
Block a user