From bc7cd2f351ab35e0830563858e827a2d397d176d Mon Sep 17 00:00:00 2001 From: Jakob Degen Date: Sun, 5 Jun 2022 16:22:54 -0700 Subject: [PATCH] `BitSet` perf improvements This commit makes two changes: 1. Changes `MaybeLiveLocals` to use `ChunkedBitSet` 2. Overrides the `fold` method for the iterator for `ChunkedBitSet` --- compiler/rustc_index/src/bit_set.rs | 83 +++++++++++++++++-- compiler/rustc_index/src/bit_set/tests.rs | 70 +++++++++++++--- .../rustc_mir_dataflow/src/impls/liveness.rs | 4 +- compiler/rustc_mir_dataflow/src/rustc_peek.rs | 4 +- compiler/rustc_mir_transform/src/generator.rs | 3 +- 5 files changed, 141 insertions(+), 23 deletions(-) diff --git a/compiler/rustc_index/src/bit_set.rs b/compiler/rustc_index/src/bit_set.rs index 059755a743b..f2eaef1e149 100644 --- a/compiler/rustc_index/src/bit_set.rs +++ b/compiler/rustc_index/src/bit_set.rs @@ -681,6 +681,48 @@ impl BitRelations> for ChunkedBitSet { } } +impl BitRelations> for BitSet { + fn union(&mut self, other: &ChunkedBitSet) -> bool { + sequential_update(|elem| self.insert(elem), other.iter()) + } + + fn subtract(&mut self, _other: &ChunkedBitSet) -> bool { + unimplemented!("implement if/when necessary"); + } + + fn intersect(&mut self, other: &ChunkedBitSet) -> bool { + assert_eq!(self.domain_size(), other.domain_size); + let mut changed = false; + for (i, chunk) in other.chunks.iter().enumerate() { + let mut words = &mut self.words[i * CHUNK_WORDS..]; + if words.len() > CHUNK_WORDS { + words = &mut words[..CHUNK_WORDS]; + } + match chunk { + Chunk::Zeros(..) => { + for word in words { + if *word != 0 { + changed = true; + *word = 0; + } + } + } + Chunk::Ones(..) => (), + Chunk::Mixed(_, _, data) => { + for (i, word) in words.iter_mut().enumerate() { + let new_val = *word & data[i]; + if new_val != *word { + changed = true; + *word = new_val; + } + } + } + } + } + changed + } +} + impl Clone for ChunkedBitSet { fn clone(&self) -> Self { ChunkedBitSet { @@ -743,6 +785,41 @@ impl<'a, T: Idx> Iterator for ChunkedBitIter<'a, T> { } None } + + fn fold(mut self, mut init: B, mut f: F) -> B + where + F: FnMut(B, Self::Item) -> B, + { + // If `next` has already been called, we may not be at the start of a chunk, so we first + // advance the iterator to the start of the next chunk, before proceeding in chunk sized + // steps. + while self.index % CHUNK_BITS != 0 { + let Some(item) = self.next() else { + return init + }; + init = f(init, item); + } + let start_chunk = self.index / CHUNK_BITS; + let chunks = &self.bitset.chunks[start_chunk..]; + for (i, chunk) in chunks.iter().enumerate() { + let base = (start_chunk + i) * CHUNK_BITS; + match chunk { + Chunk::Zeros(_) => (), + Chunk::Ones(limit) => { + for j in 0..(*limit as usize) { + init = f(init, T::new(base + j)); + } + } + Chunk::Mixed(_, _, words) => { + init = BitIter::new(&**words).fold(init, |val, mut item: T| { + item.increment_by(base); + f(val, item) + }); + } + } + } + init + } } impl Chunk { @@ -799,11 +876,7 @@ fn sequential_update( mut self_update: impl FnMut(T) -> bool, it: impl Iterator, ) -> bool { - let mut changed = false; - for elem in it { - changed |= self_update(elem); - } - changed + it.fold(false, |changed, elem| self_update(elem) | changed) } // Optimization of intersection for SparseBitSet that's generic diff --git a/compiler/rustc_index/src/bit_set/tests.rs b/compiler/rustc_index/src/bit_set/tests.rs index cfc891e97a3..a58133e4aed 100644 --- a/compiler/rustc_index/src/bit_set/tests.rs +++ b/compiler/rustc_index/src/bit_set/tests.rs @@ -342,38 +342,82 @@ fn chunked_bitset() { b10000b.assert_valid(); } +fn with_elements_chunked(elements: &[usize], domain_size: usize) -> ChunkedBitSet { + let mut s = ChunkedBitSet::new_empty(domain_size); + for &e in elements { + assert!(s.insert(e)); + } + s +} + +fn with_elements_standard(elements: &[usize], domain_size: usize) -> BitSet { + let mut s = BitSet::new_empty(domain_size); + for &e in elements { + assert!(s.insert(e)); + } + s +} + +#[test] +fn chunked_bitset_into_bitset_operations() { + let a = vec![1, 5, 7, 11, 15, 2000, 3000]; + let b = vec![3, 4, 11, 3000, 4000]; + let aub = vec![1, 3, 4, 5, 7, 11, 15, 2000, 3000, 4000]; + let aib = vec![11, 3000]; + + let b = with_elements_chunked(&b, 9876); + + let mut union = with_elements_standard(&a, 9876); + assert!(union.union(&b)); + assert!(!union.union(&b)); + assert!(union.iter().eq(aub.iter().copied())); + + let mut intersection = with_elements_standard(&a, 9876); + assert!(intersection.intersect(&b)); + assert!(!intersection.intersect(&b)); + assert!(intersection.iter().eq(aib.iter().copied())); +} + #[test] fn chunked_bitset_iter() { - fn with_elements(elements: &[usize], domain_size: usize) -> ChunkedBitSet { - let mut s = ChunkedBitSet::new_empty(domain_size); - for &e in elements { - s.insert(e); + fn check_iter(bit: &ChunkedBitSet, vec: &Vec) { + // Test collecting via both `.next()` and `.fold()` calls, to make sure both are correct + let mut collect_next = Vec::new(); + let mut bit_iter = bit.iter(); + while let Some(item) = bit_iter.next() { + collect_next.push(item); } - s + assert_eq!(vec, &collect_next); + + let collect_fold = bit.iter().fold(Vec::new(), |mut v, item| { + v.push(item); + v + }); + assert_eq!(vec, &collect_fold); } // Empty let vec: Vec = Vec::new(); - let bit = with_elements(&vec, 9000); - assert_eq!(vec, bit.iter().collect::>()); + let bit = with_elements_chunked(&vec, 9000); + check_iter(&bit, &vec); // Filled let n = 10000; let vec: Vec = (0..n).collect(); - let bit = with_elements(&vec, n); - assert_eq!(vec, bit.iter().collect::>()); + let bit = with_elements_chunked(&vec, n); + check_iter(&bit, &vec); // Filled with trailing zeros let n = 10000; let vec: Vec = (0..n).collect(); - let bit = with_elements(&vec, 2 * n); - assert_eq!(vec, bit.iter().collect::>()); + let bit = with_elements_chunked(&vec, 2 * n); + check_iter(&bit, &vec); // Mixed let n = 12345; let vec: Vec = vec![0, 1, 2, 2010, 2047, 2099, 6000, 6002, 6004]; - let bit = with_elements(&vec, n); - assert_eq!(vec, bit.iter().collect::>()); + let bit = with_elements_chunked(&vec, n); + check_iter(&bit, &vec); } #[test] diff --git a/compiler/rustc_mir_dataflow/src/impls/liveness.rs b/compiler/rustc_mir_dataflow/src/impls/liveness.rs index 7076fbe1bdb..51c59e42101 100644 --- a/compiler/rustc_mir_dataflow/src/impls/liveness.rs +++ b/compiler/rustc_mir_dataflow/src/impls/liveness.rs @@ -30,14 +30,14 @@ impl MaybeLiveLocals { } impl<'tcx> AnalysisDomain<'tcx> for MaybeLiveLocals { - type Domain = BitSet; + type Domain = ChunkedBitSet; type Direction = Backward; const NAME: &'static str = "liveness"; fn bottom_value(&self, body: &mir::Body<'tcx>) -> Self::Domain { // bottom = not live - BitSet::new_empty(body.local_decls.len()) + ChunkedBitSet::new_empty(body.local_decls.len()) } fn initialize_start_block(&self, _: &mir::Body<'tcx>, _: &mut Self::Domain) { diff --git a/compiler/rustc_mir_dataflow/src/rustc_peek.rs b/compiler/rustc_mir_dataflow/src/rustc_peek.rs index 2f884887ad9..e1df482786f 100644 --- a/compiler/rustc_mir_dataflow/src/rustc_peek.rs +++ b/compiler/rustc_mir_dataflow/src/rustc_peek.rs @@ -1,7 +1,7 @@ use rustc_span::symbol::sym; use rustc_span::Span; -use rustc_index::bit_set::BitSet; +use rustc_index::bit_set::ChunkedBitSet; use rustc_middle::mir::MirPass; use rustc_middle::mir::{self, Body, Local, Location}; use rustc_middle::ty::{self, Ty, TyCtxt}; @@ -271,7 +271,7 @@ impl<'tcx> RustcPeekAt<'tcx> for MaybeLiveLocals { &self, tcx: TyCtxt<'tcx>, place: mir::Place<'tcx>, - flow_state: &BitSet, + flow_state: &ChunkedBitSet, call: PeekCall, ) { info!(?place, "peek_at"); diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs index 9eb77f60213..9b7354da841 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -495,7 +495,8 @@ fn locals_live_across_suspend_points<'tcx>( let loc = Location { block, statement_index: data.statements.len() }; liveness.seek_to_block_end(block); - let mut live_locals = liveness.get().clone(); + let mut live_locals: BitSet<_> = BitSet::new_empty(body.local_decls.len()); + live_locals.union(liveness.get()); if !movable { // The `liveness` variable contains the liveness of MIR locals ignoring borrows.