Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an index argument to parallel iteration helpers in bevy_tasks #12169

Merged
merged 2 commits into from
Feb 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions crates/bevy_tasks/src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ pub trait ParallelSlice<T: Sync>: AsRef<[T]> {
/// in parallel across the provided `task_pool`. One task is spawned in the task pool
/// for every chunk.
///
/// The iteration function takes the index of the chunk in the original slice as the
/// first argument, and the chunk as the second argument.
///
/// Returns a `Vec` of the mapped results in the same order as the input.
///
/// # Example
Expand All @@ -15,7 +18,7 @@ pub trait ParallelSlice<T: Sync>: AsRef<[T]> {
/// # use bevy_tasks::TaskPool;
/// let task_pool = TaskPool::new();
/// let counts = (0..10000).collect::<Vec<u32>>();
/// let incremented = counts.par_chunk_map(&task_pool, 100, |chunk| {
/// let incremented = counts.par_chunk_map(&task_pool, 100, |_index, chunk| {
/// let mut results = Vec::new();
/// for count in chunk {
/// results.push(*count + 2);
Expand All @@ -32,14 +35,14 @@ pub trait ParallelSlice<T: Sync>: AsRef<[T]> {
/// - [`ParallelSlice::par_splat_map`] for mapping when a specific chunk size is unknown.
fn par_chunk_map<F, R>(&self, task_pool: &TaskPool, chunk_size: usize, f: F) -> Vec<R>
where
F: Fn(&[T]) -> R + Send + Sync,
F: Fn(usize, &[T]) -> R + Send + Sync,
R: Send + 'static,
{
let slice = self.as_ref();
let f = &f;
task_pool.scope(|scope| {
for chunk in slice.chunks(chunk_size) {
scope.spawn(async move { f(chunk) });
for (index, chunk) in slice.chunks(chunk_size).enumerate() {
scope.spawn(async move { f(index, chunk) });
}
})
}
Expand All @@ -50,6 +53,9 @@ pub trait ParallelSlice<T: Sync>: AsRef<[T]> {
/// If `max_tasks` is `None`, this function will attempt to use one chunk per thread in
/// `task_pool`.
///
/// The iteration function takes the index of the chunk in the original slice as the
/// first argument, and the chunk as the second argument.
///
/// Returns a `Vec` of the mapped results in the same order as the input.
///
/// # Example
Expand All @@ -59,7 +65,7 @@ pub trait ParallelSlice<T: Sync>: AsRef<[T]> {
/// # use bevy_tasks::TaskPool;
/// let task_pool = TaskPool::new();
/// let counts = (0..10000).collect::<Vec<u32>>();
/// let incremented = counts.par_splat_map(&task_pool, None, |chunk| {
/// let incremented = counts.par_splat_map(&task_pool, None, |_index, chunk| {
/// let mut results = Vec::new();
/// for count in chunk {
/// results.push(*count + 2);
Expand All @@ -76,7 +82,7 @@ pub trait ParallelSlice<T: Sync>: AsRef<[T]> {
/// [`ParallelSlice::par_chunk_map`] for mapping when a specific chunk size is desirable.
fn par_splat_map<F, R>(&self, task_pool: &TaskPool, max_tasks: Option<usize>, f: F) -> Vec<R>
where
F: Fn(&[T]) -> R + Send + Sync,
F: Fn(usize, &[T]) -> R + Send + Sync,
R: Send + 'static,
{
let slice = self.as_ref();
Expand All @@ -100,6 +106,9 @@ pub trait ParallelSliceMut<T: Send>: AsMut<[T]> {
/// in parallel across the provided `task_pool`. One task is spawned in the task pool
/// for every chunk.
///
/// The iteration function takes the index of the chunk in the original slice as the
/// first argument, and the chunk as the second argument.
///
/// Returns a `Vec` of the mapped results in the same order as the input.
///
/// # Example
Expand All @@ -109,7 +118,7 @@ pub trait ParallelSliceMut<T: Send>: AsMut<[T]> {
/// # use bevy_tasks::TaskPool;
/// let task_pool = TaskPool::new();
/// let mut counts = (0..10000).collect::<Vec<u32>>();
/// let incremented = counts.par_chunk_map_mut(&task_pool, 100, |chunk| {
/// let incremented = counts.par_chunk_map_mut(&task_pool, 100, |_index, chunk| {
/// let mut results = Vec::new();
/// for count in chunk {
/// *count += 5;
Expand All @@ -129,14 +138,14 @@ pub trait ParallelSliceMut<T: Send>: AsMut<[T]> {
/// [`ParallelSliceMut::par_splat_map_mut`] for mapping when a specific chunk size is unknown.
fn par_chunk_map_mut<F, R>(&mut self, task_pool: &TaskPool, chunk_size: usize, f: F) -> Vec<R>
where
F: Fn(&mut [T]) -> R + Send + Sync,
F: Fn(usize, &mut [T]) -> R + Send + Sync,
R: Send + 'static,
{
let slice = self.as_mut();
let f = &f;
task_pool.scope(|scope| {
for chunk in slice.chunks_mut(chunk_size) {
scope.spawn(async move { f(chunk) });
for (index, chunk) in slice.chunks_mut(chunk_size).enumerate() {
scope.spawn(async move { f(index, chunk) });
}
})
}
Expand All @@ -147,6 +156,9 @@ pub trait ParallelSliceMut<T: Send>: AsMut<[T]> {
/// If `max_tasks` is `None`, this function will attempt to use one chunk per thread in
/// `task_pool`.
///
/// The iteration function takes the index of the chunk in the original slice as the
/// first argument, and the chunk as the second argument.
///
/// Returns a `Vec` of the mapped results in the same order as the input.
///
/// # Example
Expand All @@ -156,7 +168,7 @@ pub trait ParallelSliceMut<T: Send>: AsMut<[T]> {
/// # use bevy_tasks::TaskPool;
/// let task_pool = TaskPool::new();
/// let mut counts = (0..10000).collect::<Vec<u32>>();
/// let incremented = counts.par_splat_map_mut(&task_pool, None, |chunk| {
/// let incremented = counts.par_splat_map_mut(&task_pool, None, |_index, chunk| {
/// let mut results = Vec::new();
/// for count in chunk {
/// *count += 5;
Expand All @@ -181,7 +193,7 @@ pub trait ParallelSliceMut<T: Send>: AsMut<[T]> {
f: F,
) -> Vec<R>
where
F: Fn(&mut [T]) -> R + Send + Sync,
F: Fn(usize, &mut [T]) -> R + Send + Sync,
R: Send + 'static,
{
let mut slice = self.as_mut();
Expand All @@ -207,7 +219,9 @@ mod tests {
fn test_par_chunks_map() {
let v = vec![42; 1000];
let task_pool = TaskPool::new();
let outputs = v.par_splat_map(&task_pool, None, |numbers| -> i32 { numbers.iter().sum() });
let outputs = v.par_splat_map(&task_pool, None, |_, numbers| -> i32 {
numbers.iter().sum()
});

let mut sum = 0;
for output in outputs {
Expand All @@ -222,7 +236,7 @@ mod tests {
let mut v = vec![42; 1000];
let task_pool = TaskPool::new();

let outputs = v.par_splat_map_mut(&task_pool, None, |numbers| -> i32 {
let outputs = v.par_splat_map_mut(&task_pool, None, |_, numbers| -> i32 {
for number in numbers.iter_mut() {
*number *= 2;
}
Expand All @@ -237,4 +251,15 @@ mod tests {
assert_eq!(sum, 1000 * 42 * 2);
assert_eq!(v[0], 84);
}

#[test]
fn test_par_chunks_map_index() {
let v = vec![1; 1000];
let task_pool = TaskPool::new();
let outputs = v.par_chunk_map(&task_pool, 100, |index, numbers| -> i32 {
numbers.iter().sum::<i32>() * index as i32
});

assert_eq!(outputs.iter().sum::<i32>(), 100 * (9 * 10) / 2);
}
}