Skip to content

Commit

Permalink
[STF] Thread safe graph_ctx (#3925)
Browse files Browse the repository at this point in the history
* Introduce a mutex to protect the underlying CUDA graph of a graph context so that we can generate tasks concurrently

* Protect CUDA graphs against concurrent accesses

* do test results

* remove dead code

* Add a test with graph capture and threads

* Add and use with_locked_graph, also use weak_ptr

* Big code simplification ! no need for shared_ptr/weak_ptr with the mutex that will outlive its users

* use std::reference_wrapper to have graph_task be moved assignable

* replace complicated API based on lambda function with a simple lock guard

* restore a comment removed by mistake

* capture with the lock taken because we need to ensure the captured stream is not used concurrently

* Fix build

* comment why we put a reference_wrapper

* Add a sanity check to ensure we have finished all tasks

* atomic variables are not automatically initialized, so we do set them

* Add missing mutex headers

* improve readability

---------

Co-authored-by: Andrei Alexandrescu <[email protected]>
  • Loading branch information
caugonnet and andralex authored Feb 26, 2025
1 parent 497843a commit 7fbbd24
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 17 deletions.
8 changes: 7 additions & 1 deletion cudax/include/cuda/experimental/__stf/graph/graph_ctx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
#include <cuda/experimental/__stf/internal/parallel_for_scope.cuh>
#include <cuda/experimental/__stf/places/blocked_partition.cuh> // for unit test!

#include <mutex>

namespace cuda::experimental::stf
{

Expand Down Expand Up @@ -261,6 +263,9 @@ class graph_ctx : public backend_ctx<graph_ctx>
bool submitted = false; // did we submit ?
mutable bool explicit_graph = false;

// To protect _graph against concurrent modifications
::std::mutex graph_mutex;

executable_graph_cache_stat cache_stats;

/* By default, the finalize operation is blocking, unless user provided
Expand Down Expand Up @@ -311,7 +316,8 @@ public:
auto task(exec_place e_place, task_dep<Deps>... deps)
{
auto dump_hooks = reserved::get_dump_hooks(this, deps...);
auto result = graph_task<Deps...>(*this, get_graph(), get_graph_epoch(), mv(e_place), mv(deps)...);
auto result =
graph_task<Deps...>(*this, get_graph(), this->state().graph_mutex, get_graph_epoch(), mv(e_place), mv(deps)...);
result.add_post_submission_hook(dump_hooks);
return result;
}
Expand Down
44 changes: 40 additions & 4 deletions cudax/include/cuda/experimental/__stf/graph/graph_task.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <cuda/experimental/__stf/internal/logical_data.cuh>
#include <cuda/experimental/__stf/internal/void_interface.cuh>

#include <mutex>

namespace cuda::experimental::stf
{

Expand All @@ -57,9 +59,14 @@ public:
// A cudaGraph_t is needed
graph_task() = delete;

graph_task(backend_ctx_untyped ctx, cudaGraph_t g, size_t epoch, exec_place e_place = exec_place::current_device())
graph_task(backend_ctx_untyped ctx,
cudaGraph_t g,
::std::mutex& graph_mutex,
size_t epoch,
exec_place e_place = exec_place::current_device())
: task(mv(e_place))
, ctx_graph(EXPECT(g))
, graph_mutex(graph_mutex)
, epoch(epoch)
, ctx(mv(ctx))
{
Expand All @@ -75,6 +82,8 @@ public:

graph_task& start()
{
::std::lock_guard<::std::mutex> lock(graph_mutex);

event_list prereqs = acquire(ctx);

// The CUDA graph API does not like duplicate dependencies
Expand All @@ -98,6 +107,8 @@ public:
/* End the task, but do not clear its data structures yet */
graph_task<>& end_uncleared()
{
::std::lock_guard<::std::mutex> lock(graph_mutex);

cudaGraphNode_t n;

auto done_prereqs = event_list();
Expand Down Expand Up @@ -388,6 +399,11 @@ public:
return ctx_graph;
}

[[nodiscard]] auto lock_ctx_graph()
{
return ::std::unique_lock<::std::mutex>(graph_mutex);
}

void set_current_place(pos4 p)
{
get_exec_place().as_grid().set_current_place(ctx, p);
Expand Down Expand Up @@ -428,7 +444,15 @@ private:

/* This is the support graph associated to the entire context */
cudaGraph_t ctx_graph = nullptr;
size_t epoch = 0;

// This protects ctx_graph : it's ok to store a reference to it because the
// context and this mutex will outlive the moment when this mutex is needed
// (and most likely the graph_task object)
// Note that we use a reference_wrapper instead of a mere reference to ensure
// the graph_task class remains move assignable.
::std::reference_wrapper<::std::mutex> graph_mutex;

size_t epoch = 0;

::std::vector<cudaGraphNode_t> ready_dependencies;

Expand All @@ -448,8 +472,13 @@ template <typename... Deps>
class graph_task : public graph_task<>
{
public:
graph_task(backend_ctx_untyped ctx, cudaGraph_t g, size_t epoch, exec_place e_place, task_dep<Deps>... deps)
: graph_task<>(mv(ctx), g, epoch, mv(e_place))
graph_task(backend_ctx_untyped ctx,
cudaGraph_t g,
::std::mutex& graph_mutex,
size_t epoch,
exec_place e_place,
task_dep<Deps>... deps)
: graph_task<>(mv(ctx), g, graph_mutex, epoch, mv(e_place))
{
static_assert(sizeof(*this) == sizeof(graph_task<>), "Cannot add state - it would be lost by slicing.");
add_deps(mv(deps)...);
Expand Down Expand Up @@ -544,6 +573,13 @@ public:
// CAPTURE the lambda
//

// To ensure the same CUDA stream is not used in multiple threads, we
// ensure there can't be multiple threads capturing at the same time.
//
// TODO : provide a per-thread CUDA stream dedicated for capture on that
// execution place.
auto lock = lock_ctx_graph();

// Get a stream from the pool associated to the execution place
cudaStream_t capture_stream = get_exec_place().getStream(ctx.async_resources(), true).stream;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ inline void task::release(backend_ctx_untyped& ctx, event_list& done_prereqs)
// This will, in particular, release shared_ptr to logical data captured in
// the dependencies.
pimpl->post_submission_hooks.clear();

#ifndef NDEBUG
ctx.increment_finished_task_count();
#endif
}

} // namespace cuda::experimental::stf
30 changes: 23 additions & 7 deletions cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ public:
if constexpr (::std::is_same_v<Ctx, graph_ctx>)
{
cudaHostNodeParams params = {.fn = callback, .userData = wrapper};

// Put this host node into the child graph that implements the graph_task<>
auto lock = t.lock_ctx_graph();
cuda_safe_call(cudaGraphAddHostNode(&t.get_node(), t.get_ctx_graph(), nullptr, 0, &params));
}
else
Expand Down Expand Up @@ -410,19 +412,20 @@ public:

if constexpr (::std::is_same_v<Ctx, graph_ctx>)
{
auto lock = t.lock_ctx_graph();
auto& g = t.get_ctx_graph();

// We have two situations : either there is a single kernel and we put the kernel in the context's
// graph, or we rely on a child graph
if (res.size() == 1)
{
insert_one_kernel(res[0], t.get_node(), t.get_ctx_graph());
insert_one_kernel(res[0], t.get_node(), g);
}
else
{
::std::vector<cudaGraphNode_t>& chain = t.get_node_chain();
chain.resize(res.size());

auto& g = t.get_ctx_graph();

// Create a chain of kernels
for (size_t i = 0; i < res.size(); i++)
{
Expand Down Expand Up @@ -455,6 +458,7 @@ public:

if constexpr (::std::is_same_v<Ctx, graph_ctx>)
{
auto lock = t.lock_ctx_graph();
insert_one_kernel(res, t.get_node(), t.get_ctx_graph());
}
else
Expand Down Expand Up @@ -561,8 +565,9 @@ protected:

virtual ~impl()
{
// We can't assert here because there may be tasks inside tasks
//_CCCL_ASSERT(total_task_cnt == 0, "You created some tasks but forgot to call finalize().");
#ifndef NDEBUG
_CCCL_ASSERT(total_task_cnt == total_finished_task_cnt, "Not all tasks were finished.");
#endif

if (!is_recording_stats)
{
Expand Down Expand Up @@ -708,7 +713,6 @@ protected:
{
// assert(!stack.hasCurrentTask());
attached_allocators.clear();
total_task_cnt.store(0);
// Leave custom_allocator, auto_scheduler, and auto_reordered as they were.
}

Expand All @@ -732,7 +736,12 @@ protected:
transfers;
bool is_recording_stats = false;
// Keep track of the number of tasks generated in the context
::std::atomic<size_t> total_task_cnt;
::std::atomic<size_t> total_task_cnt = 0;

#ifndef NDEBUG
// Keep track of the number of completed tasks in that context
::std::atomic<size_t> total_finished_task_cnt = 0;
#endif

// This data structure contains all resources useful for an efficient
// asynchronous execution. This will for example contain pools of CUDA
Expand Down Expand Up @@ -851,6 +860,13 @@ public:
++pimpl->total_task_cnt;
}

#ifndef NDEBUG
void increment_finished_task_count()
{
++pimpl->total_finished_task_cnt;
}
#endif

size_t task_count() const
{
return pimpl->total_task_cnt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,8 @@ public:
kernel_params.sharedMemBytes = 0;

// This new node will depend on the previous in the chain (allocation)
cuda_safe_call(cudaGraphAddKernelNode(&t.get_node(), t.get_ctx_graph(), NULL, 0, &kernel_params));
auto lock = t.lock_ctx_graph();
cudaGraphAddKernelNode(&t.get_node(), t.get_ctx_graph(), NULL, 0, &kernel_params);
}

return;
Expand All @@ -715,7 +716,8 @@ public:
return ::std::pair(size_t(minGridSize), size_t(blockSize));
}();

const auto [block_size, min_blocks] = conf;
const auto block_size = conf.first;
const auto min_blocks = conf.second;

// max_blocks is computed so we have one thread per element processed
const auto max_blocks = (n + block_size - 1) / block_size;
Expand Down Expand Up @@ -759,17 +761,17 @@ public:
}
else
{
auto g = t.get_ctx_graph();

_CCCL_ASSERT(sub_exec_place.is_device(), "Invalid execution place");
int dev_id = device_ordinal(sub_exec_place.affine_data_place());
const int dev_id = device_ordinal(sub_exec_place.affine_data_place());

cudaMemAllocNodeParams allocParams{};
allocParams.poolProps.allocType = cudaMemAllocationTypePinned;
allocParams.poolProps.handleTypes = cudaMemHandleTypeNone;
allocParams.poolProps.location = {.type = cudaMemLocationTypeDevice, .id = dev_id};
allocParams.bytesize = blocks * sizeof(redux_vars<deps_tup_t, ops_and_inits>);

auto lock = t.lock_ctx_graph();
auto g = t.get_ctx_graph();
const auto& input_nodes = t.get_ready_dependencies();

/* This first node depends on task's dependencies themselves */
Expand Down Expand Up @@ -898,7 +900,9 @@ public:
// This task corresponds to a single graph node, so we set that
// node instead of creating an child graph. Input and output
// dependencies will be filled later.
auto lock = t.lock_ctx_graph();
cuda_safe_call(cudaGraphAddKernelNode(&t.get_node(), t.get_ctx_graph(), nullptr, 0, &kernel_params));

// fprintf(stderr, "KERNEL NODE => graph %p, gridDim %d blockDim %d (n %ld)\n", t.get_graph(),
// kernel_params.gridDim.x, kernel_params.blockDim.x, n);
}
Expand Down
2 changes: 2 additions & 0 deletions cudax/test/stf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ set(stf_test_codegen_sources
stress/launch_vs_parallelfor.cu
stress/parallel_for_overhead.cu
# threads/axpy-threads-pfor.cu # Currently has a difficult-to-reproduce concurrency problem
threads/axpy-threads-graph.cu
threads/axpy-threads-graph-capture.cu
tools/auto_dump/auto_dump.cu
)

Expand Down
85 changes: 85 additions & 0 deletions cudax/test/stf/threads/axpy-threads-graph-capture.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

/**
* @file
*
* @brief Ensure a graph_ctx can be used concurrently
*
*/

#include <cuda/experimental/stf.cuh>

#include <mutex>
#include <thread>

using namespace cuda::experimental::stf;

__global__ void axpy(slice<int> y, slice<const int> x, int a)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int nthreads = gridDim.x * blockDim.x;

for (int i = tid; i < x.size(); i += nthreads)
{
y(i) += a * x(i);
}
}

void mytask(graph_ctx ctx, int /*id*/)
{
const size_t N = 16;

int alpha = 3;

auto lX = ctx.logical_data<int>(N);
auto lY = ctx.logical_data<int>(N);

ctx.parallel_for(lX.shape(), lX.write(), lY.write())->*[] __device__(size_t i, auto x, auto y) {
x(i) = (1 + i);
y(i) = (2 + i * i);
};

/* Compute Y = Y + alpha X */
for (size_t i = 0; i < 200; i++)
{
ctx.task(lY.rw(), lX.read())->*[alpha](cudaStream_t stream, auto dY, auto dX) {
axpy<<<128, 64, 0, stream>>>(dY, dX, alpha);
};
}

ctx.host_launch(lX.read(), lY.read())->*[alpha](auto x, auto y) {
for (size_t i = 0; i < N; i++)
{
EXPECT(x(i) == 1 + i);
EXPECT(y(i) == 2 + i * i + 200 * alpha * x(i));
}
};
}

int main()
{
graph_ctx ctx;

::std::vector<::std::thread> threads;
// Launch threads
for (int i = 0; i < 10; ++i)
{
threads.emplace_back(mytask, ctx, i);
}

// Wait for all threads to complete.
for (auto& th : threads)
{
th.join();
}

ctx.finalize();
}
Loading

0 comments on commit 7fbbd24

Please sign in to comment.