Rollup merge of #119461 - cjgillot:jump-threading-interp, r=tmiasko

Use an interpreter in MIR jump threading

This allows to understand assignments of aggregate constants. This case appears more frequently with GVN promoting aggregates to constants.
This commit is contained in:
Nadrieril 2024-01-21 06:38:36 +01:00 committed by GitHub
commit 203cc6930e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 277 additions and 107 deletions

View File

@ -36,16 +36,21 @@
//! cost by `MAX_COST`.
use rustc_arena::DroplessArena;
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
use rustc_data_structures::fx::FxHashSet;
use rustc_index::bit_set::BitSet;
use rustc_index::IndexVec;
use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
use rustc_span::DUMMY_SP;
use rustc_target::abi::{TagEncoding, Variants};
use crate::cost_checker::CostChecker;
use crate::dataflow_const_prop::DummyMachine;
pub struct JumpThreading;
@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
let mut finder = TOFinder {
tcx,
param_env,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
body,
arena: &arena,
map: &map,
@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
debug!(?discr, ?bb);
let discr_ty = discr.ty(body, tcx).ty;
let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue };
let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
debug!(?discr);
@ -142,6 +148,7 @@ struct ThreadingOpportunity {
struct TOFinder<'tcx, 'a> {
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
body: &'a Body<'tcx>,
map: &'a Map,
loop_headers: &'a BitSet<BasicBlock>,
@ -329,11 +336,11 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
}
#[instrument(level = "trace", skip(self))]
fn process_operand(
fn process_immediate(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
rhs: ImmTy<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
let register_opportunity = |c: Condition| {
@ -341,13 +348,70 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
};
let conditions = state.try_get_idx(lhs, self.map)?;
if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
conditions.iter_matches(int).for_each(register_opportunity);
}
None
}
/// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
#[instrument(level = "trace", skip(self))]
fn process_constant(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
constant: OpTy<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
self.map.for_each_projection_value(
lhs,
constant,
&mut |elem, op| match elem {
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
TrackElem::Discriminant => {
let variant = self.ecx.read_discriminant(op).ok()?;
let discr_value =
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
Some(discr_value.into())
}
TrackElem::DerefLen => {
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
let len_usize = op.len(&self.ecx).ok()?;
let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
Some(ImmTy::from_uint(len_usize, layout).into())
}
},
&mut |place, op| {
if let Some(conditions) = state.try_get_idx(place, self.map)
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
&& let Some(imm) = imm.right()
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
{
conditions.iter_matches(int).for_each(|c: Condition| {
self.opportunities
.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
})
}
},
);
}
#[instrument(level = "trace", skip(self))]
fn process_operand(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
match rhs {
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Operand::Constant(constant) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let constant =
constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
conditions.iter_matches(constant).for_each(register_opportunity);
let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
self.process_constant(bb, lhs, constant, state);
}
// Transfer the conditions on the copied rhs.
Operand::Move(rhs) | Operand::Copy(rhs) => {
@ -359,6 +423,84 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
None
}
#[instrument(level = "trace", skip(self))]
fn process_assign(
&mut self,
bb: BasicBlock,
lhs_place: &Place<'tcx>,
rhs: &Rvalue<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
let lhs = self.map.find(lhs_place.as_ref())?;
match rhs {
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
// Transfer the conditions on the copy rhs.
Rvalue::CopyForDeref(rhs) => {
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
}
Rvalue::Discriminant(rhs) => {
let rhs = self.map.find_discr(rhs.as_ref())?;
state.insert_place_idx(rhs, lhs, self.map);
}
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Rvalue::Aggregate(box ref kind, ref operands) => {
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
let lhs = match kind {
// Do not support unions.
AggregateKind::Adt(.., Some(_)) => return None,
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
&& let Ok(discr_value) =
self.ecx.discriminant_for_variant(agg_ty, *variant_index)
{
self.process_immediate(bb, discr_target, discr_value, state);
}
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
}
_ => lhs,
};
for (field_index, operand) in operands.iter_enumerated() {
if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
self.process_operand(bb, field, operand, state);
}
}
}
// Transfer the conditions on the copy rhs, after inversing polarity.
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let conds = conditions.map(self.arena, Condition::inv);
state.insert_value_idx(place, conds, self.map);
}
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
// Create a condition on `rhs ?= B`.
Rvalue::BinaryOp(
op,
box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
| box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let equals = match op {
BinOp::Eq => ScalarInt::TRUE,
BinOp::Ne => ScalarInt::FALSE,
_ => return None,
};
let value = value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
let conds = conditions.map(self.arena, |c| Condition {
value,
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
..c
});
state.insert_value_idx(place, conds, self.map);
}
_ => {}
}
None
}
#[instrument(level = "trace", skip(self))]
fn process_statement(
&mut self,
@ -374,18 +516,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
// Below, `lhs` is the return value of `mutated_statement`,
// the place to which `conditions` apply.
let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
Some(Operand::const_from_scalar(
self.tcx,
discr.ty,
scalar.into(),
rustc_span::DUMMY_SP,
))
};
match &stmt.kind {
// If we expect `discriminant(place) ?= A`,
// we have an opportunity if `variant_index ?= A`.
@ -395,7 +525,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
// nothing.
let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
let enum_layout = self.ecx.layout_of(enum_ty).ok()?;
let writes_discriminant = match enum_layout.variants {
Variants::Single { index } => {
assert_eq!(index, *variant_index);
@ -408,8 +538,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
} => *variant_index != untagged_variant,
};
if writes_discriminant {
let discr = discriminant_for_variant(enum_ty, *variant_index)?;
self.process_operand(bb, discr_target, &discr, state)?;
let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
self.process_immediate(bb, discr_target, discr, state)?;
}
}
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
@ -420,89 +550,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
}
StatementKind::Assign(box (lhs_place, rhs)) => {
if let Some(lhs) = self.map.find(lhs_place.as_ref()) {
match rhs {
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
// Transfer the conditions on the copy rhs.
Rvalue::CopyForDeref(rhs) => {
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
}
Rvalue::Discriminant(rhs) => {
let rhs = self.map.find_discr(rhs.as_ref())?;
state.insert_place_idx(rhs, lhs, self.map);
}
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Rvalue::Aggregate(box ref kind, ref operands) => {
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
let lhs = match kind {
// Do not support unions.
AggregateKind::Adt(.., Some(_)) => return None,
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) =
self.map.apply(lhs, TrackElem::Discriminant)
&& let Some(discr_value) =
discriminant_for_variant(agg_ty, *variant_index)
{
self.process_operand(bb, discr_target, &discr_value, state);
}
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
}
_ => lhs,
};
for (field_index, operand) in operands.iter_enumerated() {
if let Some(field) =
self.map.apply(lhs, TrackElem::Field(field_index))
{
self.process_operand(bb, field, operand, state);
}
}
}
// Transfer the conditions on the copy rhs, after inversing polarity.
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let conds = conditions.map(self.arena, Condition::inv);
state.insert_value_idx(place, conds, self.map);
}
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
// Create a condition on `rhs ?= B`.
Rvalue::BinaryOp(
op,
box (
Operand::Move(place) | Operand::Copy(place),
Operand::Constant(value),
)
| box (
Operand::Constant(value),
Operand::Move(place) | Operand::Copy(place),
),
) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let equals = match op {
BinOp::Eq => ScalarInt::TRUE,
BinOp::Ne => ScalarInt::FALSE,
_ => return None,
};
let value = value
.const_
.normalize(self.tcx, self.param_env)
.try_to_scalar_int()?;
let conds = conditions.map(self.arena, |c| Condition {
value,
polarity: if c.matches(equals) {
Polarity::Eq
} else {
Polarity::Ne
},
..c
});
state.insert_value_idx(place, conds, self.map);
}
_ => {}
}
}
self.process_assign(bb, lhs_place, rhs, state)?;
}
_ => {}
}
@ -577,7 +625,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
let discr = discr.place()?;
let discr_ty = discr.ty(self.body, self.tcx).ty;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
let conditions = state.try_get(discr.as_ref(), self.map)?;
if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {

View File

@ -0,0 +1,52 @@
- // MIR for `aggregate` before JumpThreading
+ // MIR for `aggregate` after JumpThreading
fn aggregate(_1: u8) -> u8 {
debug x => _1;
let mut _0: u8;
let _2: u8;
let _3: u8;
let mut _4: (u8, u8);
let mut _5: bool;
let mut _6: u8;
scope 1 {
debug a => _2;
debug b => _3;
}
bb0: {
StorageLive(_4);
_4 = const _;
StorageLive(_2);
_2 = (_4.0: u8);
StorageLive(_3);
_3 = (_4.1: u8);
StorageDead(_4);
StorageLive(_5);
StorageLive(_6);
_6 = _2;
_5 = Eq(move _6, const 7_u8);
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
+ goto -> bb2;
}
bb1: {
StorageDead(_6);
_0 = _3;
goto -> bb3;
}
bb2: {
StorageDead(_6);
_0 = _2;
goto -> bb3;
}
bb3: {
StorageDead(_5);
StorageDead(_3);
StorageDead(_2);
return;
}
}

View File

@ -0,0 +1,52 @@
- // MIR for `aggregate` before JumpThreading
+ // MIR for `aggregate` after JumpThreading
fn aggregate(_1: u8) -> u8 {
debug x => _1;
let mut _0: u8;
let _2: u8;
let _3: u8;
let mut _4: (u8, u8);
let mut _5: bool;
let mut _6: u8;
scope 1 {
debug a => _2;
debug b => _3;
}
bb0: {
StorageLive(_4);
_4 = const _;
StorageLive(_2);
_2 = (_4.0: u8);
StorageLive(_3);
_3 = (_4.1: u8);
StorageDead(_4);
StorageLive(_5);
StorageLive(_6);
_6 = _2;
_5 = Eq(move _6, const 7_u8);
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
+ goto -> bb2;
}
bb1: {
StorageDead(_6);
_0 = _3;
goto -> bb3;
}
bb2: {
StorageDead(_6);
_0 = _2;
goto -> bb3;
}
bb3: {
StorageDead(_5);
StorageDead(_3);
StorageDead(_2);
return;
}
}

View File

@ -453,7 +453,23 @@ fn disappearing_bb(x: u8) -> u8 {
)
}
/// Verify that we can thread jumps when we assign from an aggregate constant.
fn aggregate(x: u8) -> u8 {
// CHECK-LABEL: fn aggregate(
// CHECK-NOT: switchInt(
const FOO: (u8, u8) = (5, 13);
let (a, b) = FOO;
if a == 7 {
b
} else {
a
}
}
fn main() {
// CHECK-LABEL: fn main(
too_complex(Ok(0));
identity(Ok(0));
custom_discr(false);
@ -464,6 +480,7 @@ fn main() {
mutable_ref();
renumbered_bb(true);
disappearing_bb(7);
aggregate(7);
}
// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
@ -476,3 +493,4 @@ fn main() {
// EMIT_MIR jump_threading.mutable_ref.JumpThreading.diff
// EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff
// EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff
// EMIT_MIR jump_threading.aggregate.JumpThreading.diff