improve worst-case performance of BTreeSet difference and intersection

This commit is contained in:
Stein Somers 2019-03-13 23:01:12 +01:00
parent 4fec737f9a
commit f5fee8fd7d
3 changed files with 352 additions and 123 deletions

View File

@ -3,59 +3,49 @@ use std::collections::BTreeSet;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use test::{black_box, Bencher}; use test::{black_box, Bencher};
fn random(n1: u32, n2: u32) -> [BTreeSet<usize>; 2] { fn random(n: usize) -> BTreeSet<usize> {
let mut rng = thread_rng(); let mut rng = thread_rng();
let mut set1 = BTreeSet::new(); let mut set = BTreeSet::new();
let mut set2 = BTreeSet::new(); while set.len() < n {
for _ in 0..n1 { set.insert(rng.gen());
let i = rng.gen::<usize>();
set1.insert(i);
} }
for _ in 0..n2 { assert_eq!(set.len(), n);
let i = rng.gen::<usize>(); set
set2.insert(i);
}
[set1, set2]
} }
fn staggered(n1: u32, n2: u32) -> [BTreeSet<u32>; 2] { fn neg(n: usize) -> BTreeSet<i32> {
let mut even = BTreeSet::new(); let mut set = BTreeSet::new();
let mut odd = BTreeSet::new(); for i in -(n as i32)..=-1 {
for i in 0..n1 { set.insert(i);
even.insert(i * 2);
} }
for i in 0..n2 { assert_eq!(set.len(), n);
odd.insert(i * 2 + 1); set
}
[even, odd]
} }
fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] { fn pos(n: usize) -> BTreeSet<i32> {
let mut neg = BTreeSet::new(); let mut set = BTreeSet::new();
let mut pos = BTreeSet::new(); for i in 1..=(n as i32) {
for i in -(n1 as i32)..=-1 { set.insert(i);
neg.insert(i);
} }
for i in 1..=(n2 as i32) { assert_eq!(set.len(), n);
pos.insert(i); set
}
[neg, pos]
} }
fn pos_vs_neg(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
let mut neg = BTreeSet::new(); fn stagger(n1: usize, factor: usize) -> [BTreeSet<u32>; 2] {
let mut pos = BTreeSet::new(); let n2 = n1 * factor;
for i in -(n1 as i32)..=-1 { let mut sets = [BTreeSet::new(), BTreeSet::new()];
neg.insert(i); for i in 0..(n1 + n2) {
let b = i % (factor + 1) != 0;
sets[b as usize].insert(i as u32);
} }
for i in 1..=(n2 as i32) { assert_eq!(sets[0].len(), n1);
pos.insert(i); assert_eq!(sets[1].len(), n2);
} sets
[pos, neg]
} }
macro_rules! set_intersection_bench { macro_rules! set_bench {
($name: ident, $sets: expr) => { ($name: ident, $set_func: ident, $result_func: ident, $sets: expr) => {
#[bench] #[bench]
pub fn $name(b: &mut Bencher) { pub fn $name(b: &mut Bencher) {
// setup // setup
@ -63,26 +53,36 @@ macro_rules! set_intersection_bench {
// measure // measure
b.iter(|| { b.iter(|| {
let x = sets[0].intersection(&sets[1]).count(); let x = sets[0].$set_func(&sets[1]).$result_func();
black_box(x); black_box(x);
}) })
} }
}; };
} }
set_intersection_bench! {intersect_random_100, random(100, 100)} set_bench! {intersection_100_neg_vs_100_pos, intersection, count, [neg(100), pos(100)]}
set_intersection_bench! {intersect_random_10k, random(10_000, 10_000)} set_bench! {intersection_100_neg_vs_10k_pos, intersection, count, [neg(100), pos(10_000)]}
set_intersection_bench! {intersect_random_10_vs_10k, random(10, 10_000)} set_bench! {intersection_100_pos_vs_100_neg, intersection, count, [pos(100), neg(100)]}
set_intersection_bench! {intersect_random_10k_vs_10, random(10_000, 10)} set_bench! {intersection_100_pos_vs_10k_neg, intersection, count, [pos(100), neg(10_000)]}
set_intersection_bench! {intersect_staggered_100, staggered(100, 100)} set_bench! {intersection_10k_neg_vs_100_pos, intersection, count, [neg(10_000), pos(100)]}
set_intersection_bench! {intersect_staggered_10k, staggered(10_000, 10_000)} set_bench! {intersection_10k_neg_vs_10k_pos, intersection, count, [neg(10_000), pos(10_000)]}
set_intersection_bench! {intersect_staggered_10_vs_10k, staggered(10, 10_000)} set_bench! {intersection_10k_pos_vs_100_neg, intersection, count, [pos(10_000), neg(100)]}
set_intersection_bench! {intersect_staggered_10k_vs_10, staggered(10_000, 10)} set_bench! {intersection_10k_pos_vs_10k_neg, intersection, count, [pos(10_000), neg(10_000)]}
set_intersection_bench! {intersect_neg_vs_pos_100, neg_vs_pos(100, 100)} set_bench! {intersection_random_100_vs_100, intersection, count, [random(100), random(100)]}
set_intersection_bench! {intersect_neg_vs_pos_10k, neg_vs_pos(10_000, 10_000)} set_bench! {intersection_random_100_vs_10k, intersection, count, [random(100), random(10_000)]}
set_intersection_bench! {intersect_neg_vs_pos_10_vs_10k,neg_vs_pos(10, 10_000)} set_bench! {intersection_random_10k_vs_100, intersection, count, [random(10_000), random(100)]}
set_intersection_bench! {intersect_neg_vs_pos_10k_vs_10,neg_vs_pos(10_000, 10)} set_bench! {intersection_random_10k_vs_10k, intersection, count, [random(10_000), random(10_000)]}
set_intersection_bench! {intersect_pos_vs_neg_100, pos_vs_neg(100, 100)} set_bench! {intersection_staggered_100_vs_100, intersection, count, stagger(100, 1)}
set_intersection_bench! {intersect_pos_vs_neg_10k, pos_vs_neg(10_000, 10_000)} set_bench! {intersection_staggered_10k_vs_10k, intersection, count, stagger(10_000, 1)}
set_intersection_bench! {intersect_pos_vs_neg_10_vs_10k,pos_vs_neg(10, 10_000)} set_bench! {intersection_staggered_100_vs_10k, intersection, count, stagger(100, 100)}
set_intersection_bench! {intersect_pos_vs_neg_10k_vs_10,pos_vs_neg(10_000, 10)} set_bench! {difference_random_100_vs_100, difference, count, [random(100), random(100)]}
set_bench! {difference_random_100_vs_10k, difference, count, [random(100), random(10_000)]}
set_bench! {difference_random_10k_vs_100, difference, count, [random(10_000), random(100)]}
set_bench! {difference_random_10k_vs_10k, difference, count, [random(10_000), random(10_000)]}
set_bench! {difference_staggered_100_vs_100, difference, count, stagger(100, 1)}
set_bench! {difference_staggered_10k_vs_10k, difference, count, stagger(10_000, 1)}
set_bench! {difference_staggered_100_vs_10k, difference, count, stagger(100, 100)}
set_bench! {is_subset_100_vs_100, is_subset, clone, [pos(100), pos(100)]}
set_bench! {is_subset_100_vs_10k, is_subset, clone, [pos(100), pos(10_000)]}
set_bench! {is_subset_10k_vs_100, is_subset, clone, [pos(10_000), pos(100)]}
set_bench! {is_subset_10k_vs_10k, is_subset, clone, [pos(10_000), pos(10_000)]}

View File

@ -3,7 +3,7 @@
use core::borrow::Borrow; use core::borrow::Borrow;
use core::cmp::Ordering::{self, Less, Greater, Equal}; use core::cmp::Ordering::{self, Less, Greater, Equal};
use core::cmp::{min, max}; use core::cmp::max;
use core::fmt::{self, Debug}; use core::fmt::{self, Debug};
use core::iter::{Peekable, FromIterator, FusedIterator}; use core::iter::{Peekable, FromIterator, FusedIterator};
use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds}; use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds};
@ -118,17 +118,36 @@ pub struct Range<'a, T: 'a> {
/// [`difference`]: struct.BTreeSet.html#method.difference /// [`difference`]: struct.BTreeSet.html#method.difference
#[stable(feature = "rust1", since = "1.0.0")] #[stable(feature = "rust1", since = "1.0.0")]
pub struct Difference<'a, T: 'a> { pub struct Difference<'a, T: 'a> {
a: Peekable<Iter<'a, T>>, inner: DifferenceInner<'a, T>,
b: Peekable<Iter<'a, T>>, }
enum DifferenceInner<'a, T: 'a> {
Stitch {
self_iter: Iter<'a, T>,
other_iter: Peekable<Iter<'a, T>>,
},
Search {
self_iter: Iter<'a, T>,
other_set: &'a BTreeSet<T>,
},
} }
#[stable(feature = "collection_debug", since = "1.17.0")] #[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> { impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Difference") match &self.inner {
.field(&self.a) DifferenceInner::Stitch {
.field(&self.b) self_iter,
.finish() other_iter,
} => f
.debug_tuple("Difference")
.field(&self_iter)
.field(&other_iter)
.finish(),
DifferenceInner::Search {
self_iter,
other_set: _,
} => f.debug_tuple("Difference").field(&self_iter).finish(),
}
} }
} }
@ -164,17 +183,36 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
/// [`intersection`]: struct.BTreeSet.html#method.intersection /// [`intersection`]: struct.BTreeSet.html#method.intersection
#[stable(feature = "rust1", since = "1.0.0")] #[stable(feature = "rust1", since = "1.0.0")]
pub struct Intersection<'a, T: 'a> { pub struct Intersection<'a, T: 'a> {
a: Peekable<Iter<'a, T>>, inner: IntersectionInner<'a, T>,
b: Peekable<Iter<'a, T>>, }
enum IntersectionInner<'a, T: 'a> {
Stitch {
small_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets
other_iter: Iter<'a, T>,
},
Search {
small_iter: Iter<'a, T>,
large_set: &'a BTreeSet<T>,
},
} }
#[stable(feature = "collection_debug", since = "1.17.0")] #[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> { impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Intersection") match &self.inner {
.field(&self.a) IntersectionInner::Stitch {
.field(&self.b) small_iter,
.finish() other_iter,
} => f
.debug_tuple("Intersection")
.field(&small_iter)
.field(&other_iter)
.finish(),
IntersectionInner::Search {
small_iter,
large_set: _,
} => f.debug_tuple("Intersection").field(&small_iter).finish(),
}
} }
} }
@ -201,6 +239,14 @@ impl<T: fmt::Debug> fmt::Debug for Union<'_, T> {
} }
} }
// This constant is used by functions that compare two sets.
// It estimates the relative size at which searching performs better
// than iterating, based on the benchmarks in
// https://github.com/ssomers/rust_bench_btreeset_intersection;
// It's used to divide rather than multiply sizes, to rule out overflow,
// and it's a power of two to make that division cheap.
const ITER_PERFORMANCE_TIPPING_SIZE_DIFF: usize = 16;
impl<T: Ord> BTreeSet<T> { impl<T: Ord> BTreeSet<T> {
/// Makes a new `BTreeSet` with a reasonable choice of B. /// Makes a new `BTreeSet` with a reasonable choice of B.
/// ///
@ -268,9 +314,24 @@ impl<T: Ord> BTreeSet<T> {
/// ``` /// ```
#[stable(feature = "rust1", since = "1.0.0")] #[stable(feature = "rust1", since = "1.0.0")]
pub fn difference<'a>(&'a self, other: &'a BTreeSet<T>) -> Difference<'a, T> { pub fn difference<'a>(&'a self, other: &'a BTreeSet<T>) -> Difference<'a, T> {
Difference { if self.len() > other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
a: self.iter().peekable(), // Self is bigger than or not much smaller than other set.
b: other.iter().peekable(), // Iterate both sets jointly, spotting matches along the way.
Difference {
inner: DifferenceInner::Stitch {
self_iter: self.iter(),
other_iter: other.iter().peekable(),
},
}
} else {
// Self is much smaller than other set, or both sets are empty.
// Iterate the small set, searching for matches in the large set.
Difference {
inner: DifferenceInner::Search {
self_iter: self.iter(),
other_set: other,
},
}
} }
} }
@ -326,9 +387,29 @@ impl<T: Ord> BTreeSet<T> {
/// ``` /// ```
#[stable(feature = "rust1", since = "1.0.0")] #[stable(feature = "rust1", since = "1.0.0")]
pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T> { pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T> {
Intersection { let (small, other) = if self.len() <= other.len() {
a: self.iter().peekable(), (self, other)
b: other.iter().peekable(), } else {
(other, self)
};
if small.len() > other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
// Small set is not much smaller than other set.
// Iterate both sets jointly, spotting matches along the way.
Intersection {
inner: IntersectionInner::Stitch {
small_iter: small.iter(),
other_iter: other.iter(),
},
}
} else {
// Big difference in number of elements, or both sets are empty.
// Iterate the small set, searching for matches in the large set.
Intersection {
inner: IntersectionInner::Search {
small_iter: small.iter(),
large_set: other,
},
}
} }
} }
@ -462,28 +543,44 @@ impl<T: Ord> BTreeSet<T> {
/// ``` /// ```
#[stable(feature = "rust1", since = "1.0.0")] #[stable(feature = "rust1", since = "1.0.0")]
pub fn is_subset(&self, other: &BTreeSet<T>) -> bool { pub fn is_subset(&self, other: &BTreeSet<T>) -> bool {
// Stolen from TreeMap // Same result as self.difference(other).next().is_none()
let mut x = self.iter(); // but the 3 paths below are faster (in order: hugely, 20%, 5%).
let mut y = other.iter(); if self.len() > other.len() {
let mut a = x.next(); false
let mut b = y.next(); } else if self.len() > other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
while a.is_some() { // Self is not much smaller than other set.
if b.is_none() { // Stolen from TreeMap
return false; let mut x = self.iter();
let mut y = other.iter();
let mut a = x.next();
let mut b = y.next();
while a.is_some() {
if b.is_none() {
return false;
}
let a1 = a.unwrap();
let b1 = b.unwrap();
match b1.cmp(a1) {
Less => (),
Greater => return false,
Equal => a = x.next(),
}
b = y.next();
} }
true
let a1 = a.unwrap(); } else {
let b1 = b.unwrap(); // Big difference in number of elements, or both sets are empty.
// Iterate the small set, searching for matches in the large set.
match b1.cmp(a1) { for next in self {
Less => (), if !other.contains(next) {
Greater => return false, return false;
Equal => a = x.next(), }
} }
true
b = y.next();
} }
true
} }
/// Returns `true` if the set is a superset of another, /// Returns `true` if the set is a superset of another,
@ -1001,8 +1098,22 @@ fn cmp_opt<T: Ord>(x: Option<&T>, y: Option<&T>, short: Ordering, long: Ordering
impl<T> Clone for Difference<'_, T> { impl<T> Clone for Difference<'_, T> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Difference { Difference {
a: self.a.clone(), inner: match &self.inner {
b: self.b.clone(), DifferenceInner::Stitch {
self_iter,
other_iter,
} => DifferenceInner::Stitch {
self_iter: self_iter.clone(),
other_iter: other_iter.clone(),
},
DifferenceInner::Search {
self_iter,
other_set,
} => DifferenceInner::Search {
self_iter: self_iter.clone(),
other_set,
},
},
} }
} }
} }
@ -1011,24 +1122,52 @@ impl<'a, T: Ord> Iterator for Difference<'a, T> {
type Item = &'a T; type Item = &'a T;
fn next(&mut self) -> Option<&'a T> { fn next(&mut self) -> Option<&'a T> {
loop { match &mut self.inner {
match cmp_opt(self.a.peek(), self.b.peek(), Less, Less) { DifferenceInner::Stitch {
Less => return self.a.next(), self_iter,
Equal => { other_iter,
self.a.next(); } => {
self.b.next(); let mut self_next = self_iter.next()?;
} loop {
Greater => { match other_iter
self.b.next(); .peek()
.map_or(Less, |other_next| Ord::cmp(self_next, other_next))
{
Less => return Some(self_next),
Equal => {
self_next = self_iter.next()?;
other_iter.next();
}
Greater => {
other_iter.next();
}
}
} }
} }
DifferenceInner::Search {
self_iter,
other_set,
} => loop {
let self_next = self_iter.next()?;
if !other_set.contains(&self_next) {
return Some(self_next);
}
},
} }
} }
fn size_hint(&self) -> (usize, Option<usize>) { fn size_hint(&self) -> (usize, Option<usize>) {
let a_len = self.a.len(); let (self_len, other_len) = match &self.inner {
let b_len = self.b.len(); DifferenceInner::Stitch {
(a_len.saturating_sub(b_len), Some(a_len)) self_iter,
other_iter
} => (self_iter.len(), other_iter.len()),
DifferenceInner::Search {
self_iter,
other_set
} => (self_iter.len(), other_set.len()),
};
(self_len.saturating_sub(other_len), Some(self_len))
} }
} }
@ -1073,8 +1212,22 @@ impl<T: Ord> FusedIterator for SymmetricDifference<'_, T> {}
impl<T> Clone for Intersection<'_, T> { impl<T> Clone for Intersection<'_, T> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Intersection { Intersection {
a: self.a.clone(), inner: match &self.inner {
b: self.b.clone(), IntersectionInner::Stitch {
small_iter,
other_iter,
} => IntersectionInner::Stitch {
small_iter: small_iter.clone(),
other_iter: other_iter.clone(),
},
IntersectionInner::Search {
small_iter,
large_set,
} => IntersectionInner::Search {
small_iter: small_iter.clone(),
large_set,
},
},
} }
} }
} }
@ -1083,24 +1236,39 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
type Item = &'a T; type Item = &'a T;
fn next(&mut self) -> Option<&'a T> { fn next(&mut self) -> Option<&'a T> {
loop { match &mut self.inner {
match Ord::cmp(self.a.peek()?, self.b.peek()?) { IntersectionInner::Stitch {
Less => { small_iter,
self.a.next(); other_iter,
} } => {
Equal => { let mut small_next = small_iter.next()?;
self.b.next(); let mut other_next = other_iter.next()?;
return self.a.next(); loop {
} match Ord::cmp(small_next, other_next) {
Greater => { Less => small_next = small_iter.next()?,
self.b.next(); Greater => other_next = other_iter.next()?,
Equal => return Some(small_next),
}
} }
} }
IntersectionInner::Search {
small_iter,
large_set,
} => loop {
let small_next = small_iter.next()?;
if large_set.contains(&small_next) {
return Some(small_next);
}
},
} }
} }
fn size_hint(&self) -> (usize, Option<usize>) { fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(min(self.a.len(), self.b.len()))) let min_len = match &self.inner {
IntersectionInner::Stitch { small_iter, .. } => small_iter.len(),
IntersectionInner::Search { small_iter, .. } => small_iter.len(),
};
(0, Some(min_len))
} }
} }

View File

@ -69,6 +69,20 @@ fn test_intersection() {
check_intersection(&[11, 1, 3, 77, 103, 5, -5], check_intersection(&[11, 1, 3, 77, 103, 5, -5],
&[2, 11, 77, -9, -42, 5, 3], &[2, 11, 77, -9, -42, 5, 3],
&[3, 5, 11, 77]); &[3, 5, 11, 77]);
let large = (0..1000).collect::<Vec<_>>();
check_intersection(&[], &large, &[]);
check_intersection(&large, &[], &[]);
check_intersection(&[-1], &large, &[]);
check_intersection(&large, &[-1], &[]);
check_intersection(&[0], &large, &[0]);
check_intersection(&large, &[0], &[0]);
check_intersection(&[999], &large, &[999]);
check_intersection(&large, &[999], &[999]);
check_intersection(&[1000], &large, &[]);
check_intersection(&large, &[1000], &[]);
check_intersection(&[11, 5000, 1, 3, 77, 8924, 103],
&large,
&[1, 3, 11, 77, 103]);
} }
#[test] #[test]
@ -84,6 +98,18 @@ fn test_difference() {
check_difference(&[-5, 11, 22, 33, 40, 42], check_difference(&[-5, 11, 22, 33, 40, 42],
&[-12, -5, 14, 23, 34, 38, 39, 50], &[-12, -5, 14, 23, 34, 38, 39, 50],
&[11, 22, 33, 40, 42]); &[11, 22, 33, 40, 42]);
let large = (0..1000).collect::<Vec<_>>();
check_difference(&[], &large, &[]);
check_difference(&[-1], &large, &[-1]);
check_difference(&[0], &large, &[]);
check_difference(&[999], &large, &[]);
check_difference(&[1000], &large, &[1000]);
check_difference(&[11, 5000, 1, 3, 77, 8924, 103],
&large,
&[5000, 8924]);
check_difference(&large, &[], &large);
check_difference(&large, &[-1], &large);
check_difference(&large, &[1000], &large);
} }
#[test] #[test]
@ -114,6 +140,41 @@ fn test_union() {
&[-2, 1, 3, 5, 9, 11, 13, 16, 19, 24]); &[-2, 1, 3, 5, 9, 11, 13, 16, 19, 24]);
} }
#[test]
// Only tests the simple function definition with respect to intersection
fn test_is_disjoint() {
let one = [1].into_iter().collect::<BTreeSet<_>>();
let two = [2].into_iter().collect::<BTreeSet<_>>();
assert!(one.is_disjoint(&two));
}
#[test]
// Also tests the trivial function definition of is_superset
fn test_is_subset() {
fn is_subset(a: &[i32], b: &[i32]) -> bool {
let set_a = a.iter().collect::<BTreeSet<_>>();
let set_b = b.iter().collect::<BTreeSet<_>>();
set_a.is_subset(&set_b)
}
assert_eq!(is_subset(&[], &[]), true);
assert_eq!(is_subset(&[], &[1, 2]), true);
assert_eq!(is_subset(&[0], &[1, 2]), false);
assert_eq!(is_subset(&[1], &[1, 2]), true);
assert_eq!(is_subset(&[2], &[1, 2]), true);
assert_eq!(is_subset(&[3], &[1, 2]), false);
assert_eq!(is_subset(&[1, 2], &[1]), false);
assert_eq!(is_subset(&[1, 2], &[1, 2]), true);
assert_eq!(is_subset(&[1, 2], &[2, 3]), false);
let large = (0..1000).collect::<Vec<_>>();
assert_eq!(is_subset(&[], &large), true);
assert_eq!(is_subset(&large, &[]), false);
assert_eq!(is_subset(&[-1], &large), false);
assert_eq!(is_subset(&[0], &large), true);
assert_eq!(is_subset(&[1, 2], &large), true);
assert_eq!(is_subset(&[999, 1000], &large), false);
}
#[test] #[test]
fn test_zip() { fn test_zip() {
let mut x = BTreeSet::new(); let mut x = BTreeSet::new();