From e37ec9ee56dfaf2e780ad891b67c3251ffefd608 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Sat, 23 Mar 2024 13:18:01 -0700 Subject: [PATCH] Rewrite `ChunkByProducer` to be more like `SplitProducer` In particular, this now ensures we only call the predicate once on each pair of items. --- src/slice/chunk_by.rs | 239 ++++++++++++++++++++++-------------------- src/slice/test.rs | 46 ++++++++ 2 files changed, 170 insertions(+), 115 deletions(-) diff --git a/src/slice/chunk_by.rs b/src/slice/chunk_by.rs index 23bb40363..25833cabe 100644 --- a/src/slice/chunk_by.rs +++ b/src/slice/chunk_by.rs @@ -1,83 +1,141 @@ use crate::iter::plumbing::*; use crate::iter::*; -use std::fmt; +use std::marker::PhantomData; +use std::{fmt, mem}; -fn find_first_index(xs: &[T], pred: &P) -> Option -where - P: Fn(&T, &T) -> bool, -{ - xs.windows(2) - .position(|w| !pred(&w[0], &w[1])) - .map(|i| i + 1) +trait ChunkBySlice: AsRef<[T]> + Default + Send { + fn split(self, index: usize) -> (Self, Self); + + fn find(&self, pred: &impl Fn(&T, &T) -> bool, start: usize, end: usize) -> Option { + self.as_ref()[start..end] + .windows(2) + .position(move |w| !pred(&w[0], &w[1])) + .map(|i| i + 1) + } + + fn rfind(&self, pred: &impl Fn(&T, &T) -> bool, end: usize) -> Option { + self.as_ref()[..end] + .windows(2) + .rposition(move |w| !pred(&w[0], &w[1])) + .map(|i| i + 1) + } } -fn find_index(xs: &[T], pred: &P) -> Option -where - P: Fn(&T, &T) -> bool, -{ - let n = (xs.len() / 2).saturating_sub(1); - - for m in (1..((n / 2) + 1)).map(|x| 2 * x) { - let start = n.saturating_sub(m); - let end = std::cmp::min(n + m, xs.len()); - let fsts = &xs[start..end]; - let (_, snds) = fsts.split_first()?; - match fsts.iter().zip(snds).position(|(x, y)| !pred(x, y)) { - None => (), - Some(i) => return Some(start + i + 1), - } +impl ChunkBySlice for &[T] { + fn split(self, index: usize) -> (Self, Self) { + self.split_at(index) } - None } -struct ChunkByProducer<'data, 'p, T, P> { - pred: &'p P, - slice: &'data [T], +impl ChunkBySlice for &mut [T] { + fn split(self, index: usize) -> (Self, Self) { + self.split_at_mut(index) + } +} + +struct ChunkByProducer<'p, T, Slice, Pred> { + slice: Slice, + pred: &'p Pred, + tail: usize, + marker: PhantomData, } -impl<'data, 'p, T, P> UnindexedProducer for ChunkByProducer<'data, 'p, T, P> +// Note: this implementation is very similar to `SplitProducer`. +impl UnindexedProducer for ChunkByProducer<'_, T, Slice, Pred> where - T: Sync, - P: Fn(&T, &T) -> bool + Send + Sync, + Slice: ChunkBySlice, + Pred: Fn(&T, &T) -> bool + Send + Sync, { - type Item = &'data [T]; + type Item = Slice; fn split(self) -> (Self, Option) { - match find_index(self.slice, self.pred) { - Some(i) => { - let (ys, zs) = self.slice.split_at(i); - ( - Self { - pred: self.pred, - slice: ys, - }, - Some(Self { - pred: self.pred, - slice: zs, - }), - ) - } - None => (self, None), + if self.tail < 2 { + return (Self { tail: 0, ..self }, None); + } + + // Look forward for the separator, and failing that look backward. + let mid = self.tail / 2; + let index = match self.slice.find(self.pred, mid, self.tail) { + Some(i) => Some(mid + i), + None => self.slice.rfind(self.pred, mid + 1), + }; + + if let Some(index) = index { + let (left, right) = self.slice.split(index); + + let (left_tail, right_tail) = if index <= mid { + // If we scanned backwards to find the separator, everything in + // the right side is exhausted, with no separators left to find. + (index, 0) + } else { + (mid + 1, self.tail - index) + }; + + // Create the left split before the separator. + let left = Self { + slice: left, + tail: left_tail, + ..self + }; + + // Create the right split following the separator. + let right = Self { + slice: right, + tail: right_tail, + ..self + }; + + (left, Some(right)) + } else { + // The search is exhausted, no more separators... + (Self { tail: 0, ..self }, None) } } - fn fold_with(mut self, folder: F) -> F + fn fold_with(self, mut folder: F) -> F where F: Folder, { - // TODO (MSRV 1.77): - // folder.consume_iter(self.slice.chunk_by(self.pred)) + let Self { + slice, pred, tail, .. + } = self; + + let (slice, tail) = if tail == slice.as_ref().len() { + // No tail section, so just let `consume_iter` do it all. + (Some(slice), None) + } else if let Some(index) = slice.rfind(pred, tail) { + // We found the last separator to complete the tail, so + // end with that slice after `consume_iter` finds the rest. + let (left, right) = slice.split(index); + (Some(left), Some(right)) + } else { + // We know there are no separators at all, so it's all "tail". + (None, Some(slice)) + }; + + if let Some(mut slice) = slice { + // TODO (MSRV 1.77) use either: + // folder.consume_iter(slice.chunk_by(pred)) + // folder.consume_iter(slice.chunk_by_mut(pred)) + + folder = folder.consume_iter(std::iter::from_fn(move || { + let len = slice.as_ref().len(); + if len > 0 { + let i = slice.find(pred, 0, len).unwrap_or(len); + let (head, tail) = mem::take(&mut slice).split(i); + slice = tail; + Some(head) + } else { + None + } + })); + } - folder.consume_iter(std::iter::from_fn(move || { - if self.slice.is_empty() { - None - } else { - let i = find_first_index(self.slice, self.pred).unwrap_or(self.slice.len()); - let (head, tail) = self.slice.split_at(i); - self.slice = tail; - Some(head) - } - })) + if let Some(tail) = tail { + folder = folder.consume(tail); + } + + folder } } @@ -127,67 +185,16 @@ where { bridge_unindexed( ChunkByProducer { - pred: &self.pred, + tail: self.slice.len(), slice: self.slice, + pred: &self.pred, + marker: PhantomData, }, consumer, ) } } -// Mutable - -struct ChunkByMutProducer<'data, 'p, T, P> { - pred: &'p P, - slice: &'data mut [T], -} - -impl<'data, 'p, T, P> UnindexedProducer for ChunkByMutProducer<'data, 'p, T, P> -where - T: Send, - P: Fn(&T, &T) -> bool + Send + Sync, -{ - type Item = &'data mut [T]; - - fn split(self) -> (Self, Option) { - match find_index(self.slice, self.pred) { - Some(i) => { - let (ys, zs) = self.slice.split_at_mut(i); - ( - Self { - pred: self.pred, - slice: ys, - }, - Some(Self { - pred: self.pred, - slice: zs, - }), - ) - } - None => (self, None), - } - } - - fn fold_with(mut self, folder: F) -> F - where - F: Folder, - { - // TODO (MSRV 1.77): - // folder.consume_iter(self.slice.chunk_by_mut(self.pred)) - - folder.consume_iter(std::iter::from_fn(move || { - if self.slice.is_empty() { - None - } else { - let i = find_first_index(self.slice, self.pred).unwrap_or(self.slice.len()); - let (head, tail) = std::mem::take(&mut self.slice).split_at_mut(i); - self.slice = tail; - Some(head) - } - })) - } -} - /// Parallel iterator over slice in (non-overlapping) mutable chunks /// separated by a predicate. /// @@ -225,9 +232,11 @@ where C: UnindexedConsumer, { bridge_unindexed( - ChunkByMutProducer { - pred: &self.pred, + ChunkByProducer { + tail: self.slice.len(), slice: self.slice, + pred: &self.pred, + marker: PhantomData, }, consumer, ) diff --git a/src/slice/test.rs b/src/slice/test.rs index f74ca0f74..2538a86b9 100644 --- a/src/slice/test.rs +++ b/src/slice/test.rs @@ -5,6 +5,7 @@ use rand::distributions::Uniform; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; use std::cmp::Ordering::{Equal, Greater, Less}; +use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; macro_rules! sort { ($f:ident, $name:ident) => { @@ -168,3 +169,48 @@ fn test_par_rchunks_exact_mut_remainder() { assert_eq!(c.take_remainder(), &[]); assert_eq!(c.len(), 2); } + +#[test] +fn slice_chunk_by() { + let v: Vec<_> = (0..1000).collect(); + assert_eq!(v[..0].par_chunk_by(|_, _| todo!()).count(), 0); + assert_eq!(v[..1].par_chunk_by(|_, _| todo!()).count(), 1); + assert_eq!(v[..2].par_chunk_by(|_, _| true).count(), 1); + assert_eq!(v[..2].par_chunk_by(|_, _| false).count(), 2); + + let count = AtomicUsize::new(0); + let par: Vec<_> = v + .par_chunk_by(|x, y| { + count.fetch_add(1, Relaxed); + (x % 10 < 3) == (y % 10 < 3) + }) + .collect(); + assert_eq!(count.into_inner(), v.len() - 1); + + let seq: Vec<_> = v.chunk_by(|x, y| (x % 10 < 3) == (y % 10 < 3)).collect(); + assert_eq!(par, seq); +} + +#[test] +fn slice_chunk_by_mut() { + let mut v: Vec<_> = (0..1000).collect(); + assert_eq!(v[..0].par_chunk_by_mut(|_, _| todo!()).count(), 0); + assert_eq!(v[..1].par_chunk_by_mut(|_, _| todo!()).count(), 1); + assert_eq!(v[..2].par_chunk_by_mut(|_, _| true).count(), 1); + assert_eq!(v[..2].par_chunk_by_mut(|_, _| false).count(), 2); + + let mut v2 = v.clone(); + let count = AtomicUsize::new(0); + let par: Vec<_> = v + .par_chunk_by_mut(|x, y| { + count.fetch_add(1, Relaxed); + (x % 10 < 3) == (y % 10 < 3) + }) + .collect(); + assert_eq!(count.into_inner(), v2.len() - 1); + + let seq: Vec<_> = v2 + .chunk_by_mut(|x, y| (x % 10 < 3) == (y % 10 < 3)) + .collect(); + assert_eq!(par, seq); +}