diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index d30f815bd..8ecb4fbab 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -161,7 +161,7 @@ static THE_REGISTRY_SET: Once = Once::new(); /// Starts the worker threads (if that has not already happened). If /// initialization has not already occurred, use the default /// configuration. -pub(super) fn global_registry() -> &'static Arc { +fn global_registry() -> &'static Arc { set_global_registry(default_global_registry) .or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) }) .expect("The global thread pool has not been initialized.") @@ -217,6 +217,36 @@ fn default_global_registry() -> Result, ThreadPoolBuildError> { result } +// This is used to temporarily overwrite the current registry. +// +// This either null, a pointer to the global registry if it was +// ever used to access the global registry or a pointer to a +// registry which is temporarily made current because the current +// thread is not a worker thread but is running a scope associated +// to a specific thread pool. +thread_local! { + static CURRENT_REGISTRY: Cell<*const Arc> = const { Cell::new(ptr::null()) }; +} + +#[cold] +fn set_current_registry_to_global_registry() -> *const Arc { + let global = global_registry(); + + CURRENT_REGISTRY.with(|current_registry| current_registry.set(global)); + + global +} + +pub(super) fn current_registry() -> *const Arc { + let mut current = CURRENT_REGISTRY.with(Cell::get); + + if current.is_null() { + current = set_current_registry_to_global_registry(); + } + + current +} + struct Terminator<'a>(&'a Arc); impl<'a> Drop for Terminator<'a> { @@ -315,7 +345,7 @@ impl Registry { unsafe { let worker_thread = WorkerThread::current(); let registry = if worker_thread.is_null() { - global_registry() + &*current_registry() } else { &(*worker_thread).registry }; @@ -323,6 +353,39 @@ impl Registry { } } + /// Optionally install a specific registry as the current one. + /// + /// This is used when a thread which is not a worker executes + /// a scope which should use the specific thread pool instead of + /// the global one. + pub(super) fn with_current(registry: Option<&Arc>, f: F) -> R + where + F: FnOnce() -> R, + { + struct Guard { + current: *const Arc, + } + + impl Guard { + fn new(registry: &Arc) -> Self { + let current = + CURRENT_REGISTRY.with(|current_registry| current_registry.replace(registry)); + + Self { current } + } + } + + impl Drop for Guard { + fn drop(&mut self) { + CURRENT_REGISTRY.with(|current_registry| current_registry.set(self.current)); + } + } + + let _guard = registry.map(Guard::new); + + f() + } + /// Returns the number of threads in the current registry. This /// is better than `Registry::current().num_threads()` because it /// avoids incrementing the `Arc`. @@ -330,7 +393,7 @@ impl Registry { unsafe { let worker_thread = WorkerThread::current(); if worker_thread.is_null() { - global_registry().num_threads() + (*current_registry()).num_threads() } else { (*worker_thread).registry.num_threads() } @@ -946,7 +1009,7 @@ where // invalidated until we return. op(&*owner_thread, false) } else { - global_registry().in_worker(op) + (*current_registry()).in_worker(op) } } } diff --git a/rayon-core/src/scope/mod.rs b/rayon-core/src/scope/mod.rs index 1d8732fea..8dd0234be 100644 --- a/rayon-core/src/scope/mod.rs +++ b/rayon-core/src/scope/mod.rs @@ -8,7 +8,7 @@ use crate::broadcast::BroadcastContext; use crate::job::{ArcJob, HeapJob, JobFifo, JobRef}; use crate::latch::{CountLatch, Latch}; -use crate::registry::{global_registry, in_worker, Registry, WorkerThread}; +use crate::registry::{current_registry, in_worker, Registry, WorkerThread}; use crate::unwind; use std::any::Any; use std::fmt; @@ -416,9 +416,11 @@ pub(crate) fn do_in_place_scope<'scope, OP, R>(registry: Option<&Arc>, where OP: FnOnce(&Scope<'scope>) -> R, { - let thread = unsafe { WorkerThread::current().as_ref() }; - let scope = Scope::<'scope>::new(thread, registry); - scope.base.complete(thread, || op(&scope)) + Registry::with_current(registry, || { + let thread = unsafe { WorkerThread::current().as_ref() }; + let scope = Scope::<'scope>::new(thread, registry); + scope.base.complete(thread, || op(&scope)) + }) } /// Creates a "fork-join" scope `s` with FIFO order, and invokes the @@ -453,9 +455,11 @@ pub(crate) fn do_in_place_scope_fifo<'scope, OP, R>(registry: Option<&Arc) -> R, { - let thread = unsafe { WorkerThread::current().as_ref() }; - let scope = ScopeFifo::<'scope>::new(thread, registry); - scope.base.complete(thread, || op(&scope)) + Registry::with_current(registry, || { + let thread = unsafe { WorkerThread::current().as_ref() }; + let scope = ScopeFifo::<'scope>::new(thread, registry); + scope.base.complete(thread, || op(&scope)) + }) } impl<'scope> Scope<'scope> { @@ -625,7 +629,11 @@ impl<'scope> ScopeBase<'scope> { fn new(owner: Option<&WorkerThread>, registry: Option<&Arc>) -> Self { let registry = registry.unwrap_or_else(|| match owner { Some(owner) => owner.registry(), - None => global_registry(), + // SAFETY: `current_registry` will either return a pointer to + // the global registry which has a 'static lifetime or + // to temporary one kept alive by `with_current`. + // In both case we can safely dereference it here to clone the `Arc`. + None => unsafe { &*current_registry() }, }); ScopeBase { diff --git a/rayon-core/src/thread_pool/mod.rs b/rayon-core/src/thread_pool/mod.rs index 5ae6e0f60..bcdf78856 100644 --- a/rayon-core/src/thread_pool/mod.rs +++ b/rayon-core/src/thread_pool/mod.rs @@ -198,6 +198,21 @@ impl ThreadPool { unsafe { broadcast::broadcast_in(op, &self.registry) } } + /// TODO + pub fn current() -> Self { + Self { + registry: Registry::current(), + } + } + + /// TODO + pub fn with_current(&self, f: F) -> R + where + F: FnOnce() -> R, + { + Registry::with_current(Some(&self.registry), f) + } + /// Returns the (current) number of threads in the thread pool. /// /// # Future compatibility note diff --git a/rayon-core/src/thread_pool/test.rs b/rayon-core/src/thread_pool/test.rs index 88b36282d..aeaa44b8e 100644 --- a/rayon-core/src/thread_pool/test.rs +++ b/rayon-core/src/thread_pool/test.rs @@ -3,6 +3,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::channel; use std::sync::{Arc, Mutex}; +use std::thread; use crate::{join, Scope, ScopeFifo, ThreadPool, ThreadPoolBuilder}; @@ -381,6 +382,98 @@ fn in_place_scope_fifo_no_deadlock() { }); } +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn in_place_scope_which_pool() { + let pool = ThreadPoolBuilder::new() + .num_threads(1) + .thread_name(|_| "worker".to_owned()) + .build() + .unwrap(); + + // Determine which pool is currently installed here + // by checking the thread name seen by spawned work items. + pool.in_place_scope(|scope| { + let (name_send, name_recv) = channel(); + + scope.spawn(move |_| { + let name = thread::current().name().map(ToOwned::to_owned); + + name_send.send(name).unwrap(); + }); + + let name = name_recv.recv().unwrap(); + + assert_eq!(name.as_deref(), Some("worker")); + + let (name_send, name_recv) = channel(); + + crate::spawn(move || { + let name = thread::current().name().map(ToOwned::to_owned); + + name_send.send(name).unwrap(); + }); + + let name = name_recv.recv().unwrap(); + + assert_eq!(name.as_deref(), Some("worker")); + + let (lhs_name, rhs_name) = crate::join( + || thread::current().name().map(ToOwned::to_owned), + || thread::current().name().map(ToOwned::to_owned), + ); + + assert_eq!(lhs_name.as_deref(), Some("worker")); + assert_eq!(rhs_name.as_deref(), Some("worker")); + }); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn in_place_scope_fifo_which_pool() { + let pool = ThreadPoolBuilder::new() + .num_threads(1) + .thread_name(|_| "worker".to_owned()) + .build() + .unwrap(); + + // Determine which pool is currently installed here + // by checking the thread name seen by spawned work items. + pool.in_place_scope_fifo(|scope| { + let (name_send, name_recv) = channel(); + + scope.spawn_fifo(move |_| { + let name = thread::current().name().map(ToOwned::to_owned); + + name_send.send(name).unwrap(); + }); + + let name = name_recv.recv().unwrap(); + + assert_eq!(name.as_deref(), Some("worker")); + + let (name_send, name_recv) = channel(); + + crate::spawn_fifo(move || { + let name = thread::current().name().map(ToOwned::to_owned); + + name_send.send(name).unwrap(); + }); + + let name = name_recv.recv().unwrap(); + + assert_eq!(name.as_deref(), Some("worker")); + + let (lhs_name, rhs_name) = crate::join( + || thread::current().name().map(ToOwned::to_owned), + || thread::current().name().map(ToOwned::to_owned), + ); + + assert_eq!(lhs_name.as_deref(), Some("worker")); + assert_eq!(rhs_name.as_deref(), Some("worker")); + }); +} + #[test] fn yield_now_to_spawn() { let (tx, rx) = channel(); diff --git a/tests/with_current.rs b/tests/with_current.rs new file mode 100644 index 000000000..92925dd87 --- /dev/null +++ b/tests/with_current.rs @@ -0,0 +1,32 @@ +use rayon::prelude::*; +use rayon::{ThreadPool, ThreadPoolBuilder}; +use std::thread; + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn with_current() { + fn high_priority_work() { + assert_eq!(thread::current().name(), Some("high-priority-thread")); + } + + fn regular_work(_item: &()) { + assert_eq!(thread::current().name(), None); + } + + let items = vec![(); 128]; + + let default_pool = ThreadPool::current(); + + let high_priority_pool = ThreadPoolBuilder::new() + .thread_name(|_| "high-priority-thread".to_owned()) + .build() + .unwrap(); + + high_priority_pool.in_place_scope(|scope| { + scope.spawn(|_| high_priority_work()); + + default_pool.with_current(|| { + items.par_iter().for_each(|item| regular_work(item)); + }) + }); +}