Handle discriminants in dataflow-const-prop.

This commit is contained in:
Camille GILLOT 2023-01-21 22:28:54 +00:00
parent cd3649b2a5
commit 9a6c04f5d0
6 changed files with 311 additions and 65 deletions

View File

@ -1,6 +1,7 @@
#![feature(associated_type_defaults)]
#![feature(box_patterns)]
#![feature(exact_size_is_empty)]
#![feature(let_chains)]
#![feature(min_specialization)]
#![feature(once_cell)]
#![feature(stmt_expr_attributes)]

View File

@ -65,10 +65,8 @@ pub trait ValueAnalysis<'tcx> {
StatementKind::Assign(box (place, rvalue)) => {
self.handle_assign(*place, rvalue, state);
}
StatementKind::SetDiscriminant { .. } => {
// Could treat this as writing a constant to a pseudo-place.
// But discriminants are currently not tracked, so we do nothing.
// Related: https://github.com/rust-lang/unsafe-code-guidelines/issues/84
StatementKind::SetDiscriminant { box ref place, .. } => {
state.flood_discr(place.as_ref(), self.map());
}
StatementKind::Intrinsic(box intrinsic) => {
self.handle_intrinsic(intrinsic, state);
@ -447,26 +445,29 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
}
pub fn flood_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
if let Some(root) = map.find(place) {
self.flood_idx_with(root, map, value);
}
}
pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map) {
self.flood_with(place, map, V::top())
}
pub fn flood_idx_with(&mut self, place: PlaceIndex, map: &Map, value: V) {
let StateData::Reachable(values) = &mut self.0 else { return };
map.preorder_invoke(place, &mut |place| {
map.for_each_aliasing_place(place, None, &mut |place| {
if let Some(vi) = map.places[place].value_index {
values[vi] = value.clone();
}
});
}
pub fn flood_idx(&mut self, place: PlaceIndex, map: &Map) {
self.flood_idx_with(place, map, V::top())
pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map) {
self.flood_with(place, map, V::top())
}
pub fn flood_discr_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
let StateData::Reachable(values) = &mut self.0 else { return };
map.for_each_aliasing_place(place, Some(TrackElem::Discriminant), &mut |place| {
if let Some(vi) = map.places[place].value_index {
values[vi] = value.clone();
}
});
}
pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map) {
self.flood_discr_with(place, map, V::top())
}
/// Copies `source` to `target`, including all tracked places beneath.
@ -474,7 +475,9 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
/// If `target` contains a place that is not contained in `source`, it will be overwritten with
/// Top. Also, because this will copy all entries one after another, it may only be used for
/// places that are non-overlapping or identical.
pub fn assign_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
///
/// The target place must have been flooded before calling this method.
fn assign_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
let StateData::Reachable(values) = &mut self.0 else { return };
// If both places are tracked, we copy the value to the target. If the target is tracked,
@ -492,26 +495,28 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
let projection = map.places[target_child].proj_elem.unwrap();
if let Some(source_child) = map.projections.get(&(source, projection)) {
self.assign_place_idx(target_child, *source_child, map);
} else {
self.flood_idx(target_child, map);
}
}
}
pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
self.flood(target, map);
if let Some(target) = map.find(target) {
self.assign_idx(target, result, map);
} else {
// We don't track this place nor any projections, assignment can be ignored.
}
}
pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
self.flood_discr(target, map);
if let Some(target) = map.find_discr(target) {
self.assign_idx(target, result, map);
}
}
/// The target place must have been flooded before calling this method.
pub fn assign_idx(&mut self, target: PlaceIndex, result: ValueOrPlace<V>, map: &Map) {
match result {
ValueOrPlace::Value(value) => {
// First flood the target place in case we also track any projections (although
// this scenario is currently not well-supported by the API).
self.flood_idx(target, map);
let StateData::Reachable(values) = &mut self.0 else { return };
if let Some(value_index) = map.places[target].value_index {
values[value_index] = value;
@ -526,6 +531,14 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
map.find(place).map(|place| self.get_idx(place, map)).unwrap_or(V::top())
}
/// Retrieve the value stored for a place, or if it is not tracked.
pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V {
match map.find_discr(place) {
Some(place) => self.get_idx(place, map),
None => V::top(),
}
}
/// Retrieve the value stored for a place index, or if it is not tracked.
pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V {
match &self.0 {
@ -582,7 +595,6 @@ impl Map {
/// This is currently the only way to create a [`Map`]. The way in which the tracked places are
/// chosen is an implementation detail and may not be relied upon (other than that their type
/// passes the filter).
#[instrument(skip_all, level = "debug")]
pub fn from_filter<'tcx>(
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
@ -614,7 +626,7 @@ impl Map {
/// Potentially register the (local, projection) place and its fields, recursively.
///
/// Invariant: The projection must only contain fields.
/// Invariant: The projection must only contain trackable elements.
fn register_with_filter_rec<'tcx>(
&mut self,
tcx: TyCtxt<'tcx>,
@ -623,21 +635,46 @@ impl Map {
ty: Ty<'tcx>,
filter: &mut impl FnMut(Ty<'tcx>) -> bool,
) {
if filter(ty) {
// We know that the projection only contains trackable elements.
let place = self.make_place(local, projection).unwrap();
// We know that the projection only contains trackable elements.
let place = self.make_place(local, projection).unwrap();
// Allocate a value slot if it doesn't have one.
if self.places[place].value_index.is_none() {
self.places[place].value_index = Some(self.value_count.into());
self.value_count += 1;
// Allocate a value slot if it doesn't have one, and the user requested one.
if self.places[place].value_index.is_none() && filter(ty) {
self.places[place].value_index = Some(self.value_count.into());
self.value_count += 1;
}
if ty.is_enum() {
let discr_ty = ty.discriminant_ty(tcx);
if filter(discr_ty) {
let discr = *self
.projections
.entry((place, TrackElem::Discriminant))
.or_insert_with(|| {
// Prepend new child to the linked list.
let next = self.places.push(PlaceInfo::new(Some(TrackElem::Discriminant)));
self.places[next].next_sibling = self.places[place].first_child;
self.places[place].first_child = Some(next);
next
});
// Allocate a value slot if it doesn't have one.
if self.places[discr].value_index.is_none() {
self.places[discr].value_index = Some(self.value_count.into());
self.value_count += 1;
}
}
}
// Recurse with all fields of this place.
iter_fields(ty, tcx, |variant, field, ty| {
if variant.is_some() {
// Downcasts are currently not supported.
if let Some(variant) = variant {
projection.push(PlaceElem::Downcast(None, variant));
let _ = self.make_place(local, projection);
projection.push(PlaceElem::Field(field, ty));
self.register_with_filter_rec(tcx, local, projection, ty, filter);
projection.pop();
projection.pop();
return;
}
projection.push(PlaceElem::Field(field, ty));
@ -694,13 +731,77 @@ impl Map {
Some(index)
}
/// Locates the given place, if it exists in the tree.
pub fn find_discr(&self, place: PlaceRef<'_>) -> Option<PlaceIndex> {
let index = self.find(place)?;
self.apply(index, TrackElem::Discriminant)
}
/// Iterate over all direct children.
pub fn children(&self, parent: PlaceIndex) -> impl Iterator<Item = PlaceIndex> + '_ {
Children::new(self, parent)
}
/// Invoke a function on the given place and all places that may alias it.
///
/// In particular, when the given place has a variant downcast, we invoke the function on all
/// the other variants.
///
/// `tail_elem` allows to support discriminants that are not a place in MIR, but that we track
/// as such.
fn for_each_aliasing_place(
&self,
place: PlaceRef<'_>,
tail_elem: Option<TrackElem>,
f: &mut impl FnMut(PlaceIndex),
) {
let Some(&Some(mut index)) = self.locals.get(place.local) else {
// The local is not tracked at all, nothing to invalidate.
return;
};
let elems = place
.projection
.iter()
.map(|&elem| elem.try_into())
.chain(tail_elem.map(Ok).into_iter());
for elem in elems {
let Ok(elem) = elem else { return };
let sub = self.apply(index, elem);
if let TrackElem::Variant(..) | TrackElem::Discriminant = elem {
// Writing to an enum variant field invalidates the other variants and the discriminant.
self.for_each_variant_sibling(index, sub, f);
}
if let Some(sub) = sub {
index = sub
} else {
return;
}
}
self.preorder_invoke(index, f);
}
/// Invoke the given function on all the descendants of the given place, except one branch.
pub fn for_each_variant_sibling(
&self,
parent: PlaceIndex,
preserved_child: Option<PlaceIndex>,
f: &mut impl FnMut(PlaceIndex),
) {
for sibling in self.children(parent) {
let elem = self.places[sibling].proj_elem;
// Only invalidate variants and discriminant. Fields (for generators) are not
// invalidated by assignment to a variant.
if let Some(TrackElem::Variant(..) | TrackElem::Discriminant) = elem
// Only invalidate the other variants, the current one is fine.
&& Some(sibling) != preserved_child
{
self.preorder_invoke(sibling, f);
}
}
}
/// Invoke a function on the given place and all descendants.
pub fn preorder_invoke(&self, root: PlaceIndex, f: &mut impl FnMut(PlaceIndex)) {
fn preorder_invoke(&self, root: PlaceIndex, f: &mut impl FnMut(PlaceIndex)) {
f(root);
for child in self.children(root) {
self.preorder_invoke(child, f);
@ -759,6 +860,7 @@ impl<'a> Iterator for Children<'a> {
}
/// Used as the result of an operand or r-value.
#[derive(Debug)]
pub enum ValueOrPlace<V> {
Value(V),
Place(PlaceIndex),
@ -776,6 +878,8 @@ impl<V: HasTop> ValueOrPlace<V> {
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum TrackElem {
Field(Field),
Variant(VariantIdx),
Discriminant,
}
impl<V, T> TryFrom<ProjectionElem<V, T>> for TrackElem {
@ -784,6 +888,7 @@ impl<V, T> TryFrom<ProjectionElem<V, T>> for TrackElem {
fn try_from(value: ProjectionElem<V, T>) -> Result<Self, Self::Error> {
match value {
ProjectionElem::Field(field, _) => Ok(TrackElem::Field(field)),
ProjectionElem::Downcast(_, idx) => Ok(TrackElem::Variant(idx)),
_ => Err(()),
}
}
@ -900,6 +1005,12 @@ fn debug_with_context_rec<V: Debug + Eq>(
for child in map.children(place) {
let info_elem = map.places[child].proj_elem.unwrap();
let child_place_str = match info_elem {
TrackElem::Discriminant => {
format!("discriminant({})", place_str)
}
TrackElem::Variant(idx) => {
format!("({} as {:?})", place_str, idx)
}
TrackElem::Field(field) => {
if place_str.starts_with('*') {
format!("({}).{}", place_str, field.index())

View File

@ -13,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects};
use rustc_span::DUMMY_SP;
use rustc_target::abi::Align;
use rustc_target::abi::VariantIdx;
use crate::MirPass;
@ -30,6 +31,7 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
#[instrument(skip_all level = "debug")]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
debug!(def_id = ?body.source.def_id());
if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT {
debug!("aborted dataflow const prop due too many basic blocks");
return;
@ -63,14 +65,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
}
}
struct ConstAnalysis<'tcx> {
struct ConstAnalysis<'a, 'tcx> {
map: Map,
tcx: TyCtxt<'tcx>,
local_decls: &'a LocalDecls<'tcx>,
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
param_env: ty::ParamEnv<'tcx>,
}
impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
impl<'tcx> ConstAnalysis<'_, 'tcx> {
fn eval_disciminant(
&self,
enum_ty: Ty<'tcx>,
variant_index: VariantIdx,
) -> Option<ScalarTy<'tcx>> {
if !enum_ty.is_enum() {
return None;
}
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?;
Some(ScalarTy(discr_value, discr.ty))
}
}
impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
type Value = FlatSet<ScalarTy<'tcx>>;
const NAME: &'static str = "ConstAnalysis";
@ -79,6 +98,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
&self.map
}
fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State<Self::Value>) {
match statement.kind {
StatementKind::SetDiscriminant { box ref place, variant_index } => {
state.flood_discr(place.as_ref(), &self.map);
if self.map.find_discr(place.as_ref()).is_some() {
let enum_ty = place.ty(self.local_decls, self.tcx).ty;
if let Some(discr) = self.eval_disciminant(enum_ty, variant_index) {
state.assign_discr(
place.as_ref(),
ValueOrPlace::Value(FlatSet::Elem(discr)),
&self.map,
);
}
}
}
_ => self.super_statement(statement, state),
}
}
fn handle_assign(
&self,
target: Place<'tcx>,
@ -87,17 +125,22 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
) {
match rvalue {
Rvalue::Aggregate(kind, operands) => {
let target = self.map().find(target.as_ref());
if let Some(target) = target {
state.flood_idx_with(target, self.map(), FlatSet::Bottom);
let field_based = match **kind {
AggregateKind::Tuple | AggregateKind::Closure(..) => true,
AggregateKind::Adt(def_id, ..) => {
matches!(self.tcx.def_kind(def_id), DefKind::Struct)
state.flood_with(target.as_ref(), self.map(), FlatSet::Bottom);
if let Some(target_idx) = self.map().find(target.as_ref()) {
let (variant_target, variant_index) = match **kind {
AggregateKind::Tuple | AggregateKind::Closure(..) => {
(Some(target_idx), None)
}
_ => false,
AggregateKind::Adt(def_id, variant_index, ..) => {
match self.tcx.def_kind(def_id) {
DefKind::Struct => (Some(target_idx), None),
DefKind::Enum => (Some(target_idx), Some(variant_index)),
_ => (None, None),
}
}
_ => (None, None),
};
if field_based {
if let Some(target) = variant_target {
for (field_index, operand) in operands.iter().enumerate() {
if let Some(field) = self
.map()
@ -108,15 +151,20 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
}
}
}
if let Some(variant_index) = variant_index
&& let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant)
{
let enum_ty = target.ty(self.local_decls, self.tcx).ty;
if let Some(discr_val) = self.eval_disciminant(enum_ty, variant_index) {
state.assign_idx(discr_idx, ValueOrPlace::Value(FlatSet::Elem(discr_val)), &self.map);
}
}
}
}
Rvalue::CheckedBinaryOp(op, box (left, right)) => {
state.flood(target.as_ref(), self.map());
let target = self.map().find(target.as_ref());
if let Some(target) = target {
// We should not track any projections other than
// what is overwritten below, but just in case...
state.flood_idx(target, self.map());
}
let value_target = target
.and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into())));
@ -195,6 +243,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom),
FlatSet::Top => ValueOrPlace::Value(FlatSet::Top),
},
Rvalue::Discriminant(place) => {
ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map()))
}
_ => self.super_rvalue(rvalue, state),
}
}
@ -268,12 +319,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> {
}
}
impl<'tcx> ConstAnalysis<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self {
impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self {
let param_env = tcx.param_env(body.source.def_id());
Self {
map,
tcx,
local_decls: &body.local_decls,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
param_env: param_env,
}
@ -466,6 +518,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> {
_ => (),
}
}
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
match rvalue {
Rvalue::Discriminant(place) => {
match self.state.get_discr(place.as_ref(), self.visitor.map) {
FlatSet::Top => (),
FlatSet::Elem(value) => {
self.visitor.before_effect.insert((location, *place), value);
}
FlatSet::Bottom => (),
}
}
_ => self.super_rvalue(rvalue, location),
}
}
}
struct DummyMachine;

View File

@ -0,0 +1,26 @@
- // MIR for `mutate_discriminant` before DataflowConstProp
+ // MIR for `mutate_discriminant` after DataflowConstProp
fn mutate_discriminant() -> u8 {
let mut _0: u8; // return place in scope 0 at $DIR/enum.rs:+0:29: +0:31
let mut _1: std::option::Option<NonZeroUsize>; // in scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
let mut _2: isize; // in scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
bb0: {
discriminant(_1) = 1; // scope 0 at $DIR/enum.rs:+4:13: +4:34
(((_1 as variant#1).0: NonZeroUsize).0: usize) = const 0_usize; // scope 0 at $DIR/enum.rs:+6:13: +6:64
_2 = discriminant(_1); // scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
switchInt(_2) -> [0: bb1, otherwise: bb2]; // scope 0 at $DIR/enum.rs:+9:13: +12:14
}
bb1: {
_0 = const 1_u8; // scope 0 at $DIR/enum.rs:+15:13: +15:20
return; // scope 0 at $DIR/enum.rs:+16:13: +16:21
}
bb2: {
_0 = const 2_u8; // scope 0 at $DIR/enum.rs:+19:13: +19:20
unreachable; // scope 0 at $DIR/enum.rs:+20:13: +20:26
}
}

View File

@ -1,13 +1,52 @@
// unit-test: DataflowConstProp
// Not trackable, because variants could be aliased.
#![feature(custom_mir, core_intrinsics, rustc_attrs)]
use std::intrinsics::mir::*;
enum E {
V1(i32),
V2(i32)
}
// EMIT_MIR enum.main.DataflowConstProp.diff
fn main() {
// EMIT_MIR enum.simple.DataflowConstProp.diff
fn simple() {
let e = E::V1(0);
let x = match e { E::V1(x) => x, E::V2(x) => x };
}
#[rustc_layout_scalar_valid_range_start(1)]
#[rustc_nonnull_optimization_guaranteed]
struct NonZeroUsize(usize);
// EMIT_MIR enum.mutate_discriminant.DataflowConstProp.diff
#[custom_mir(dialect = "runtime", phase = "post-cleanup")]
fn mutate_discriminant() -> u8 {
mir!(
let x: Option<NonZeroUsize>;
{
SetDiscriminant(x, 1);
// This assignment overwrites the niche in which the discriminant is stored.
place!(Field(Field(Variant(x, 1), 0), 0)) = 0_usize;
// So we cannot know the value of this discriminant.
let a = Discriminant(x);
match a {
0 => bb1,
_ => bad,
}
}
bb1 = {
RET = 1;
Return()
}
bad = {
RET = 2;
Unreachable()
}
)
}
fn main() {
simple();
mutate_discriminant();
}

View File

@ -1,8 +1,8 @@
- // MIR for `main` before DataflowConstProp
+ // MIR for `main` after DataflowConstProp
- // MIR for `simple` before DataflowConstProp
+ // MIR for `simple` after DataflowConstProp
fn main() -> () {
let mut _0: (); // return place in scope 0 at $DIR/enum.rs:+0:11: +0:11
fn simple() -> () {
let mut _0: (); // return place in scope 0 at $DIR/enum.rs:+0:13: +0:13
let _1: E; // in scope 0 at $DIR/enum.rs:+1:9: +1:10
let mut _3: isize; // in scope 0 at $DIR/enum.rs:+2:23: +2:31
scope 1 {
@ -25,8 +25,10 @@
StorageLive(_1); // scope 0 at $DIR/enum.rs:+1:9: +1:10
_1 = E::V1(const 0_i32); // scope 0 at $DIR/enum.rs:+1:13: +1:21
StorageLive(_2); // scope 1 at $DIR/enum.rs:+2:9: +2:10
_3 = discriminant(_1); // scope 1 at $DIR/enum.rs:+2:19: +2:20
switchInt(move _3) -> [0: bb3, 1: bb1, otherwise: bb2]; // scope 1 at $DIR/enum.rs:+2:13: +2:20
- _3 = discriminant(_1); // scope 1 at $DIR/enum.rs:+2:19: +2:20
- switchInt(move _3) -> [0: bb3, 1: bb1, otherwise: bb2]; // scope 1 at $DIR/enum.rs:+2:13: +2:20
+ _3 = const 0_isize; // scope 1 at $DIR/enum.rs:+2:19: +2:20
+ switchInt(const 0_isize) -> [0: bb3, 1: bb1, otherwise: bb2]; // scope 1 at $DIR/enum.rs:+2:13: +2:20
}
bb1: {
@ -50,7 +52,7 @@
}
bb4: {
_0 = const (); // scope 0 at $DIR/enum.rs:+0:11: +3:2
_0 = const (); // scope 0 at $DIR/enum.rs:+0:13: +3:2
StorageDead(_2); // scope 1 at $DIR/enum.rs:+3:1: +3:2
StorageDead(_1); // scope 0 at $DIR/enum.rs:+3:1: +3:2
return; // scope 0 at $DIR/enum.rs:+3:2: +3:2