Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDAX] Add host launch API allowing stream ordered host execution #3555

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions cudax/include/cuda/experimental/__launch/host_launch.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX__LAUNCH_HOST_LAUNCH
#define _CUDAX__LAUNCH_HOST_LAUNCH
#include <cuda/__cccl_config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/__functional/reference_wrapper.h>
#include <cuda/std/__type_traits/decay.h>
#include <cuda/std/__utility/forward.h>
#include <cuda/std/tuple>
#include <cuda/stream_ref>

namespace cuda::experimental
{

template <typename _CallablePtr>
void __stream_callback_caller(cudaStream_t, cudaError_t __status, void* __callable_ptr)
{
auto __casted_callable = static_cast<_CallablePtr>(__callable_ptr);
if (__status == cudaSuccess)
{
(*__casted_callable)();
}
delete __casted_callable;
}

//! @brief Launches a host callable to be executed in stream order on the provided stream
//!
//! Callable and arguments are copied into an internal dynamic allocation to preserve them
//! until the asynchronous call happens. Lambda capture or reference_wrapper can be used if
//! there is a need to pass something by reference.
//!
//! @param __stream Stream to launch the host function on
//! @param __callable Host function or callable object to call in stream order
//! @param __args Arguments to call the supplied callable with
template <typename _Callable, typename... _Args>
void host_launch(stream_ref __stream, _Callable __callable, _Args... __args)
{
static_assert(_CUDA_VSTD::is_invocable_v<_Callable, _Args...>,
"Callable can't be called with the supplied arguments");
auto __lambda_ptr = new auto([__callable = _CUDA_VSTD::move(__callable),
__args_tuple = _CUDA_VSTD::make_tuple(_CUDA_VSTD::move(__args)...)]() mutable {
_CUDA_VSTD::apply(__callable, __args_tuple);
});

// We use the callback here to have it execute even on stream error, because it needs to free the above allocation
_CCCL_TRY_CUDA_API(
cudaStreamAddCallback,
"Failed to launch host function",
__stream.get(),
__stream_callback_caller<decltype(__lambda_ptr)>,
static_cast<void*>(__lambda_ptr),
0);
}

template <typename _CallablePtr>
void __host_func_launcher(void* __callable_ptr)
{
auto __casted_callable = static_cast<_CallablePtr>(__callable_ptr);
(*__casted_callable)();
}

//! @brief Launches a host callable to be executed in stream order on the provided stream
//!
//! Callable will be called using the supplied reference. If the callable was destroyed
//! or moved by the time it is asynchronously called the behavior is undefined.
//!
//! Callable can't take any arguments, if some additional state is required a lambda can be used
//! to capture it.
//!
//! @param __stream Stream to launch the host function on
//! @param __callable A reference to a host function or callable object to call in stream order
template <typename _Callable, typename... _Args>
void host_launch(stream_ref __stream, ::cuda::std::reference_wrapper<_Callable> __callable)
{
static_assert(_CUDA_VSTD::is_invocable_v<_Callable>, "Callable in reference_wrapper can't take any arguments");
_CCCL_TRY_CUDA_API(
cudaLaunchHostFunc,
"Failed to launch host function",
__stream.get(),
__host_func_launcher<_Callable*>,
&__callable.get());
}

} // namespace cuda::experimental

#endif // !_CUDAX__LAUNCH_HOST_LAUNCH
1 change: 1 addition & 0 deletions cudax/include/cuda/experimental/launch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#define __CUDAX_LAUNCH___

#include <cuda/experimental/__launch/configuration.cuh>
#include <cuda/experimental/__launch/host_launch.cuh>
#include <cuda/experimental/__launch/launch.cuh>
#include <cuda/experimental/__launch/launch_transform.cuh>
#include <cuda/experimental/__launch/param_kind.cuh>
Expand Down
142 changes: 142 additions & 0 deletions cudax/test/launch/launch_smoke.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <cuda/experimental/launch.cuh>
#include <cuda/experimental/stream.cuh>

#include "cuda/experimental/__stream/stream_ref.cuh"
#include <cooperative_groups.h>
#include <testing.cuh>

Expand Down Expand Up @@ -310,3 +311,144 @@ TEST_CASE("Launch with default config")
{
test_default_config();
}

void block_stream(cudax::stream_ref stream, cuda::atomic<int>& atomic)
{
auto block_lambda = [&]() {
while (atomic != 1)
;
};
cudax::host_launch(stream, block_lambda);
}

void unblock_and_wait_stream(cudax::stream_ref stream, cuda::atomic<int>& atomic)
{
CUDAX_REQUIRE(!stream.ready());
atomic = 1;
stream.wait();
atomic = 0;
}

void launch_local_lambda(cudax::stream_ref stream, int& set, int set_to)
{
auto lambda = [&]() {
set = set_to;
};
cudax::host_launch(stream, lambda);
}

template <typename Lambda>
struct lambda_wrapper
{
Lambda lambda;

lambda_wrapper(const Lambda& lambda)
: lambda(lambda)
{}

lambda_wrapper(lambda_wrapper&&) = default;
lambda_wrapper(const lambda_wrapper&) = default;

void operator()()
{
if constexpr (cuda::std::is_same_v<cuda::std::invoke_result_t<Lambda>, void*>)
{
// If lambda returns the address it captured, confirm this object wasn't moved
CUDAX_REQUIRE(lambda() == this);
}
else
{
lambda();
}
}

// Make sure we fail if const is added to this wrapper anywhere
void operator()() const
{
CUDAX_REQUIRE(false);
}
};

TEST_CASE("Host launch")
{
cuda::atomic<int> atomic = 0;
cudax::stream stream;
int i = 0;

auto set_lambda = [&](int set) {
i = set;
};

SECTION("Can do a host launch")
{
block_stream(stream, atomic);

cudax::host_launch(stream, set_lambda, 2);

unblock_and_wait_stream(stream, atomic);
CUDAX_REQUIRE(i == 2);
}

SECTION("Can launch multiple functions")
{
block_stream(stream, atomic);
auto check_lambda = [&]() {
CUDAX_REQUIRE(i == 4);
};

cudax::host_launch(stream, set_lambda, 3);
cudax::host_launch(stream, set_lambda, 4);
cudax::host_launch(stream, check_lambda);
cudax::host_launch(stream, set_lambda, 5);
unblock_and_wait_stream(stream, atomic);
CUDAX_REQUIRE(i == 5);
}

SECTION("Non trivially copyable")
{
std::string s = "hello";

cudax::host_launch(
stream,
[&](auto str_arg) {
CUDAX_REQUIRE(s == str_arg);
},
s);
stream.wait();
}

SECTION("Confirm no const added to the callable")
{
lambda_wrapper wrapped_lambda([&]() {
i = 21;
});

cudax::host_launch(stream, wrapped_lambda);
stream.wait();
CUDAX_REQUIRE(i == 21)
}

SECTION("Can launch a local function and return")
{
block_stream(stream, atomic);
launch_local_lambda(stream, i, 42);
unblock_and_wait_stream(stream, atomic);
CUDAX_REQUIRE(i == 42);
}

SECTION("Launch by reference")
{
// Grab the pointer to confirm callable was not moved
void* wrapper_ptr = nullptr;
lambda_wrapper another_lambda_setter([&]() {
i = 84;
return wrapper_ptr;
});
wrapper_ptr = static_cast<void*>(&another_lambda_setter);

block_stream(stream, atomic);
host_launch(stream, cuda::std::ref(another_lambda_setter));
unblock_and_wait_stream(stream, atomic);
CUDAX_REQUIRE(i == 84);
}
}
Loading