Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to temporarily set the current registry even if it is not associated with a worker thread #1166

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 67 additions & 4 deletions rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ static THE_REGISTRY_SET: Once = Once::new();
/// Starts the worker threads (if that has not already happened). If
/// initialization has not already occurred, use the default
/// configuration.
pub(super) fn global_registry() -> &'static Arc<Registry> {
fn global_registry() -> &'static Arc<Registry> {
set_global_registry(default_global_registry)
.or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) })
.expect("The global thread pool has not been initialized.")
Expand Down Expand Up @@ -217,6 +217,36 @@ fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
result
}

// This is used to temporarily overwrite the current registry.
//
// This either null, a pointer to the global registry if it was
// ever used to access the global registry or a pointer to a
// registry which is temporarily made current because the current
// thread is not a worker thread but is running a scope associated
// to a specific thread pool.
thread_local! {
static CURRENT_REGISTRY: Cell<*const Arc<Registry>> = const { Cell::new(ptr::null()) };
}

#[cold]
fn set_current_registry_to_global_registry() -> *const Arc<Registry> {
let global = global_registry();

CURRENT_REGISTRY.with(|current_registry| current_registry.set(global));

global
}

pub(super) fn current_registry() -> *const Arc<Registry> {
let mut current = CURRENT_REGISTRY.with(Cell::get);

if current.is_null() {
current = set_current_registry_to_global_registry();
}

current
}

struct Terminator<'a>(&'a Arc<Registry>);

impl<'a> Drop for Terminator<'a> {
Expand Down Expand Up @@ -315,22 +345,55 @@ impl Registry {
unsafe {
let worker_thread = WorkerThread::current();
let registry = if worker_thread.is_null() {
global_registry()
&*current_registry()
} else {
&(*worker_thread).registry
};
Arc::clone(registry)
}
}

/// Optionally install a specific registry as the current one.
///
/// This is used when a thread which is not a worker executes
/// a scope which should use the specific thread pool instead of
/// the global one.
pub(super) fn with_current<F, R>(registry: Option<&Arc<Registry>>, f: F) -> R
where
F: FnOnce() -> R,
{
struct Guard {
current: *const Arc<Registry>,
}

impl Guard {
fn new(registry: &Arc<Registry>) -> Self {
let current =
CURRENT_REGISTRY.with(|current_registry| current_registry.replace(registry));

Self { current }
}
}

impl Drop for Guard {
fn drop(&mut self) {
CURRENT_REGISTRY.with(|current_registry| current_registry.set(self.current));
}
}

let _guard = registry.map(Guard::new);

f()
}

/// Returns the number of threads in the current registry. This
/// is better than `Registry::current().num_threads()` because it
/// avoids incrementing the `Arc`.
pub(super) fn current_num_threads() -> usize {
unsafe {
let worker_thread = WorkerThread::current();
if worker_thread.is_null() {
global_registry().num_threads()
(*current_registry()).num_threads()
} else {
(*worker_thread).registry.num_threads()
}
Expand Down Expand Up @@ -946,7 +1009,7 @@ where
// invalidated until we return.
op(&*owner_thread, false)
} else {
global_registry().in_worker(op)
(*current_registry()).in_worker(op)
}
}
}
Expand Down
24 changes: 16 additions & 8 deletions rayon-core/src/scope/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
use crate::broadcast::BroadcastContext;
use crate::job::{ArcJob, HeapJob, JobFifo, JobRef};
use crate::latch::{CountLatch, Latch};
use crate::registry::{global_registry, in_worker, Registry, WorkerThread};
use crate::registry::{current_registry, in_worker, Registry, WorkerThread};
use crate::unwind;
use std::any::Any;
use std::fmt;
Expand Down Expand Up @@ -416,9 +416,11 @@ pub(crate) fn do_in_place_scope<'scope, OP, R>(registry: Option<&Arc<Registry>>,
where
OP: FnOnce(&Scope<'scope>) -> R,
{
let thread = unsafe { WorkerThread::current().as_ref() };
let scope = Scope::<'scope>::new(thread, registry);
scope.base.complete(thread, || op(&scope))
Registry::with_current(registry, || {
let thread = unsafe { WorkerThread::current().as_ref() };
let scope = Scope::<'scope>::new(thread, registry);
scope.base.complete(thread, || op(&scope))
})
}

/// Creates a "fork-join" scope `s` with FIFO order, and invokes the
Expand Down Expand Up @@ -453,9 +455,11 @@ pub(crate) fn do_in_place_scope_fifo<'scope, OP, R>(registry: Option<&Arc<Regist
where
OP: FnOnce(&ScopeFifo<'scope>) -> R,
{
let thread = unsafe { WorkerThread::current().as_ref() };
let scope = ScopeFifo::<'scope>::new(thread, registry);
scope.base.complete(thread, || op(&scope))
Registry::with_current(registry, || {
let thread = unsafe { WorkerThread::current().as_ref() };
let scope = ScopeFifo::<'scope>::new(thread, registry);
scope.base.complete(thread, || op(&scope))
})
}

impl<'scope> Scope<'scope> {
Expand Down Expand Up @@ -625,7 +629,11 @@ impl<'scope> ScopeBase<'scope> {
fn new(owner: Option<&WorkerThread>, registry: Option<&Arc<Registry>>) -> Self {
let registry = registry.unwrap_or_else(|| match owner {
Some(owner) => owner.registry(),
None => global_registry(),
// SAFETY: `current_registry` will either return a pointer to
// the global registry which has a 'static lifetime or
// to temporary one kept alive by `with_current`.
// In both case we can safely dereference it here to clone the `Arc`.
None => unsafe { &*current_registry() },
});

ScopeBase {
Expand Down
15 changes: 15 additions & 0 deletions rayon-core/src/thread_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,21 @@ impl ThreadPool {
unsafe { broadcast::broadcast_in(op, &self.registry) }
}

/// TODO
pub fn current() -> Self {
Self {
registry: Registry::current(),
}
}

/// TODO
pub fn with_current<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
Registry::with_current(Some(&self.registry), f)
}

/// Returns the (current) number of threads in the thread pool.
///
/// # Future compatibility note
Expand Down
93 changes: 93 additions & 0 deletions rayon-core/src/thread_pool/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::channel;
use std::sync::{Arc, Mutex};
use std::thread;

use crate::{join, Scope, ScopeFifo, ThreadPool, ThreadPoolBuilder};

Expand Down Expand Up @@ -381,6 +382,98 @@ fn in_place_scope_fifo_no_deadlock() {
});
}

#[test]
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
fn in_place_scope_which_pool() {
let pool = ThreadPoolBuilder::new()
.num_threads(1)
.thread_name(|_| "worker".to_owned())
.build()
.unwrap();

// Determine which pool is currently installed here
// by checking the thread name seen by spawned work items.
pool.in_place_scope(|scope| {
let (name_send, name_recv) = channel();

scope.spawn(move |_| {
let name = thread::current().name().map(ToOwned::to_owned);

name_send.send(name).unwrap();
});

let name = name_recv.recv().unwrap();

assert_eq!(name.as_deref(), Some("worker"));

let (name_send, name_recv) = channel();

crate::spawn(move || {
let name = thread::current().name().map(ToOwned::to_owned);

name_send.send(name).unwrap();
});

let name = name_recv.recv().unwrap();

assert_eq!(name.as_deref(), Some("worker"));

let (lhs_name, rhs_name) = crate::join(
|| thread::current().name().map(ToOwned::to_owned),
|| thread::current().name().map(ToOwned::to_owned),
);

assert_eq!(lhs_name.as_deref(), Some("worker"));
assert_eq!(rhs_name.as_deref(), Some("worker"));
});
}

#[test]
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
fn in_place_scope_fifo_which_pool() {
let pool = ThreadPoolBuilder::new()
.num_threads(1)
.thread_name(|_| "worker".to_owned())
.build()
.unwrap();

// Determine which pool is currently installed here
// by checking the thread name seen by spawned work items.
pool.in_place_scope_fifo(|scope| {
let (name_send, name_recv) = channel();

scope.spawn_fifo(move |_| {
let name = thread::current().name().map(ToOwned::to_owned);

name_send.send(name).unwrap();
});

let name = name_recv.recv().unwrap();

assert_eq!(name.as_deref(), Some("worker"));

let (name_send, name_recv) = channel();

crate::spawn_fifo(move || {
let name = thread::current().name().map(ToOwned::to_owned);

name_send.send(name).unwrap();
});

let name = name_recv.recv().unwrap();

assert_eq!(name.as_deref(), Some("worker"));

let (lhs_name, rhs_name) = crate::join(
|| thread::current().name().map(ToOwned::to_owned),
|| thread::current().name().map(ToOwned::to_owned),
);

assert_eq!(lhs_name.as_deref(), Some("worker"));
assert_eq!(rhs_name.as_deref(), Some("worker"));
});
}

#[test]
fn yield_now_to_spawn() {
let (tx, rx) = channel();
Expand Down
32 changes: 32 additions & 0 deletions tests/with_current.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use rayon::prelude::*;
use rayon::{ThreadPool, ThreadPoolBuilder};
use std::thread;

#[test]
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
fn with_current() {
fn high_priority_work() {
assert_eq!(thread::current().name(), Some("high-priority-thread"));
}

fn regular_work(_item: &()) {
assert_eq!(thread::current().name(), None);
}

let items = vec![(); 128];

let default_pool = ThreadPool::current();

let high_priority_pool = ThreadPoolBuilder::new()
.thread_name(|_| "high-priority-thread".to_owned())
.build()
.unwrap();

high_priority_pool.in_place_scope(|scope| {
scope.spawn(|_| high_priority_work());

default_pool.with_current(|| {
items.par_iter().for_each(|item| regular_work(item));
})
});
}