Auto merge of #78681 - m-ou-se:binary-heap-retain, r=Amanieu

Improve rebuilding behaviour of BinaryHeap::retain.

This changes `BinaryHeap::retain` such that it doesn't always fully rebuild the heap, but only rebuilds the parts for which that's necessary.

This makes use of the fact that retain gives out `&T`s and not `&mut T`s.

Retaining every element or removing only elements at the end results in no rebuilding at all. Retaining most elements results in only reordering the elements that got moved (those after the first removed element), using the same logic as was already used for `append`.

cc `@KodrAus` `@sfackler` - We briefly discussed this possibility in the meeting last week while we talked about stabilization of this function (#71503).
This commit is contained in:
bors 2021-04-23 00:07:19 +00:00
commit f4a8cf0a00
2 changed files with 69 additions and 35 deletions

View File

@ -652,6 +652,43 @@ impl<T: Ord> BinaryHeap<T> {
unsafe { self.sift_up(start, pos) };
}
/// Rebuild assuming data[0..start] is still a proper heap.
fn rebuild_tail(&mut self, start: usize) {
if start == self.len() {
return;
}
let tail_len = self.len() - start;
#[inline(always)]
fn log2_fast(x: usize) -> usize {
(usize::BITS - x.leading_zeros() - 1) as usize
}
// `rebuild` takes O(self.len()) operations
// and about 2 * self.len() comparisons in the worst case
// while repeating `sift_up` takes O(tail_len * log(start)) operations
// and about 1 * tail_len * log_2(start) comparisons in the worst case,
// assuming start >= tail_len. For larger heaps, the crossover point
// no longer follows this reasoning and was determined empirically.
let better_to_rebuild = if start < tail_len {
true
} else if self.len() <= 2048 {
2 * self.len() < tail_len * log2_fast(start)
} else {
2 * self.len() < tail_len * 11
};
if better_to_rebuild {
self.rebuild();
} else {
for i in start..self.len() {
// SAFETY: The index `i` is always less than self.len().
unsafe { self.sift_up(0, i) };
}
}
}
fn rebuild(&mut self) {
let mut n = self.len() / 2;
while n > 0 {
@ -689,37 +726,11 @@ impl<T: Ord> BinaryHeap<T> {
swap(self, other);
}
if other.is_empty() {
return;
}
let start = self.data.len();
#[inline(always)]
fn log2_fast(x: usize) -> usize {
(usize::BITS - x.leading_zeros() - 1) as usize
}
self.data.append(&mut other.data);
// `rebuild` takes O(len1 + len2) operations
// and about 2 * (len1 + len2) comparisons in the worst case
// while `extend` takes O(len2 * log(len1)) operations
// and about 1 * len2 * log_2(len1) comparisons in the worst case,
// assuming len1 >= len2. For larger heaps, the crossover point
// no longer follows this reasoning and was determined empirically.
#[inline]
fn better_to_rebuild(len1: usize, len2: usize) -> bool {
let tot_len = len1 + len2;
if tot_len <= 2048 {
2 * tot_len < len2 * log2_fast(len1)
} else {
2 * tot_len < len2 * 11
}
}
if better_to_rebuild(self.len(), other.len()) {
self.data.append(&mut other.data);
self.rebuild();
} else {
self.extend(other.drain());
}
self.rebuild_tail(start);
}
/// Returns an iterator which retrieves elements in heap order.
@ -770,12 +781,22 @@ impl<T: Ord> BinaryHeap<T> {
/// assert_eq!(heap.into_sorted_vec(), [-10, 2, 4])
/// ```
#[unstable(feature = "binary_heap_retain", issue = "71503")]
pub fn retain<F>(&mut self, f: F)
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> bool,
{
self.data.retain(f);
self.rebuild();
let mut first_removed = self.len();
let mut i = 0;
self.data.retain(|e| {
let keep = f(e);
if !keep && i < first_removed {
first_removed = i;
}
i += 1;
keep
});
// data[0..first_removed] is untouched, so we only need to rebuild the tail:
self.rebuild_tail(first_removed);
}
}

View File

@ -386,10 +386,23 @@ fn assert_covariance() {
#[test]
fn test_retain() {
let mut a = BinaryHeap::from(vec![-10, -5, 1, 2, 4, 13]);
a.retain(|x| x % 2 == 0);
let mut a = BinaryHeap::from(vec![100, 10, 50, 1, 2, 20, 30]);
a.retain(|&x| x != 2);
assert_eq!(a.into_sorted_vec(), [-10, 2, 4])
// Check that 20 moved into 10's place.
assert_eq!(a.clone().into_vec(), [100, 20, 50, 1, 10, 30]);
a.retain(|_| true);
assert_eq!(a.clone().into_vec(), [100, 20, 50, 1, 10, 30]);
a.retain(|&x| x < 50);
assert_eq!(a.clone().into_vec(), [30, 20, 10, 1]);
a.retain(|_| false);
assert!(a.is_empty());
}
// old binaryheap failed this test