std::vec: implement a stable merge sort, deferring to insertion sort for

very small runs.

This uses a lot of unsafe code for speed, otherwise we would be having
to sort by sorting lists of indices and then do a pile of swaps to put
everything in the correct place.

Fixes #9819.
This commit is contained in:
Huon Wilson 2013-12-19 09:24:26 +11:00
parent 3906823765
commit 721609e4ae
2 changed files with 348 additions and 1 deletions

View File

@ -1921,6 +1921,150 @@ impl<T:Eq> OwnedEqVector<T> for ~[T] {
}
}
fn merge_sort<T>(v: &mut [T], less_eq: |&T, &T| -> bool) {
// warning: this wildly uses unsafe.
static INSERTION: uint = 8;
let len = v.len();
// allocate some memory to use as scratch memory, we keep the
// length 0 so we can keep shallow copies of the contents of `v`
// without risking the dtors running on an object twice if
// `less_eq` fails.
let mut working_space = with_capacity(2 * len);
// these both are buffers of length `len`.
let mut buf_dat = working_space.as_mut_ptr();
let mut buf_tmp = unsafe {buf_dat.offset(len as int)};
// length `len`.
let buf_v = v.as_ptr();
// step 1. sort short runs with insertion sort. This takes the
// values from `v` and sorts them into `buf_dat`, leaving that
// with sorted runs of length INSERTION.
// We could hardcode the sorting comparisons here, and we could
// manipulate/step the pointers themselves, rather than repeatedly
// .offset-ing.
for start in range_step(0, len, INSERTION) {
// start <= i <= len;
for i in range(start, cmp::min(start + INSERTION, len)) {
// j satisfies: start <= j <= i;
let mut j = i as int;
unsafe {
// `i` is in bounds.
let read_ptr = buf_v.offset(i as int);
// find where to insert, we need to do strict <,
// rather than <=, to maintain stability.
// start <= j - 1 < len, so .offset(j - 1) is in
// bounds.
while j > start as int && !less_eq(&*buf_dat.offset(j - 1), &*read_ptr) {
j -= 1;
}
// shift everything to the right, to make space to
// insert this value.
// j + 1 could be `len` (for the last `i`), but in
// that case, `i == j` so we don't copy. The
// `.offset(j)` is always in bounds.
ptr::copy_memory(buf_dat.offset(j + 1),
buf_dat.offset(j),
i - j as uint);
ptr::copy_nonoverlapping_memory(buf_dat.offset(j), read_ptr, 1);
}
}
}
// step 2. merge the sorted runs.
let mut width = INSERTION;
while width < len {
// merge the sorted runs of length `width` in `buf_dat` two at
// a time, placing the result in `buf_tmp`.
// 0 <= start <= len.
for start in range_step(0, len, 2 * width) {
// manipulate pointers directly for speed (rather than
// using a `for` loop with `range` and `.offset` inside
// that loop).
unsafe {
// the end of the first run & start of the
// second. Offset of `len` is defined, since this is
// precisely one byte past the end of the object.
let right_start = buf_dat.offset(cmp::min(start + width, len) as int);
// end of the second. Similar reasoning to the above re safety.
let right_end_idx = cmp::min(start + 2 * width, len);
let right_end = buf_dat.offset(right_end_idx as int);
// the pointers to the elements under consideration
// from the two runs.
// both of these are in bounds.
let mut left = buf_dat.offset(start as int);
let mut right = right_start;
// where we're putting the results, it is a run of
// length `2*width`, so we step it once for each step
// of either `left` or `right`. `buf_tmp` has length
// `len`, so these are in bounds.
let mut out = buf_tmp.offset(start as int);
let out_end = buf_tmp.offset(right_end_idx as int);
while out < out_end {
// Either the left or the right run are exhausted,
// so just copy the remainder from the other run
// and move on; this gives a huge speed-up (order
// of 25%) for mostly sorted vectors (the best
// case).
if left == right_start {
// the number remaining in this run.
let elems = (right_end as uint - right as uint) / mem::size_of::<T>();
ptr::copy_nonoverlapping_memory(out, right, elems);
break;
} else if right == right_end {
let elems = (right_start as uint - left as uint) / mem::size_of::<T>();
ptr::copy_nonoverlapping_memory(out, left, elems);
break;
}
// check which side is smaller, and that's the
// next element for the new run.
// `left < right_start` and `right < right_end`,
// so these are valid.
let to_copy = if less_eq(&*left, &*right) {
step(&mut left)
} else {
step(&mut right)
};
ptr::copy_nonoverlapping_memory(out, to_copy, 1);
step(&mut out);
}
}
}
util::swap(&mut buf_dat, &mut buf_tmp);
width *= 2;
}
// write the result to `v` in one go, so that there are never two copies
// of the same object in `v`.
unsafe {
ptr::copy_nonoverlapping_memory(v.as_mut_ptr(), buf_dat, len);
}
// increment the pointer, returning the old pointer.
#[inline(always)]
unsafe fn step<T>(ptr: &mut *mut T) -> *mut T {
let old = *ptr;
*ptr = ptr.offset(1);
old
}
}
/// Extension methods for vectors such that their elements are
/// mutable.
pub trait MutableVector<'a, T> {
@ -2020,6 +2164,25 @@ pub trait MutableVector<'a, T> {
/// Reverse the order of elements in a vector, in place
fn reverse(self);
/// Sort the vector, in place, using `less_eq` to compare `a <=
/// b`.
///
/// This sort is `O(n log n)` worst-case and stable, but allocates
/// approximately `2 * n`, where `n` is the length of `self`.
///
/// # Example
///
/// ```rust
/// let mut v = [5, 4, 1, 3, 2];
/// v.sort(|a, b| *a <= *b);
/// assert_eq!(v, [1, 2, 3, 4, 5]);
///
/// // reverse sorting
/// v.sort(|a, b| *b <= *a);
/// assert_eq!(v, [5, 4, 3, 2, 1]);
/// ```
fn sort(self, less_eq: |&T, &T| -> bool);
/**
* Consumes `src` and moves as many elements as it can into `self`
* from the range [start,end).
@ -2164,6 +2327,11 @@ impl<'a,T> MutableVector<'a, T> for &'a mut [T] {
}
}
#[inline]
fn sort(self, less_eq: |&T, &T| -> bool) {
merge_sort(self, less_eq)
}
#[inline]
fn move_from(self, mut src: ~[T], start: uint, end: uint) -> uint {
for (a, b) in self.mut_iter().zip(src.mut_slice(start, end).mut_iter()) {
@ -2692,6 +2860,7 @@ mod tests {
use vec::*;
use cmp::*;
use prelude::*;
use rand::{Rng, task_rng};
fn square(n: uint) -> uint { n * n }
@ -3298,6 +3467,57 @@ mod tests {
assert!(v3.is_empty());
}
#[test]
fn test_sort() {
for len in range(4u, 25) {
for _ in range(0, 100) {
let mut v = task_rng().gen_vec::<uint>(len);
v.sort(|a,b| a <= b);
assert!(v.windows(2).all(|w| w[0] <= w[1]));
}
}
// shouldn't fail/crash
let mut v: [uint, .. 0] = [];
v.sort(|a,b| a <= b);
let mut v = [0xDEADBEEF];
v.sort(|a,b| a <= b);
assert_eq!(v, [0xDEADBEEF]);
}
#[test]
fn test_sort_stability() {
for len in range(4, 25) {
for _ in range(0 , 10) {
let mut counts = [0, .. 10];
// create a vector like [(6, 1), (5, 1), (6, 2), ...],
// where the first item of each tuple is random, but
// the second item represents which occurrence of that
// number this element is, i.e. the second elements
// will occur in sorted order.
let mut v = range(0, len).map(|_| {
let n = task_rng().gen::<uint>() % 10;
counts[n] += 1;
(n, counts[n])
}).to_owned_vec();
// only sort on the first element, so an unstable sort
// may mix up the counts.
v.sort(|&(a,_), &(b,_)| a <= b);
// this comparison includes the count (the second item
// of the tuple), so elements with equal first items
// will need to be ordered with increasing
// counts... i.e. exactly asserting that this sort is
// stable.
assert!(v.windows(2).all(|w| w[0] <= w[1]));
}
}
}
#[test]
fn test_partition() {
assert_eq!((~[]).partition(|x: &int| *x < 3), (~[], ~[]));
@ -4124,7 +4344,8 @@ mod bench {
use vec::VectorVector;
use option::*;
use ptr;
use rand::{weak_rng, Rng};
use rand::{weak_rng, task_rng, Rng};
use mem;
#[bench]
fn iterator(bh: &mut BenchHarness) {
@ -4325,4 +4546,42 @@ mod bench {
}
})
}
fn sort_random_small(bh: &mut BenchHarness) {
let mut rng = weak_rng();
bh.iter(|| {
let mut v: ~[f64] = rng.gen_vec(5);
v.sort(|a,b| *a <= *b);
});
bh.bytes = 5 * mem::size_of::<f64>() as u64;
}
#[bench]
fn sort_random_medium(bh: &mut BenchHarness) {
let mut rng = weak_rng();
bh.iter(|| {
let mut v: ~[f64] = rng.gen_vec(100);
v.sort(|a,b| *a <= *b);
});
bh.bytes = 100 * mem::size_of::<f64>() as u64;
}
#[bench]
fn sort_random_large(bh: &mut BenchHarness) {
let mut rng = weak_rng();
bh.iter(|| {
let mut v: ~[f64] = rng.gen_vec(10000);
v.sort(|a,b| *a <= *b);
});
bh.bytes = 10000 * mem::size_of::<f64>() as u64;
}
#[bench]
fn sort_sorted(bh: &mut BenchHarness) {
let mut v = vec::from_fn(10000, |i| i);
bh.iter(|| {
v.sort(|a,b| *a <= *b);
});
bh.bytes = (v.len() * mem::size_of_val(&v[0])) as u64;
}
}

View File

@ -0,0 +1,88 @@
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use std::rand::{task_rng, Rng};
static MAX_LEN: uint = 20;
static mut drop_counts: [uint, .. MAX_LEN] = [0, .. MAX_LEN];
static mut clone_count: uint = 0;
#[deriving(Rand, Ord)]
struct DropCounter { x: uint, clone_num: uint }
impl Clone for DropCounter {
fn clone(&self) -> DropCounter {
let num = unsafe { clone_count };
unsafe { clone_count += 1; }
DropCounter {
x: self.x,
clone_num: num
}
}
}
impl Drop for DropCounter {
fn drop(&mut self) {
unsafe {
// Rand creates some with arbitrary clone_nums
if self.clone_num < MAX_LEN {
drop_counts[self.clone_num] += 1;
}
}
}
}
pub fn main() {
// len can't go above 64.
for len in range(2u, MAX_LEN) {
for _ in range(0, 10) {
let main = task_rng().gen_vec::<DropCounter>(len);
// work out the total number of comparisons required to sort
// this array...
let mut count = 0;
main.clone().sort(|a, b| { count += 1; a <= b });
// ... and then fail on each and every single one.
for fail_countdown in range(0, count) {
// refresh the counters.
unsafe {
drop_counts = [0, .. MAX_LEN];
clone_count = 0;
}
let v = main.clone();
std::task::try(proc() {
let mut v = v;
let mut fail_countdown = fail_countdown;
v.sort(|a, b| {
if fail_countdown == 0 {
fail!()
}
fail_countdown -= 1;
a <= b
})
});
// check that the number of things dropped is exactly
// what we expect (i.e. the contents of `v`).
unsafe {
for (i, &c) in drop_counts.iter().enumerate() {
let expected = if i < len {1} else {0};
assert!(c == expected,
"found drop count == {} for i == {}, len == {}",
c, i, len);
}
}
}
}
}
}