Skip to content

Commit

Permalink
Implement Yaroslavskiy-Bentley-Bloch Quicksort.
Browse files Browse the repository at this point in the history
  • Loading branch information
n3vu0r committed Jun 21, 2021
1 parent b6628c6 commit 764e33c
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 108 deletions.
234 changes: 134 additions & 100 deletions src/sort.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use indexmap::IndexMap;
use ndarray::prelude::*;
use ndarray::{Data, DataMut, Slice};
use rand::prelude::*;
use rand::thread_rng;

/// Methods for sorting and partitioning 1-D arrays.
pub trait Sort1dExt<A, S>
Expand Down Expand Up @@ -50,26 +48,21 @@ where
S: DataMut,
S2: Data<Elem = usize>;

/// Partitions the array in increasing order based on the value initially
/// located at `pivot_index` and returns the new index of the value.
/// Partitions the array in increasing order based on the values initially located at `0` and
/// `n` where `n` is the number of elements in the array and returns the new indexes of the
/// values.
///
/// The elements are rearranged in such a way that the value initially
/// located at `pivot_index` is moved to the position it would be in an
/// array sorted in increasing order. The return value is the new index of
/// the value after rearrangement. All elements smaller than the value are
/// moved to its left and all elements equal or greater than the value are
/// moved to its right. The ordering of the elements in the two partitions
/// is undefined.
/// The elements are rearranged in such a way that the values initially located at `0` and `n`
/// are moved to the position it would be in an array sorted in increasing order. The return
/// values are the new indexes of the values after rearrangement. All elements less than the
/// values are moved to their left and all elements equal or greater than the values are moved
/// to their right. The ordering of the elements in the three partitions is undefined.
///
/// `self` is shuffled **in place** to operate the desired partition:
/// no copy of the array is allocated.
/// `self` is shuffled **in place**, no copy of the array is allocated.
///
/// The method uses Hoare's partition algorithm.
/// Complexity: O(`n`), where `n` is the number of elements in the array.
/// Average number of element swaps: n/6 - 1/3 (see
/// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550))
/// This method implements the partitioning scheme of [Yaroslavskiy-Bentley-Bloch Quicksort].
///
/// **Panics** if `pivot_index` is greater than or equal to `n`.
/// [Yaroslavskiy-Bentley-Bloch Quicksort]: https://www.wild-inter.net/publications/wild-2016
///
/// # Example
///
Expand All @@ -78,23 +71,30 @@ where
/// use ndarray_stats::Sort1dExt;
///
/// let mut data = array![3, 1, 4, 5, 2];
/// let pivot_index = 2;
/// let pivot_value = data[pivot_index];
/// // Sorted pivot values.
/// let (lower_value, upper_value) = (data[data.len() - 1], data[0]);
///
/// // Partition by the value located at `pivot_index`.
/// let new_index = data.partition_mut(pivot_index);
/// // The pivot value is now located at `new_index`.
/// assert_eq!(data[new_index], pivot_value);
/// // Elements less than that value are moved to the left.
/// for i in 0..new_index {
/// assert!(data[i] < pivot_value);
/// // Partitions by the values located at `0` and `data.len() - 1`.
/// let (lower_index, upper_index) = data.partition_mut();
/// // The pivot values are now located at `lower_index` and `upper_index`.
/// assert_eq!(data[lower_index], lower_value);
/// assert_eq!(data[upper_index], upper_value);
/// // Elements lower than the lower pivot value are moved to its left.
/// for i in 0..lower_index {
/// assert!(data[i] < lower_value);
/// }
/// // Elements greater than or equal the lower pivot value and less than or equal the upper
/// // pivot value are moved between the two pivot indexes.
/// for i in lower_index + 1..upper_index {
/// assert!(lower_value <= data[i]);
/// assert!(data[i] <= upper_value);
/// }
/// // Elements greater than or equal to that value are moved to the right.
/// for i in (new_index + 1)..data.len() {
/// assert!(data[i] >= pivot_value);
/// // Elements greater than or equal the upper pivot value are moved to its right.
/// for i in upper_index + 1..data.len() {
/// assert!(upper_value <= data[i]);
/// }
/// ```
fn partition_mut(&mut self, pivot_index: usize) -> usize
fn partition_mut(&mut self) -> (usize, usize)
where
A: Ord + Clone,
S: DataMut;
Expand All @@ -115,17 +115,20 @@ where
if n == 1 {
self[0].clone()
} else {
let mut rng = thread_rng();
let pivot_index = rng.gen_range(0..n);
let partition_index = self.partition_mut(pivot_index);
if i < partition_index {
self.slice_axis_mut(Axis(0), Slice::from(..partition_index))
let (lower_index, upper_index) = self.partition_mut();
if i < lower_index {
self.slice_axis_mut(Axis(0), Slice::from(..lower_index))
.get_from_sorted_mut(i)
} else if i == partition_index {
} else if i == lower_index {
self[i].clone()
} else if i < upper_index {
self.slice_axis_mut(Axis(0), Slice::from(lower_index + 1..upper_index))
.get_from_sorted_mut(i - (lower_index + 1))
} else if i == upper_index {
self[i].clone()
} else {
self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..))
.get_from_sorted_mut(i - (partition_index + 1))
self.slice_axis_mut(Axis(0), Slice::from(upper_index + 1..))
.get_from_sorted_mut(i - (upper_index + 1))
}
}
}
Expand All @@ -143,42 +146,51 @@ where
get_many_from_sorted_mut_unchecked(self, &deduped_indexes)
}

fn partition_mut(&mut self, pivot_index: usize) -> usize
fn partition_mut(&mut self) -> (usize, usize)
where
A: Ord + Clone,
S: DataMut,
{
let pivot_value = self[pivot_index].clone();
self.swap(pivot_index, 0);
let n = self.len();
let mut i = 1;
let mut j = n - 1;
loop {
loop {
if i > j {
break;
}
if self[i] >= pivot_value {
break;
// Sort `lowermost` and `uppermost` elements and use them as dual pivot.
let lowermost = 0;
let uppermost = self.len() - 1;
if self[lowermost] > self[uppermost] {
self.swap(lowermost, uppermost);
}
// Increasing running and partition index starting after lower pivot.
let mut index = lowermost + 1;
let mut lower = lowermost + 1;
// Decreasing partition index starting before upper pivot.
let mut upper = uppermost - 1;
// Swap elements at `index` into their partitions.
while index <= upper {
if self[index] < self[lowermost] {
// Swap elements into lower partition.
self.swap(index, lower);
lower += 1;
} else if self[index] >= self[uppermost] {
// Search first element of upper partition.
while self[upper] > self[uppermost] && index < upper {
upper -= 1;
}
i += 1;
}
while pivot_value <= self[j] {
if j == 1 {
break;
// Swap elements into upper partition.
self.swap(index, upper);
if self[index] < self[lowermost] {
// Swap swapped elements into lower partition.
self.swap(index, lower);
lower += 1;
}
j -= 1;
}
if i >= j {
break;
} else {
self.swap(i, j);
i += 1;
j -= 1;
upper -= 1;
}
index += 1;
}
self.swap(0, i - 1);
i - 1
lower -= 1;
upper += 1;
// Swap pivots to their new indexes.
self.swap(lowermost, lower);
self.swap(uppermost, upper);
// Lower and upper pivot index.
(lower, upper)
}

private_impl! {}
Expand Down Expand Up @@ -249,50 +261,72 @@ fn _get_many_from_sorted_mut_unchecked<A>(
return;
}

// We pick a random pivot index: the corresponding element is the pivot value
let mut rng = thread_rng();
let pivot_index = rng.gen_range(0..n);
// We partition the array with respect to the two pivot values. The pivot values move to
// `lower_index` and `upper_index`.
//
// Elements strictly less than the lower pivot value have indexes < `lower_index`.
//
// Elements greater than or equal the lower pivot value and less than or equal the upper pivot
// value have indexes > `lower_index` and < `upper_index`.
//
// Elements less than or equal the upper pivot value have indexes > `upper_index`.
let (lower_index, upper_index) = array.partition_mut();

// We partition the array with respect to the pivot value.
// The pivot value moves to `array_partition_index`.
// Elements strictly smaller than the pivot value have indexes < `array_partition_index`.
// Elements greater or equal to the pivot value have indexes > `array_partition_index`.
let array_partition_index = array.partition_mut(pivot_index);
// We use a divide-and-conquer strategy, splitting the indexes we are searching for (`indexes`)
// and the corresponding portions of the output slice (`values`) into partitions with respect to
// `lower_index` and `upper_index`.
let (found_exact, split_index) = match indexes.binary_search(&lower_index) {
Ok(index) => (true, index),
Err(index) => (false, index),
};
let (lower_indexes, inner_indexes) = indexes.split_at_mut(split_index);
let (lower_values, inner_values) = values.split_at_mut(split_index);
let (upper_indexes, upper_values) = if found_exact {
inner_values[0] = array[lower_index].clone(); // Write exactly found value.
(&mut inner_indexes[1..], &mut inner_values[1..])
} else {
(inner_indexes, inner_values)
};

// We use a divide-and-conquer strategy, splitting the indexes we are
// searching for (`indexes`) and the corresponding portions of the output
// slice (`values`) into pieces with respect to `array_partition_index`.
let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) {
let (found_exact, split_index) = match upper_indexes.binary_search(&upper_index) {
Ok(index) => (true, index),
Err(index) => (false, index),
};
let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split);
let (smaller_values, other_values) = values.split_at_mut(index_split);
let (bigger_indexes, bigger_values) = if found_exact {
other_values[0] = array[array_partition_index].clone(); // Write exactly found value.
(&mut other_indexes[1..], &mut other_values[1..])
let (inner_indexes, upper_indexes) = upper_indexes.split_at_mut(split_index);
let (inner_values, upper_values) = upper_values.split_at_mut(split_index);
let (upper_indexes, upper_values) = if found_exact {
upper_values[0] = array[upper_index].clone(); // Write exactly found value.
(&mut upper_indexes[1..], &mut upper_values[1..])
} else {
(other_indexes, other_values)
(upper_indexes, upper_values)
};

// We search recursively for the values corresponding to strictly smaller
// indexes to the left of `partition_index`.
// We search recursively for the values corresponding to indexes strictly less than
// `lower_index` in the lower partition.
_get_many_from_sorted_mut_unchecked(
array.slice_axis_mut(Axis(0), Slice::from(..lower_index)),
lower_indexes,
lower_values,
);

// We search recursively for the values corresponding to indexes greater than or equal
// `lower_index` in the inner partition, that is between the lower and upper partition. Since
// only the inner partition of the array is passed in, the indexes need to be shifted by length
// of the lower partition.
inner_indexes.iter_mut().for_each(|x| *x -= lower_index + 1);
_get_many_from_sorted_mut_unchecked(
array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)),
smaller_indexes,
smaller_values,
array.slice_axis_mut(Axis(0), Slice::from(lower_index + 1..upper_index)),
inner_indexes,
inner_values,
);

// We search recursively for the values corresponding to strictly bigger
// indexes to the right of `partition_index`. Since only the right portion
// of the array is passed in, the indexes need to be shifted by length of
// the removed portion.
bigger_indexes
.iter_mut()
.for_each(|x| *x -= array_partition_index + 1);
// We search recursively for the values corresponding to indexes greater than or equal
// `upper_index` in the upper partition. Since only the upper partition of the array is passed
// in, the indexes need to be shifted by the combined length of the lower and inner partition.
upper_indexes.iter_mut().for_each(|x| *x -= upper_index + 1);
_get_many_from_sorted_mut_unchecked(
array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)),
bigger_indexes,
bigger_values,
array.slice_axis_mut(Axis(0), Slice::from(upper_index + 1..)),
upper_indexes,
upper_values,
);
}
23 changes: 15 additions & 8 deletions tests/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@ fn test_partition_mut() {
];
for a in l.iter_mut() {
let n = a.len();
let pivot_index = n - 1;
let pivot_value = a[pivot_index].clone();
let partition_index = a.partition_mut(pivot_index);
for i in 0..partition_index {
assert!(a[i] < pivot_value);
let (mut lower_value, mut upper_value) = (a[0].clone(), a[n - 1].clone());
if lower_value > upper_value {
std::mem::swap(&mut lower_value, &mut upper_value);
}
assert_eq!(a[partition_index], pivot_value);
for j in (partition_index + 1)..n {
assert!(pivot_value <= a[j]);
let (lower_index, upper_index) = a.partition_mut();
for i in 0..lower_index {
assert!(a[i] < lower_value);
}
assert_eq!(a[lower_index], lower_value);
for i in lower_index + 1..upper_index {
assert!(lower_value <= a[i]);
assert!(a[i] <= upper_value);
}
assert_eq!(a[upper_index], upper_value);
for i in (upper_index + 1)..n {
assert!(upper_value <= a[i]);
}
}
}
Expand Down

0 comments on commit 764e33c

Please sign in to comment.