Skip to content

Commit 64df01d

Browse files
committed
RFC: Add a hook for detecting task switches.
Certain libraries are configured using global or thread-local state instead of passing handles to every function. CUDA, for example, has a `cudaSetDevice` function that binds a device to the current thread for all future API calls. This is at odds with Julia's task-based concurrency, which presents an execution environment that's local to the current task (e.g., in the case of CUDA, using a different device). This PR adds a hook mechanism that can be used to detect task switches, and synchronize Julia's task-local environment with the library's global or thread-local state.
1 parent d234931 commit 64df01d

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

src/task.c

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,13 @@ static jl_ptls_t NOINLINE refetch_ptls(void)
480480
return jl_get_ptls_states();
481481
}
482482

483+
typedef void *(*jl_task_switch_hook_t)(jl_task_t *t JL_PROPAGATES_ROOT);
484+
jl_task_switch_hook_t task_switch_hook = NULL;
485+
JL_DLLEXPORT void jl_hook_task_switch(jl_task_switch_hook_t hook)
486+
{
487+
task_switch_hook = hook;
488+
}
489+
483490
JL_DLLEXPORT void jl_switch(void)
484491
{
485492
jl_ptls_t ptls = jl_get_ptls_states();
@@ -497,7 +504,7 @@ JL_DLLEXPORT void jl_switch(void)
497504
if (ptls->in_finalizer)
498505
jl_error("task switch not allowed from inside gc finalizer");
499506
if (ptls->in_pure_callback)
500-
jl_error("task switch not allowed from inside staged nor pure functions");
507+
jl_error("task switch not allowed from inside staged nor pure functions or callbacks");
501508
if (t->sticky && jl_atomic_load_acquire(&t->tid) == -1) {
502509
// manually yielding to a task
503510
if (jl_atomic_compare_exchange(&t->tid, -1, ptls->tid) != -1)
@@ -507,6 +514,13 @@ JL_DLLEXPORT void jl_switch(void)
507514
jl_error("cannot switch to task running on another thread");
508515
}
509516

517+
if (task_switch_hook) {
518+
int last_in = ptls->in_pure_callback;
519+
ptls->in_pure_callback = 1;
520+
task_switch_hook(t);
521+
ptls->in_pure_callback = last_in;
522+
}
523+
510524
// Store old values on the stack and reset
511525
sig_atomic_t defer_signal = ptls->defer_signal;
512526
int8_t gc_state = jl_gc_unsafe_enter(ptls);

0 commit comments

Comments
 (0)