From dc4fe8e2953747605b085d297cf329824d499884 Mon Sep 17 00:00:00 2001 From: Camille GILLOT Date: Sun, 5 Feb 2023 09:31:27 +0000 Subject: [PATCH] Make SROA expand assignments. --- compiler/rustc_mir_transform/src/sroa.rs | 88 ++++++++++++++----- ....copies.ScalarReplacementOfAggregates.diff | 48 ++++++++++ ...scaping.ScalarReplacementOfAggregates.diff | 4 +- ...oa.flat.ScalarReplacementOfAggregates.diff | 2 +- ..._copies.ScalarReplacementOfAggregates.diff | 48 ++++++++++ tests/mir-opt/sroa.rs | 30 +++++-- 6 files changed, 188 insertions(+), 32 deletions(-) create mode 100644 tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff create mode 100644 tests/mir-opt/sroa.ref_copies.ScalarReplacementOfAggregates.diff diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 2118e3c5522..f6609704d25 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -78,10 +78,15 @@ fn escaping_locals(body: &Body<'_>) -> BitSet { rvalue: &Rvalue<'tcx>, location: Location, ) { - if lvalue.as_local().is_some() && let Rvalue::Aggregate(..) = rvalue { - // Aggregate assignments are expanded in run_pass. - self.visit_rvalue(rvalue, location); - return; + if lvalue.as_local().is_some() { + match rvalue { + // Aggregate assignments are expanded in run_pass. + Rvalue::Aggregate(..) | Rvalue::Use(..) => { + self.visit_rvalue(rvalue, location); + return; + } + _ => {} + } } self.super_assign(lvalue, rvalue, location) } @@ -195,10 +200,9 @@ fn replace_flattened_locals<'tcx>( return; } - let mut fragments = IndexVec::new(); + let mut fragments = IndexVec::<_, Option>>::from_elem(None, &body.local_decls); for (k, v) in &replacements.fields { - fragments.ensure_contains_elem(k.local, || Vec::new()); - fragments[k.local].push((k.projection, *v)); + fragments[k.local].get_or_insert_default().push((k.projection, *v)); } debug!(?fragments); @@ -235,7 +239,7 @@ struct ReplacementVisitor<'tcx, 'll> { all_dead_locals: BitSet, /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage /// and deinit statement and debuginfo. - fragments: IndexVec], Local)>>, + fragments: IndexVec], Local)>>>, patch: MirPatch<'tcx>, } @@ -243,9 +247,9 @@ impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> { fn gather_debug_info_fragments( &self, place: PlaceRef<'tcx>, - ) -> Vec> { + ) -> Option>> { let mut fragments = Vec::new(); - let parts = &self.fragments[place.local]; + let Some(parts) = &self.fragments[place.local] else { return None }; for (proj, replacement_local) in parts { if proj.starts_with(place.projection) { fragments.push(VarDebugInfoFragment { @@ -254,7 +258,7 @@ impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> { }); } } - fragments + Some(fragments) } fn replace_place(&self, place: PlaceRef<'tcx>) -> Option> { @@ -276,8 +280,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { match statement.kind { StatementKind::StorageLive(l) => { - if self.all_dead_locals.contains(l) { - let final_locals = &self.fragments[l]; + if let Some(final_locals) = &self.fragments[l] { for &(_, fl) in final_locals { self.patch.add_statement(location, StatementKind::StorageLive(fl)); } @@ -286,8 +289,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { return; } StatementKind::StorageDead(l) => { - if self.all_dead_locals.contains(l) { - let final_locals = &self.fragments[l]; + if let Some(final_locals) = &self.fragments[l] { for &(_, fl) in final_locals { self.patch.add_statement(location, StatementKind::StorageDead(fl)); } @@ -297,9 +299,8 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { } StatementKind::Deinit(box ref place) => { if let Some(local) = place.as_local() - && self.all_dead_locals.contains(local) + && let Some(final_locals) = &self.fragments[local] { - let final_locals = &self.fragments[local]; for &(_, fl) in final_locals { self.patch.add_statement( location, @@ -313,9 +314,8 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { StatementKind::Assign(box (ref place, Rvalue::Aggregate(_, ref operands))) => { if let Some(local) = place.as_local() - && self.all_dead_locals.contains(local) + && let Some(final_locals) = &self.fragments[local] { - let final_locals = &self.fragments[local]; for &(projection, fl) in final_locals { let &[PlaceElem::Field(index, _)] = projection else { bug!() }; let index = index.as_usize(); @@ -330,6 +330,48 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { } } + StatementKind::Assign(box (ref place, Rvalue::Use(Operand::Constant(_)))) => { + if let Some(local) = place.as_local() + && let Some(final_locals) = &self.fragments[local] + { + for &(projection, fl) in final_locals { + let rvalue = Rvalue::Use(Operand::Move(place.project_deeper(projection, self.tcx))); + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((fl.into(), rvalue))), + ); + } + self.all_dead_locals.remove(local); + return; + } + } + + StatementKind::Assign(box (ref lhs, Rvalue::Use(ref op))) => { + let (rplace, copy) = match op { + Operand::Copy(rplace) => (rplace, true), + Operand::Move(rplace) => (rplace, false), + Operand::Constant(_) => bug!(), + }; + if let Some(local) = lhs.as_local() + && let Some(final_locals) = &self.fragments[local] + { + for &(projection, fl) in final_locals { + let rplace = rplace.project_deeper(projection, self.tcx); + let rvalue = if copy { + Rvalue::Use(Operand::Copy(rplace)) + } else { + Rvalue::Use(Operand::Move(rplace)) + }; + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((fl.into(), rvalue))), + ); + } + statement.make_nop(); + return; + } + } + _ => {} } self.super_statement(statement, location) @@ -348,9 +390,8 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { VarDebugInfoContents::Place(ref mut place) => { if let Some(repl) = self.replace_place(place.as_ref()) { *place = repl; - } else if self.all_dead_locals.contains(place.local) { + } else if let Some(fragments) = self.gather_debug_info_fragments(place.as_ref()) { let ty = place.ty(self.local_decls, self.tcx).ty; - let fragments = self.gather_debug_info_fragments(place.as_ref()); var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments }; } } @@ -361,8 +402,9 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { if let Some(repl) = self.replace_place(fragment.contents.as_ref()) { fragment.contents = repl; true - } else if self.all_dead_locals.contains(fragment.contents.local) { - let frg = self.gather_debug_info_fragments(fragment.contents.as_ref()); + } else if let Some(frg) = + self.gather_debug_info_fragments(fragment.contents.as_ref()) + { new_fragments.extend(frg.into_iter().map(|mut f| { f.projection.splice(0..0, fragment.projection.iter().copied()); f diff --git a/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff b/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff new file mode 100644 index 00000000000..72610de8eaf --- /dev/null +++ b/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff @@ -0,0 +1,48 @@ +- // MIR for `copies` before ScalarReplacementOfAggregates ++ // MIR for `copies` after ScalarReplacementOfAggregates + + fn copies(_1: Foo) -> () { + debug x => _1; // in scope 0 at $DIR/sroa.rs:+0:11: +0:12 + let mut _0: (); // return place in scope 0 at $DIR/sroa.rs:+0:19: +0:19 + let _2: Foo; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ let _5: u8; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ let _6: &str; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 + scope 1 { +- debug y => _2; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10 ++ debug y => Foo{ .0 => _5, .2 => _6, }; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10 + let _3: u8; // in scope 1 at $DIR/sroa.rs:+2:9: +2:10 + scope 2 { + debug t => _3; // in scope 2 at $DIR/sroa.rs:+2:9: +2:10 + let _4: &str; // in scope 2 at $DIR/sroa.rs:+3:9: +3:10 + scope 3 { + debug u => _4; // in scope 3 at $DIR/sroa.rs:+3:9: +3:10 + } + } + } + + bb0: { +- StorageLive(_2); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 +- _2 = _1; // scope 0 at $DIR/sroa.rs:+1:13: +1:14 ++ StorageLive(_5); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ StorageLive(_6); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ nop; // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ _5 = (_1.0: u8); // scope 0 at $DIR/sroa.rs:+1:13: +1:14 ++ _6 = (_1.2: &str); // scope 0 at $DIR/sroa.rs:+1:13: +1:14 ++ nop; // scope 0 at $DIR/sroa.rs:+1:13: +1:14 + StorageLive(_3); // scope 1 at $DIR/sroa.rs:+2:9: +2:10 +- _3 = (_2.0: u8); // scope 1 at $DIR/sroa.rs:+2:13: +2:16 ++ _3 = _5; // scope 1 at $DIR/sroa.rs:+2:13: +2:16 + StorageLive(_4); // scope 2 at $DIR/sroa.rs:+3:9: +3:10 +- _4 = (_2.2: &str); // scope 2 at $DIR/sroa.rs:+3:13: +3:16 ++ _4 = _6; // scope 2 at $DIR/sroa.rs:+3:13: +3:16 + _0 = const (); // scope 0 at $DIR/sroa.rs:+0:19: +4:2 + StorageDead(_4); // scope 2 at $DIR/sroa.rs:+4:1: +4:2 + StorageDead(_3); // scope 1 at $DIR/sroa.rs:+4:1: +4:2 +- StorageDead(_2); // scope 0 at $DIR/sroa.rs:+4:1: +4:2 ++ StorageDead(_5); // scope 0 at $DIR/sroa.rs:+4:1: +4:2 ++ StorageDead(_6); // scope 0 at $DIR/sroa.rs:+4:1: +4:2 ++ nop; // scope 0 at $DIR/sroa.rs:+4:1: +4:2 + return; // scope 0 at $DIR/sroa.rs:+4:2: +4:2 + } + } + diff --git a/tests/mir-opt/sroa.escaping.ScalarReplacementOfAggregates.diff b/tests/mir-opt/sroa.escaping.ScalarReplacementOfAggregates.diff index b01fb6fc915..ea7f5007224 100644 --- a/tests/mir-opt/sroa.escaping.ScalarReplacementOfAggregates.diff +++ b/tests/mir-opt/sroa.escaping.ScalarReplacementOfAggregates.diff @@ -17,7 +17,7 @@ StorageLive(_5); // scope 0 at $DIR/sroa.rs:+2:34: +2:37 _5 = g() -> bb1; // scope 0 at $DIR/sroa.rs:+2:34: +2:37 // mir::Constant - // + span: $DIR/sroa.rs:78:34: 78:35 + // + span: $DIR/sroa.rs:73:34: 73:35 // + literal: Const { ty: fn() -> u32 {g}, val: Value() } } @@ -28,7 +28,7 @@ _2 = &raw const (*_3); // scope 0 at $DIR/sroa.rs:+2:7: +2:41 _1 = f(move _2) -> bb2; // scope 0 at $DIR/sroa.rs:+2:5: +2:42 // mir::Constant - // + span: $DIR/sroa.rs:78:5: 78:6 + // + span: $DIR/sroa.rs:73:5: 73:6 // + literal: Const { ty: fn(*const u32) {f}, val: Value() } } diff --git a/tests/mir-opt/sroa.flat.ScalarReplacementOfAggregates.diff b/tests/mir-opt/sroa.flat.ScalarReplacementOfAggregates.diff index 338ce262f1e..69631fc0213 100644 --- a/tests/mir-opt/sroa.flat.ScalarReplacementOfAggregates.diff +++ b/tests/mir-opt/sroa.flat.ScalarReplacementOfAggregates.diff @@ -45,7 +45,7 @@ + _9 = move _6; // scope 0 at $DIR/sroa.rs:+1:30: +1:70 + _10 = const "a"; // scope 0 at $DIR/sroa.rs:+1:30: +1:70 // mir::Constant - // + span: $DIR/sroa.rs:57:52: 57:55 + // + span: $DIR/sroa.rs:53:52: 53:55 // + literal: Const { ty: &str, val: Value(Slice(..)) } + _11 = move _7; // scope 0 at $DIR/sroa.rs:+1:30: +1:70 + nop; // scope 0 at $DIR/sroa.rs:+1:30: +1:70 diff --git a/tests/mir-opt/sroa.ref_copies.ScalarReplacementOfAggregates.diff b/tests/mir-opt/sroa.ref_copies.ScalarReplacementOfAggregates.diff new file mode 100644 index 00000000000..1a561a9edde --- /dev/null +++ b/tests/mir-opt/sroa.ref_copies.ScalarReplacementOfAggregates.diff @@ -0,0 +1,48 @@ +- // MIR for `ref_copies` before ScalarReplacementOfAggregates ++ // MIR for `ref_copies` after ScalarReplacementOfAggregates + + fn ref_copies(_1: &Foo) -> () { + debug x => _1; // in scope 0 at $DIR/sroa.rs:+0:15: +0:16 + let mut _0: (); // return place in scope 0 at $DIR/sroa.rs:+0:24: +0:24 + let _2: Foo; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ let _5: u8; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ let _6: &str; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 + scope 1 { +- debug y => _2; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10 ++ debug y => Foo{ .0 => _5, .2 => _6, }; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10 + let _3: u8; // in scope 1 at $DIR/sroa.rs:+2:9: +2:10 + scope 2 { + debug t => _3; // in scope 2 at $DIR/sroa.rs:+2:9: +2:10 + let _4: &str; // in scope 2 at $DIR/sroa.rs:+3:9: +3:10 + scope 3 { + debug u => _4; // in scope 3 at $DIR/sroa.rs:+3:9: +3:10 + } + } + } + + bb0: { +- StorageLive(_2); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 +- _2 = (*_1); // scope 0 at $DIR/sroa.rs:+1:13: +1:15 ++ StorageLive(_5); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ StorageLive(_6); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ nop; // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ _5 = ((*_1).0: u8); // scope 0 at $DIR/sroa.rs:+1:13: +1:15 ++ _6 = ((*_1).2: &str); // scope 0 at $DIR/sroa.rs:+1:13: +1:15 ++ nop; // scope 0 at $DIR/sroa.rs:+1:13: +1:15 + StorageLive(_3); // scope 1 at $DIR/sroa.rs:+2:9: +2:10 +- _3 = (_2.0: u8); // scope 1 at $DIR/sroa.rs:+2:13: +2:16 ++ _3 = _5; // scope 1 at $DIR/sroa.rs:+2:13: +2:16 + StorageLive(_4); // scope 2 at $DIR/sroa.rs:+3:9: +3:10 +- _4 = (_2.2: &str); // scope 2 at $DIR/sroa.rs:+3:13: +3:16 ++ _4 = _6; // scope 2 at $DIR/sroa.rs:+3:13: +3:16 + _0 = const (); // scope 0 at $DIR/sroa.rs:+0:24: +4:2 + StorageDead(_4); // scope 2 at $DIR/sroa.rs:+4:1: +4:2 + StorageDead(_3); // scope 1 at $DIR/sroa.rs:+4:1: +4:2 +- StorageDead(_2); // scope 0 at $DIR/sroa.rs:+4:1: +4:2 ++ StorageDead(_5); // scope 0 at $DIR/sroa.rs:+4:1: +4:2 ++ StorageDead(_6); // scope 0 at $DIR/sroa.rs:+4:1: +4:2 ++ nop; // scope 0 at $DIR/sroa.rs:+4:1: +4:2 + return; // scope 0 at $DIR/sroa.rs:+4:2: +4:2 + } + } + diff --git a/tests/mir-opt/sroa.rs b/tests/mir-opt/sroa.rs index ff8deb40d7d..b80f61600c2 100644 --- a/tests/mir-opt/sroa.rs +++ b/tests/mir-opt/sroa.rs @@ -12,17 +12,14 @@ impl Drop for Tag { fn drop(&mut self) {} } -// EMIT_MIR sroa.dropping.ScalarReplacementOfAggregates.diff pub fn dropping() { S(Tag(0), Tag(1), Tag(2)).1; } -// EMIT_MIR sroa.enums.ScalarReplacementOfAggregates.diff pub fn enums(a: usize) -> usize { if let Some(a) = Some(a) { a } else { 0 } } -// EMIT_MIR sroa.structs.ScalarReplacementOfAggregates.diff pub fn structs(a: f32) -> f32 { struct U { _foo: usize, @@ -32,7 +29,6 @@ pub fn structs(a: f32) -> f32 { U { _foo: 0, a }.a } -// EMIT_MIR sroa.unions.ScalarReplacementOfAggregates.diff pub fn unions(a: f32) -> u32 { union Repr { f: f32, @@ -41,6 +37,7 @@ pub fn unions(a: f32) -> u32 { unsafe { Repr { f: a }.u } } +#[derive(Copy, Clone)] struct Foo { a: u8, b: (), @@ -52,7 +49,6 @@ fn g() -> u32 { 3 } -// EMIT_MIR sroa.flat.ScalarReplacementOfAggregates.diff pub fn flat() { let Foo { a, b, c, d } = Foo { a: 5, b: (), c: "a", d: Some(-4) }; let _ = a; @@ -72,12 +68,23 @@ fn f(a: *const u32) { println!("{}", unsafe { *a.add(2) }); } -// EMIT_MIR sroa.escaping.ScalarReplacementOfAggregates.diff pub fn escaping() { // Verify this struct is not flattened. f(&Escaping { a: 1, b: 2, c: g() }.a); } +fn copies(x: Foo) { + let y = x; + let t = y.a; + let u = y.c; +} + +fn ref_copies(x: &Foo) { + let y = *x; + let t = y.a; + let u = y.c; +} + fn main() { dropping(); enums(5); @@ -85,4 +92,15 @@ fn main() { unions(5.); flat(); escaping(); + copies(Foo { a: 5, b: (), c: "a", d: Some(-4) }); + ref_copies(&Foo { a: 5, b: (), c: "a", d: Some(-4) }); } + +// EMIT_MIR sroa.dropping.ScalarReplacementOfAggregates.diff +// EMIT_MIR sroa.enums.ScalarReplacementOfAggregates.diff +// EMIT_MIR sroa.structs.ScalarReplacementOfAggregates.diff +// EMIT_MIR sroa.unions.ScalarReplacementOfAggregates.diff +// EMIT_MIR sroa.flat.ScalarReplacementOfAggregates.diff +// EMIT_MIR sroa.escaping.ScalarReplacementOfAggregates.diff +// EMIT_MIR sroa.copies.ScalarReplacementOfAggregates.diff +// EMIT_MIR sroa.ref_copies.ScalarReplacementOfAggregates.diff