Skip to content

Commit

Permalink
[STF] Implement kernel chains in the graph backend without child grap…
Browse files Browse the repository at this point in the history
…hs (#3707)

* Reimplement chain of CUDA kernels in the CUDA graph backend to avoid child graphs

* simplify code

* Revert compilation environment changes that should not be committed

* remove unused var

* minor code improvements

* Update cudax/include/cuda/experimental/__stf/graph/graph_task.cuh

Cleaner code

Co-authored-by: Bernhard Manfred Gruber <[email protected]>

---------

Co-authored-by: Bernhard Manfred Gruber <[email protected]>
  • Loading branch information
caugonnet and bernhardmgruber authored Feb 10, 2025
1 parent d19c9a2 commit f745c97
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
36 changes: 30 additions & 6 deletions cudax/include/cuda/experimental/__stf/graph/graph_task.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ public:
}
else
{
// We either created independent task nodes, or a child graph. We need
// to inject input dependencies, and make the task completion depend on
// task nodes or the child graph.
// We either created independent task nodes, a chain of tasks, or a child
// graph. We need to inject input dependencies, and make the task
// completion depend on task nodes, task chain, or the child graph.
if (task_nodes.size() > 0)
{
for (auto& node : task_nodes)
Expand Down Expand Up @@ -146,6 +146,18 @@ public:
done_prereqs.add(mv(gnp));
}
}
else if (chained_task_nodes.size() > 0)
{
// First node depends on ready_dependencies
::std::vector<cudaGraphNode_t> out_array(ready_dependencies.size(), chained_task_nodes[0]);
cuda_safe_call(
cudaGraphAddDependencies(ctx_graph, ready_dependencies.data(), out_array.data(), ready_dependencies.size()));

// Overall the task depends on the completion of the last node
auto gnp = reserved::graph_event(chained_task_nodes.back(), epoch);
gnp->set_symbol(ctx, "done " + get_symbol());
done_prereqs.add(mv(gnp));
}
else
{
// Note that if nothing was done in the task, this will create a child
Expand Down Expand Up @@ -327,7 +339,8 @@ public:
cudaGraph_t& get_graph()
{
// We either use a child graph or task nodes, not both
assert(task_nodes.empty());
_CCCL_ASSERT(task_nodes.empty(), "cannot use both get_graph() and get_node()");
_CCCL_ASSERT(chained_task_nodes.empty(), "cannot use both get_graph() and get_node_chain()");

// Lazy creation
if (child_graph == nullptr)
Expand All @@ -342,14 +355,23 @@ public:
// Create a node in the graph
cudaGraphNode_t& get_node()
{
// We either use a child graph or task nodes, not both
assert(!child_graph);
_CCCL_ASSERT(!child_graph, "cannot use both get_node() and get_graph()");
_CCCL_ASSERT(chained_task_nodes.empty(), "cannot use both get_node() and get_node_chain()");

// Create a new entry and return it
task_nodes.emplace_back();
return task_nodes.back();
}

// Create a node in the graph
::std::vector<cudaGraphNode_t>& get_node_chain()
{
_CCCL_ASSERT(!child_graph, "cannot use both get_node_chain() and get_graph()");
_CCCL_ASSERT(task_nodes.empty(), "cannot use both get_node_chain() and get_node()");

return chained_task_nodes;
}

const auto& get_ready_dependencies() const
{
return ready_dependencies;
Expand Down Expand Up @@ -402,6 +424,8 @@ private:
* child graph, but add nodes directly */
::std::vector<cudaGraphNode_t> task_nodes;

::std::vector<cudaGraphNode_t> chained_task_nodes;

/* This is the support graph associated to the entire context */
cudaGraph_t ctx_graph = nullptr;
size_t epoch = 0;
Expand Down
16 changes: 5 additions & 11 deletions cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -418,24 +418,18 @@ public:
}
else
{
// Get the (child) graph associated to the task
auto g = t.get_graph();
::std::vector<cudaGraphNode_t>& chain = t.get_node_chain();
chain.resize(res.size());

cudaGraphNode_t n = nullptr;
cudaGraphNode_t prev_n = nullptr;
auto& g = t.get_ctx_graph();

// Create a chain of kernels
for (size_t i = 0; i < res.size(); i++)
{
insert_one_kernel(res[i], chain[i], g);
if (i > 0)
{
prev_n = n;
}

insert_one_kernel(res[i], n, g);
if (i > 0)
{
cuda_safe_call(cudaGraphAddDependencies(g, &prev_n, &n, 1));
cuda_safe_call(cudaGraphAddDependencies(g, &chain[i - 1], &chain[i], 1));
}
}
}
Expand Down

0 comments on commit f745c97

Please sign in to comment.