Skip to content

Commit

Permalink
[STF] Logical token (#3196)
Browse files Browse the repository at this point in the history
* Split the implementation of the void interface into the definition of the interface, and its implementations on streams and graphs.

* Add missing files

* Check if a task implementation can match a prototype where the void_interface arguments are ignored

* Implement ctx.abstract_logical_data() which relies on a void data interface

* Illustrate how to use abstract handles in local contexts

* Introduce an is_void_interface() virtual method in the data interface to potentially optimize some stages

* Small improvements in the examples

* Do not try to allocate or move void data

* Do not use I as a variable

* fix linkage error

* rename abtract_logical_data into logical_token

* Document logical token

* fix spelling error

* fix sphinx error

* reflect name changes

* use meaningful variable names

* simplify logical_token implementation because writeback is already disabled

* add a unit test for token elision

* implement token elision in host_launch

* Remove unused type

* Implement helpers to check if a function can be invoked from a tuple, or from a tuple where we removed tokens

* Much simpler is_tuple_invocable_with_filtered implementation

* Fix buggy test

* Factorize code

* Document that we can ignore tokens for task and host_launch

* Documentation for logical data freeze
  • Loading branch information
caugonnet authored Jan 3, 2025
1 parent b57e065 commit fd2a15d
Show file tree
Hide file tree
Showing 16 changed files with 625 additions and 128 deletions.
17 changes: 13 additions & 4 deletions cudax/examples/stf/void_data_interface.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,23 @@ int main()
{
context ctx;

auto ltask_res = ctx.logical_data(shape_of<void_interface>());
ctx.task(ltask_res.write())->*[](cudaStream_t, auto) {
auto token = ctx.logical_data(shape_of<void_interface>());
ctx.task(token.write())->*[](cudaStream_t, auto) {

};

void_interface sync;
auto ltask2_res = ctx.logical_data(sync);
ctx.task(ltask2_res.write(), ltask_res.read())->*[](cudaStream_t, auto, auto) {
auto token2 = ctx.logical_data(sync);

auto token3 = ctx.logical_token();
ctx.task(token2.write(), token.read())->*[](cudaStream_t, auto, auto) {

};

// Do not pass useless arguments by removing void_interface arguments
// Note that the rw() access is possible even if there was no prior write()
// or actual underlying data.
ctx.task(token3.rw(), token.read())->*[](cudaStream_t) {

};

Expand Down
1 change: 1 addition & 0 deletions cudax/include/cuda/experimental/__stf/graph/graph_ctx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <cuda/experimental/__stf/graph/graph_task.cuh>
#include <cuda/experimental/__stf/graph/interfaces/slice.cuh>
#include <cuda/experimental/__stf/graph/interfaces/void_interface.cuh>
#include <cuda/experimental/__stf/internal/acquire_release.cuh>
#include <cuda/experimental/__stf/internal/backend_allocator_setup.cuh>
#include <cuda/experimental/__stf/internal/launch.cuh>
Expand Down
25 changes: 22 additions & 3 deletions cudax/include/cuda/experimental/__stf/graph/graph_task.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <cuda/experimental/__stf/internal/backend_ctx.cuh> // graph_task<> has-a backend_ctx_untyped
#include <cuda/experimental/__stf/internal/frozen_logical_data.cuh>
#include <cuda/experimental/__stf/internal/logical_data.cuh>
#include <cuda/experimental/__stf/internal/void_interface.cuh>

namespace cuda::experimental::stf
{
Expand Down Expand Up @@ -508,8 +509,12 @@ public:
dot.template add_vertex<task, logical_data_untyped>(*this);
}

constexpr bool fun_invocable_stream_deps = ::std::is_invocable_v<Fun, cudaStream_t, Deps...>;
constexpr bool fun_invocable_stream_non_void_deps =
reserved::is_invocable_with_filtered<Fun, cudaStream_t, Deps...>::value;

// Default for the first argument is a `cudaStream_t`.
if constexpr (::std::is_invocable_v<Fun, cudaStream_t, Deps...>)
if constexpr (fun_invocable_stream_deps || fun_invocable_stream_non_void_deps)
{
//
// CAPTURE the lambda
Expand All @@ -522,7 +527,16 @@ public:
cuda_safe_call(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeThreadLocal));

// Launch the user provided function
::std::apply(f, tuple_prepend(mv(capture_stream), typed_deps()));
if constexpr (fun_invocable_stream_deps)
{
::std::apply(f, tuple_prepend(mv(capture_stream), typed_deps()));
}
else if constexpr (fun_invocable_stream_non_void_deps)
{
// Remove void arguments
::std::apply(::std::forward<Fun>(f),
tuple_prepend(mv(capture_stream), reserved::remove_void_interface_types(typed_deps())));
}

cuda_safe_call(cudaStreamEndCapture(capture_stream, &childGraph));

Expand All @@ -534,7 +548,12 @@ public:
}
else
{
static_assert(::std::is_invocable_v<Fun, cudaGraph_t, Deps...>, "Incorrect lambda function signature.");
constexpr bool fun_invocable_graph_deps = ::std::is_invocable_v<Fun, cudaGraph_t, Deps...>;
constexpr bool fun_invocable_graph_non_void_deps =
reserved::is_invocable_with_filtered<Fun, cudaGraph_t, Deps...>::value;

static_assert(fun_invocable_graph_deps || fun_invocable_graph_non_void_deps,
"Incorrect lambda function signature.");
//
// Give the lambda a child graph
//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

/**
* @file
*
* @brief This implements a void data interface over the graph_ctx backend
*/

#pragma once

#include <cuda/__cccl_config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/experimental/__stf/graph/graph_data_interface.cuh>
#include <cuda/experimental/__stf/internal/void_interface.cuh>

namespace cuda::experimental::stf
{

template <typename T>
struct graphed_interface_of;

/**
* @brief Data interface to manipulate the void interface in the CUDA graph backend
*/
class void_graph_interface : public graph_data_interface<void_interface>
{
public:
/// @brief Alias for the base class
using base = graph_data_interface<void_interface>;
/// @brief Alias for the shape type
using base::shape_t;

void_graph_interface(void_interface s)
: base(mv(s))
{}
void_graph_interface(shape_of<void_interface> s)
: base(mv(s))
{}

void data_allocate(
backend_ctx_untyped&,
block_allocator_untyped&,
const data_place&,
instance_id_t,
::std::ptrdiff_t& s,
void**,
event_list&) override
{
s = 0;
}

void data_deallocate(
backend_ctx_untyped&, block_allocator_untyped&, const data_place&, instance_id_t, void*, event_list&) final
{}

cudaGraphNode_t graph_data_copy(
cudaMemcpyKind,
instance_id_t,
instance_id_t,
cudaGraph_t graph,
const cudaGraphNode_t* input_nodes,
size_t input_cnt) override
{
cudaGraphNode_t dummy;
cuda_safe_call(cudaGraphAddEmptyNode(&dummy, graph, input_nodes, input_cnt));
return dummy;
}

bool pin_host_memory(instance_id_t) override
{
// no-op
return false;
}

void unpin_host_memory(instance_id_t) override {}

/* This helps detecting when we are manipulating a void data interface, so
* that we can optimize useless stages such as allocations or copies */
bool is_void_interface() const override final
{
return true;
}
};

/**
* @brief Define how the CUDA stream backend must manipulate this void interface
*
* Note that we specialize cuda::experimental::stf::shape_of to avoid ambiguous specialization
*
* @extends graphed_interface_of
*/
template <>
struct graphed_interface_of<void_interface>
{
using type = void_graph_interface;
};

} // end namespace cuda::experimental::stf
28 changes: 27 additions & 1 deletion cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <cuda/experimental/__stf/internal/slice.cuh> // backend_ctx<T> uses shape_of
#include <cuda/experimental/__stf/internal/task_state.cuh> // backend_ctx_untyped::impl has-a ctx_stack
#include <cuda/experimental/__stf/internal/thread_hierarchy.cuh>
#include <cuda/experimental/__stf/internal/void_interface.cuh>
#include <cuda/experimental/__stf/localization/composite_slice.cuh>

// XXX there is currently a dependency on this header for places.h
Expand Down Expand Up @@ -195,7 +196,22 @@ public:
{
delete w;
};
::std::apply(::std::forward<Fun>(w->first), mv(w->second));

constexpr bool fun_invocable_task_deps = reserved::is_tuple_invocable_v<Fun, decltype(payload)>;
constexpr bool fun_invocable_task_non_void_deps =
reserved::is_tuple_invocable_with_filtered<Fun, decltype(payload)>::value;

static_assert(fun_invocable_task_deps || fun_invocable_task_non_void_deps,
"Incorrect lambda function signature in host_launch.");

if constexpr (fun_invocable_task_deps)
{
::std::apply(::std::forward<Fun>(w->first), mv(w->second));
}
else if constexpr (fun_invocable_task_non_void_deps)
{
::std::apply(::std::forward<Fun>(w->first), reserved::remove_void_interface_types(mv(w->second)));
}
};

if constexpr (::std::is_same_v<Ctx, graph_ctx>)
Expand Down Expand Up @@ -1067,6 +1083,16 @@ public:
return logical_data(make_slice(p, n), mv(dplace));
}

auto logical_token()
{
// We do not use a shape because we want the first rw() access to succeed
// without an initial write()
//
// Note that we do not disable write back as the write-back mechanism is
// handling void_interface specifically to ignore it anyway.
return logical_data(void_interface{});
}

template <typename T>
frozen_logical_data<T> freeze(cuda::experimental::stf::logical_data<T> d,
access_mode m = access_mode::read,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ public:
return tp.find_data_instance_id(d);
}

/**
* @brief Indicates whether this is a void data interface, which permits to
* skip some operations to allocate or move data for example
*/
virtual bool is_void_interface() const
{
return false;
}

private:
/**
* @brief Get the common implementation of the data interface.
Expand Down
12 changes: 4 additions & 8 deletions cudax/include/cuda/experimental/__stf/internal/launch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -451,21 +451,17 @@ public:
}

size_t p_rank = 0;
if constexpr (::std::is_same_v<Ctx, stream_ctx>)
for (auto p : e_place)
{
for (auto p : e_place)
if constexpr (::std::is_same_v<Ctx, stream_ctx>)
{
reserved::launch_impl(interpreted_policy, p, f, args, t.get_stream(p_rank), p_rank);
p_rank++;
}
}
else
{
for (auto p : e_place)
else
{
reserved::graph_launch_impl(t, interpreted_policy, p, f, args, p_rank);
p_rank++;
}
p_rank++;
}
}

Expand Down
19 changes: 17 additions & 2 deletions cudax/include/cuda/experimental/__stf/internal/logical_data.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,12 @@ public:
return dinterface != nullptr;
}

bool is_void_interface() const
{
_CCCL_ASSERT(has_interface(), "uninitialized logical data");
return dinterface->is_void_interface();
}

bool has_ref() const
{
assert(refcnt.load() >= 0);
Expand Down Expand Up @@ -1255,6 +1261,15 @@ public:
return pimpl->dinterface != nullptr;
}

/**
* @brief Returns true if the data is a void data interface
*/
bool is_void_interface() const
{
assert(pimpl);
return pimpl->is_void_interface();
}

// This function applies the reduction operator over 2 instances, the one
// identified by "in_instance_id" is not modified, the one identified as
// "inout_instance_id" is where the result is put.
Expand Down Expand Up @@ -1727,7 +1742,7 @@ inline void reserved::logical_data_untyped_impl::erase()

/* If there is a reference instance id, it needs to be updated with a
* valid copy if that is not the case yet */
if (enable_write_back)
if (enable_write_back && !is_void_interface())
{
instance_id_t ref_id = reference_instance_id;
assert(ref_id != instance_id_t::invalid);
Expand Down Expand Up @@ -2032,7 +2047,7 @@ inline void fetch_data(
{
event_list stf_prereq = reserved::enforce_stf_deps_before(ctx, d, instance_id, t, mode, eplace);

if (d.has_interface())
if (d.has_interface() && !d.is_void_interface())
{
// Allocate data if needed (and possibly reclaim memory to do so)
reserved::dep_allocate(ctx, d, mode, dplace, eplace, instance_id, stf_prereq);
Expand Down
Loading

0 comments on commit fd2a15d

Please sign in to comment.