diff --git a/crates/bevy_tasks/src/slice.rs b/crates/bevy_tasks/src/slice.rs index 8410478322ee0..93568fd15dad0 100644 --- a/crates/bevy_tasks/src/slice.rs +++ b/crates/bevy_tasks/src/slice.rs @@ -6,6 +6,9 @@ pub trait ParallelSlice: AsRef<[T]> { /// in parallel across the provided `task_pool`. One task is spawned in the task pool /// for every chunk. /// + /// The iteration function takes the index of the chunk in the original slice as the + /// first argument, and the chunk as the second argument. + /// /// Returns a `Vec` of the mapped results in the same order as the input. /// /// # Example @@ -15,7 +18,7 @@ pub trait ParallelSlice: AsRef<[T]> { /// # use bevy_tasks::TaskPool; /// let task_pool = TaskPool::new(); /// let counts = (0..10000).collect::>(); - /// let incremented = counts.par_chunk_map(&task_pool, 100, |chunk| { + /// let incremented = counts.par_chunk_map(&task_pool, 100, |_index, chunk| { /// let mut results = Vec::new(); /// for count in chunk { /// results.push(*count + 2); @@ -32,14 +35,14 @@ pub trait ParallelSlice: AsRef<[T]> { /// - [`ParallelSlice::par_splat_map`] for mapping when a specific chunk size is unknown. fn par_chunk_map(&self, task_pool: &TaskPool, chunk_size: usize, f: F) -> Vec where - F: Fn(&[T]) -> R + Send + Sync, + F: Fn(usize, &[T]) -> R + Send + Sync, R: Send + 'static, { let slice = self.as_ref(); let f = &f; task_pool.scope(|scope| { - for chunk in slice.chunks(chunk_size) { - scope.spawn(async move { f(chunk) }); + for (index, chunk) in slice.chunks(chunk_size).enumerate() { + scope.spawn(async move { f(index, chunk) }); } }) } @@ -50,6 +53,9 @@ pub trait ParallelSlice: AsRef<[T]> { /// If `max_tasks` is `None`, this function will attempt to use one chunk per thread in /// `task_pool`. /// + /// The iteration function takes the index of the chunk in the original slice as the + /// first argument, and the chunk as the second argument. + /// /// Returns a `Vec` of the mapped results in the same order as the input. /// /// # Example @@ -59,7 +65,7 @@ pub trait ParallelSlice: AsRef<[T]> { /// # use bevy_tasks::TaskPool; /// let task_pool = TaskPool::new(); /// let counts = (0..10000).collect::>(); - /// let incremented = counts.par_splat_map(&task_pool, None, |chunk| { + /// let incremented = counts.par_splat_map(&task_pool, None, |_index, chunk| { /// let mut results = Vec::new(); /// for count in chunk { /// results.push(*count + 2); @@ -76,7 +82,7 @@ pub trait ParallelSlice: AsRef<[T]> { /// [`ParallelSlice::par_chunk_map`] for mapping when a specific chunk size is desirable. fn par_splat_map(&self, task_pool: &TaskPool, max_tasks: Option, f: F) -> Vec where - F: Fn(&[T]) -> R + Send + Sync, + F: Fn(usize, &[T]) -> R + Send + Sync, R: Send + 'static, { let slice = self.as_ref(); @@ -100,6 +106,9 @@ pub trait ParallelSliceMut: AsMut<[T]> { /// in parallel across the provided `task_pool`. One task is spawned in the task pool /// for every chunk. /// + /// The iteration function takes the index of the chunk in the original slice as the + /// first argument, and the chunk as the second argument. + /// /// Returns a `Vec` of the mapped results in the same order as the input. /// /// # Example @@ -109,7 +118,7 @@ pub trait ParallelSliceMut: AsMut<[T]> { /// # use bevy_tasks::TaskPool; /// let task_pool = TaskPool::new(); /// let mut counts = (0..10000).collect::>(); - /// let incremented = counts.par_chunk_map_mut(&task_pool, 100, |chunk| { + /// let incremented = counts.par_chunk_map_mut(&task_pool, 100, |_index, chunk| { /// let mut results = Vec::new(); /// for count in chunk { /// *count += 5; @@ -129,14 +138,14 @@ pub trait ParallelSliceMut: AsMut<[T]> { /// [`ParallelSliceMut::par_splat_map_mut`] for mapping when a specific chunk size is unknown. fn par_chunk_map_mut(&mut self, task_pool: &TaskPool, chunk_size: usize, f: F) -> Vec where - F: Fn(&mut [T]) -> R + Send + Sync, + F: Fn(usize, &mut [T]) -> R + Send + Sync, R: Send + 'static, { let slice = self.as_mut(); let f = &f; task_pool.scope(|scope| { - for chunk in slice.chunks_mut(chunk_size) { - scope.spawn(async move { f(chunk) }); + for (index, chunk) in slice.chunks_mut(chunk_size).enumerate() { + scope.spawn(async move { f(index, chunk) }); } }) } @@ -147,6 +156,9 @@ pub trait ParallelSliceMut: AsMut<[T]> { /// If `max_tasks` is `None`, this function will attempt to use one chunk per thread in /// `task_pool`. /// + /// The iteration function takes the index of the chunk in the original slice as the + /// first argument, and the chunk as the second argument. + /// /// Returns a `Vec` of the mapped results in the same order as the input. /// /// # Example @@ -156,7 +168,7 @@ pub trait ParallelSliceMut: AsMut<[T]> { /// # use bevy_tasks::TaskPool; /// let task_pool = TaskPool::new(); /// let mut counts = (0..10000).collect::>(); - /// let incremented = counts.par_splat_map_mut(&task_pool, None, |chunk| { + /// let incremented = counts.par_splat_map_mut(&task_pool, None, |_index, chunk| { /// let mut results = Vec::new(); /// for count in chunk { /// *count += 5; @@ -181,7 +193,7 @@ pub trait ParallelSliceMut: AsMut<[T]> { f: F, ) -> Vec where - F: Fn(&mut [T]) -> R + Send + Sync, + F: Fn(usize, &mut [T]) -> R + Send + Sync, R: Send + 'static, { let mut slice = self.as_mut(); @@ -207,7 +219,9 @@ mod tests { fn test_par_chunks_map() { let v = vec![42; 1000]; let task_pool = TaskPool::new(); - let outputs = v.par_splat_map(&task_pool, None, |numbers| -> i32 { numbers.iter().sum() }); + let outputs = v.par_splat_map(&task_pool, None, |_, numbers| -> i32 { + numbers.iter().sum() + }); let mut sum = 0; for output in outputs { @@ -222,7 +236,7 @@ mod tests { let mut v = vec![42; 1000]; let task_pool = TaskPool::new(); - let outputs = v.par_splat_map_mut(&task_pool, None, |numbers| -> i32 { + let outputs = v.par_splat_map_mut(&task_pool, None, |_, numbers| -> i32 { for number in numbers.iter_mut() { *number *= 2; } @@ -237,4 +251,15 @@ mod tests { assert_eq!(sum, 1000 * 42 * 2); assert_eq!(v[0], 84); } + + #[test] + fn test_par_chunks_map_index() { + let v = vec![1; 1000]; + let task_pool = TaskPool::new(); + let outputs = v.par_chunk_map(&task_pool, 100, |index, numbers| -> i32 { + numbers.iter().sum::() * index as i32 + }); + + assert_eq!(outputs.iter().sum::(), 100 * (9 * 10) / 2); + } }