Skip to content

Commit

Permalink
Rewrite ChunkByProducer to be more like SplitProducer
Browse files Browse the repository at this point in the history
In particular, this now ensures we only call the predicate once on
each pair of items.
  • Loading branch information
cuviper committed Mar 23, 2024
1 parent ad41345 commit e37ec9e
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 115 deletions.
239 changes: 124 additions & 115 deletions src/slice/chunk_by.rs
Original file line number Diff line number Diff line change
@@ -1,83 +1,141 @@
use crate::iter::plumbing::*;
use crate::iter::*;
use std::fmt;
use std::marker::PhantomData;
use std::{fmt, mem};

fn find_first_index<T, P>(xs: &[T], pred: &P) -> Option<usize>
where
P: Fn(&T, &T) -> bool,
{
xs.windows(2)
.position(|w| !pred(&w[0], &w[1]))
.map(|i| i + 1)
trait ChunkBySlice<T>: AsRef<[T]> + Default + Send {
fn split(self, index: usize) -> (Self, Self);

fn find(&self, pred: &impl Fn(&T, &T) -> bool, start: usize, end: usize) -> Option<usize> {
self.as_ref()[start..end]
.windows(2)
.position(move |w| !pred(&w[0], &w[1]))
.map(|i| i + 1)
}

fn rfind(&self, pred: &impl Fn(&T, &T) -> bool, end: usize) -> Option<usize> {
self.as_ref()[..end]
.windows(2)
.rposition(move |w| !pred(&w[0], &w[1]))
.map(|i| i + 1)
}
}

fn find_index<T, P>(xs: &[T], pred: &P) -> Option<usize>
where
P: Fn(&T, &T) -> bool,
{
let n = (xs.len() / 2).saturating_sub(1);

for m in (1..((n / 2) + 1)).map(|x| 2 * x) {
let start = n.saturating_sub(m);
let end = std::cmp::min(n + m, xs.len());
let fsts = &xs[start..end];
let (_, snds) = fsts.split_first()?;
match fsts.iter().zip(snds).position(|(x, y)| !pred(x, y)) {
None => (),
Some(i) => return Some(start + i + 1),
}
impl<T: Sync> ChunkBySlice<T> for &[T] {
fn split(self, index: usize) -> (Self, Self) {
self.split_at(index)
}
None
}

struct ChunkByProducer<'data, 'p, T, P> {
pred: &'p P,
slice: &'data [T],
impl<T: Send> ChunkBySlice<T> for &mut [T] {
fn split(self, index: usize) -> (Self, Self) {
self.split_at_mut(index)
}
}

struct ChunkByProducer<'p, T, Slice, Pred> {
slice: Slice,
pred: &'p Pred,
tail: usize,
marker: PhantomData<fn(&T)>,
}

impl<'data, 'p, T, P> UnindexedProducer for ChunkByProducer<'data, 'p, T, P>
// Note: this implementation is very similar to `SplitProducer`.
impl<T, Slice, Pred> UnindexedProducer for ChunkByProducer<'_, T, Slice, Pred>
where
T: Sync,
P: Fn(&T, &T) -> bool + Send + Sync,
Slice: ChunkBySlice<T>,
Pred: Fn(&T, &T) -> bool + Send + Sync,
{
type Item = &'data [T];
type Item = Slice;

fn split(self) -> (Self, Option<Self>) {
match find_index(self.slice, self.pred) {
Some(i) => {
let (ys, zs) = self.slice.split_at(i);
(
Self {
pred: self.pred,
slice: ys,
},
Some(Self {
pred: self.pred,
slice: zs,
}),
)
}
None => (self, None),
if self.tail < 2 {
return (Self { tail: 0, ..self }, None);
}

// Look forward for the separator, and failing that look backward.
let mid = self.tail / 2;
let index = match self.slice.find(self.pred, mid, self.tail) {
Some(i) => Some(mid + i),
None => self.slice.rfind(self.pred, mid + 1),
};

if let Some(index) = index {
let (left, right) = self.slice.split(index);

let (left_tail, right_tail) = if index <= mid {
// If we scanned backwards to find the separator, everything in
// the right side is exhausted, with no separators left to find.
(index, 0)
} else {
(mid + 1, self.tail - index)
};

// Create the left split before the separator.
let left = Self {
slice: left,
tail: left_tail,
..self
};

// Create the right split following the separator.
let right = Self {
slice: right,
tail: right_tail,
..self
};

(left, Some(right))
} else {
// The search is exhausted, no more separators...
(Self { tail: 0, ..self }, None)
}
}

fn fold_with<F>(mut self, folder: F) -> F
fn fold_with<F>(self, mut folder: F) -> F
where
F: Folder<Self::Item>,
{
// TODO (MSRV 1.77):
// folder.consume_iter(self.slice.chunk_by(self.pred))
let Self {
slice, pred, tail, ..
} = self;

let (slice, tail) = if tail == slice.as_ref().len() {
// No tail section, so just let `consume_iter` do it all.
(Some(slice), None)
} else if let Some(index) = slice.rfind(pred, tail) {
// We found the last separator to complete the tail, so
// end with that slice after `consume_iter` finds the rest.
let (left, right) = slice.split(index);
(Some(left), Some(right))
} else {
// We know there are no separators at all, so it's all "tail".
(None, Some(slice))
};

if let Some(mut slice) = slice {
// TODO (MSRV 1.77) use either:
// folder.consume_iter(slice.chunk_by(pred))
// folder.consume_iter(slice.chunk_by_mut(pred))

folder = folder.consume_iter(std::iter::from_fn(move || {
let len = slice.as_ref().len();
if len > 0 {
let i = slice.find(pred, 0, len).unwrap_or(len);
let (head, tail) = mem::take(&mut slice).split(i);
slice = tail;
Some(head)
} else {
None
}
}));
}

folder.consume_iter(std::iter::from_fn(move || {
if self.slice.is_empty() {
None
} else {
let i = find_first_index(self.slice, self.pred).unwrap_or(self.slice.len());
let (head, tail) = self.slice.split_at(i);
self.slice = tail;
Some(head)
}
}))
if let Some(tail) = tail {
folder = folder.consume(tail);
}

folder
}
}

Expand Down Expand Up @@ -127,67 +185,16 @@ where
{
bridge_unindexed(
ChunkByProducer {
pred: &self.pred,
tail: self.slice.len(),
slice: self.slice,
pred: &self.pred,
marker: PhantomData,
},
consumer,
)
}
}

// Mutable

struct ChunkByMutProducer<'data, 'p, T, P> {
pred: &'p P,
slice: &'data mut [T],
}

impl<'data, 'p, T, P> UnindexedProducer for ChunkByMutProducer<'data, 'p, T, P>
where
T: Send,
P: Fn(&T, &T) -> bool + Send + Sync,
{
type Item = &'data mut [T];

fn split(self) -> (Self, Option<Self>) {
match find_index(self.slice, self.pred) {
Some(i) => {
let (ys, zs) = self.slice.split_at_mut(i);
(
Self {
pred: self.pred,
slice: ys,
},
Some(Self {
pred: self.pred,
slice: zs,
}),
)
}
None => (self, None),
}
}

fn fold_with<F>(mut self, folder: F) -> F
where
F: Folder<Self::Item>,
{
// TODO (MSRV 1.77):
// folder.consume_iter(self.slice.chunk_by_mut(self.pred))

folder.consume_iter(std::iter::from_fn(move || {
if self.slice.is_empty() {
None
} else {
let i = find_first_index(self.slice, self.pred).unwrap_or(self.slice.len());
let (head, tail) = std::mem::take(&mut self.slice).split_at_mut(i);
self.slice = tail;
Some(head)
}
}))
}
}

/// Parallel iterator over slice in (non-overlapping) mutable chunks
/// separated by a predicate.
///
Expand Down Expand Up @@ -225,9 +232,11 @@ where
C: UnindexedConsumer<Self::Item>,
{
bridge_unindexed(
ChunkByMutProducer {
pred: &self.pred,
ChunkByProducer {
tail: self.slice.len(),
slice: self.slice,
pred: &self.pred,
marker: PhantomData,
},
consumer,
)
Expand Down
46 changes: 46 additions & 0 deletions src/slice/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rand::distributions::Uniform;
use rand::seq::SliceRandom;
use rand::{thread_rng, Rng};
use std::cmp::Ordering::{Equal, Greater, Less};
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};

macro_rules! sort {
($f:ident, $name:ident) => {
Expand Down Expand Up @@ -168,3 +169,48 @@ fn test_par_rchunks_exact_mut_remainder() {
assert_eq!(c.take_remainder(), &[]);
assert_eq!(c.len(), 2);
}

#[test]
fn slice_chunk_by() {
let v: Vec<_> = (0..1000).collect();
assert_eq!(v[..0].par_chunk_by(|_, _| todo!()).count(), 0);
assert_eq!(v[..1].par_chunk_by(|_, _| todo!()).count(), 1);
assert_eq!(v[..2].par_chunk_by(|_, _| true).count(), 1);
assert_eq!(v[..2].par_chunk_by(|_, _| false).count(), 2);

let count = AtomicUsize::new(0);
let par: Vec<_> = v
.par_chunk_by(|x, y| {
count.fetch_add(1, Relaxed);
(x % 10 < 3) == (y % 10 < 3)
})
.collect();
assert_eq!(count.into_inner(), v.len() - 1);

let seq: Vec<_> = v.chunk_by(|x, y| (x % 10 < 3) == (y % 10 < 3)).collect();
assert_eq!(par, seq);
}

#[test]
fn slice_chunk_by_mut() {
let mut v: Vec<_> = (0..1000).collect();
assert_eq!(v[..0].par_chunk_by_mut(|_, _| todo!()).count(), 0);
assert_eq!(v[..1].par_chunk_by_mut(|_, _| todo!()).count(), 1);
assert_eq!(v[..2].par_chunk_by_mut(|_, _| true).count(), 1);
assert_eq!(v[..2].par_chunk_by_mut(|_, _| false).count(), 2);

let mut v2 = v.clone();
let count = AtomicUsize::new(0);
let par: Vec<_> = v
.par_chunk_by_mut(|x, y| {
count.fetch_add(1, Relaxed);
(x % 10 < 3) == (y % 10 < 3)
})
.collect();
assert_eq!(count.into_inner(), v2.len() - 1);

let seq: Vec<_> = v2
.chunk_by_mut(|x, y| (x % 10 < 3) == (y % 10 < 3))
.collect();
assert_eq!(par, seq);
}

0 comments on commit e37ec9e

Please sign in to comment.