diff --git a/cudax/include/cuda/experimental/__device/device.cuh b/cudax/include/cuda/experimental/__device/device.cuh index f91b0089d5f..35e0cfe2d4c 100644 --- a/cudax/include/cuda/experimental/__device/device.cuh +++ b/cudax/include/cuda/experimental/__device/device.cuh @@ -21,7 +21,13 @@ # pragma system_header #endif // no system header +#include + #include +#include + +#include +#include namespace cuda::experimental { @@ -33,7 +39,7 @@ struct __emplace_device { int __id_; - _CCCL_NODISCARD constexpr operator device() const noexcept; + _CCCL_NODISCARD operator device() const noexcept; _CCCL_NODISCARD constexpr const __emplace_device* operator->() const noexcept; }; @@ -56,6 +62,24 @@ public: # endif #endif + CUcontext primary_context() const + { + ::std::call_once(__init_once, [this]() { + __device = detail::driver::deviceGet(__id_); + __primary_ctx = detail::driver::primaryCtxRetain(__device); + }); + assert(__primary_ctx != nullptr); + return __primary_ctx; + } + + ~device() + { + if (__primary_ctx) + { + detail::driver::primaryCtxRelease(__device); + } + } + private: // TODO: put a mutable thread-safe (or thread_local) cache of device // properties here. @@ -63,6 +87,10 @@ private: friend class device_ref; friend struct detail::__emplace_device; + mutable CUcontext __primary_ctx = nullptr; + mutable CUdevice __device{}; + mutable ::std::once_flag __init_once; + explicit constexpr device(int __id) noexcept : device_ref(__id) {} @@ -76,7 +104,7 @@ private: namespace detail { -_CCCL_NODISCARD inline constexpr __emplace_device::operator device() const noexcept +_CCCL_NODISCARD inline __emplace_device::operator device() const noexcept { return device(__id_); } diff --git a/cudax/include/cuda/experimental/__device/device_ref.cuh b/cudax/include/cuda/experimental/__device/device_ref.cuh index f5945914da0..7f2635611f4 100644 --- a/cudax/include/cuda/experimental/__device/device_ref.cuh +++ b/cudax/include/cuda/experimental/__device/device_ref.cuh @@ -22,7 +22,6 @@ #endif // no system header #include -#include namespace cuda::experimental { @@ -103,69 +102,6 @@ public: } }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document - -//! @brief RAII helper which saves the current device and switches to the -//! specified device on construction and switches to the saved device on -//! destruction. -//! -struct __scoped_device -{ -private: - // The original device ordinal, or -1 if the device was not changed. - int const __old_device; - - //! @brief Returns the current device ordinal. - //! - //! @throws cuda_error if the device query fails. - static int __current_device() - { - int device = -1; - _CCCL_TRY_CUDA_API(cudaGetDevice, "failed to get the current device", &device); - return device; - } - - explicit __scoped_device(int new_device, int old_device) noexcept - : __old_device(new_device == old_device ? -1 : old_device) - {} - -public: - //! @brief Construct a new `__scoped_device` object and switch to the specified - //! device. - //! - //! @param new_device The device to switch to - //! - //! @throws cuda_error if the device switch fails - explicit __scoped_device(device_ref new_device) - : __scoped_device(new_device.get(), __current_device()) - { - if (__old_device != -1) - { - _CCCL_TRY_CUDA_API(cudaSetDevice, "failed to set the current device", new_device.get()); - } - } - - __scoped_device(__scoped_device&&) = delete; - __scoped_device(__scoped_device const&) = delete; - __scoped_device& operator=(__scoped_device&&) = delete; - __scoped_device& operator=(__scoped_device const&) = delete; - - //! @brief Destroy the `__scoped_device` object and switch back to the original - //! device. - //! - //! @throws cuda_error if the device switch fails. If the destructor is called - //! during stack unwinding, the program is automatically terminated. - ~__scoped_device() noexcept(false) - { - if (__old_device != -1) - { - _CCCL_TRY_CUDA_API(cudaSetDevice, "failed to restore the current device", __old_device); - } - } -}; - -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace cuda::experimental #endif // _CUDAX__DEVICE_DEVICE_REF diff --git a/cudax/include/cuda/experimental/__event/event.cuh b/cudax/include/cuda/experimental/__event/event.cuh index 0b6b7802b22..3ce997c55c3 100644 --- a/cudax/include/cuda/experimental/__event/event.cuh +++ b/cudax/include/cuda/experimental/__event/event.cuh @@ -30,6 +30,7 @@ #include #include +#include namespace cuda::experimental { @@ -54,7 +55,7 @@ public: //! //! @throws cuda_error if the event creation fails. explicit event(stream_ref __stream, flags __flags = flags::none) - : event(static_cast(__flags) | cudaEventDisableTiming) + : event(__stream, static_cast(__flags) | cudaEventDisableTiming) { record(__stream); } @@ -85,7 +86,9 @@ public: { if (__event_ != nullptr) { - [[maybe_unused]] auto __status = ::cudaEventDestroy(__event_); + // Needs to call driver API in case current device is not set, runtime version would set dev 0 current + // Alternative would be to store the device and push/pop here + [[maybe_unused]] auto __status = detail::driver::eventDestroy(__event_); } } @@ -144,9 +147,10 @@ private: : event_ref(__evnt) {} - explicit event(unsigned int __flags) + explicit event(stream_ref __stream, unsigned int __flags) : event_ref(::cudaEvent_t{}) { + [[maybe_unused]] __ensure_current_device __dev_setter(__stream); _CCCL_TRY_CUDA_API( ::cudaEventCreateWithFlags, "Failed to create CUDA event", &__event_, static_cast(__flags)); } diff --git a/cudax/include/cuda/experimental/__event/event_ref.cuh b/cudax/include/cuda/experimental/__event/event_ref.cuh index b795d46a77b..3b0ccc6dbcd 100644 --- a/cudax/include/cuda/experimental/__event/event_ref.cuh +++ b/cudax/include/cuda/experimental/__event/event_ref.cuh @@ -30,6 +30,8 @@ #include #include +#include + namespace cuda::experimental { class event; @@ -74,7 +76,8 @@ public: { assert(__event_ != nullptr); assert(__stream.get() != nullptr); - _CCCL_TRY_CUDA_API(::cudaEventRecord, "Failed to record CUDA event", __event_, __stream.get()); + // Need to use driver API, cudaEventRecord will push dev 0 if stack is empty + detail::driver::eventRecord(__event_, __stream.get()); } //! @brief Waits until all the work in the stream prior to the record of the diff --git a/cudax/include/cuda/experimental/__event/timed_event.cuh b/cudax/include/cuda/experimental/__event/timed_event.cuh index debcbcd26e5..48b9b0f1a5a 100644 --- a/cudax/include/cuda/experimental/__event/timed_event.cuh +++ b/cudax/include/cuda/experimental/__event/timed_event.cuh @@ -42,7 +42,7 @@ public: //! //! @throws cuda_error if the event creation fails. explicit timed_event(stream_ref __stream, flags __flags = flags::none) - : event(static_cast(__flags)) + : event(__stream, static_cast(__flags)) { record(__stream); } diff --git a/cudax/include/cuda/experimental/__launch/launch.cuh b/cudax/include/cuda/experimental/__launch/launch.cuh index 790af2a9d58..1a49cafa405 100644 --- a/cudax/include/cuda/experimental/__launch/launch.cuh +++ b/cudax/include/cuda/experimental/__launch/launch.cuh @@ -16,6 +16,7 @@ #include #include +#include #if _CCCL_STD_VER >= 2017 namespace cuda::experimental @@ -119,6 +120,7 @@ template & conf, const Kernel& kernel, Args... args) { + [[maybe_unused]] __ensure_current_device __dev_setter(stream); cudaError_t status; if constexpr (::cuda::std::is_invocable_v, Args...>) { @@ -181,6 +183,7 @@ void launch( template void launch(::cuda::stream_ref stream, const hierarchy_dimensions& dims, const Kernel& kernel, Args... args) { + [[maybe_unused]] __ensure_current_device __dev_setter(stream); cudaError_t status; if constexpr (::cuda::std::is_invocable_v, Args...>) { @@ -245,6 +248,7 @@ void launch(::cuda::stream_ref stream, void (*kernel)(kernel_config, ExpArgs...), ActArgs&&... args) { + [[maybe_unused]] __ensure_current_device __dev_setter(stream); cudaError_t status = [&](ExpArgs... args) { return detail::launch_impl(stream, conf, kernel, conf, args...); }(std::forward(args)...); @@ -299,6 +303,7 @@ void launch(::cuda::stream_ref stream, void (*kernel)(hierarchy_dimensions, ExpArgs...), ActArgs&&... args) { + [[maybe_unused]] __ensure_current_device __dev_setter(stream); cudaError_t status = [&](ExpArgs... args) { return detail::launch_impl(stream, kernel_config(dims), kernel, dims, args...); }(std::forward(args)...); @@ -354,6 +359,7 @@ void launch(::cuda::stream_ref stream, void (*kernel)(ExpArgs...), ActArgs&&... args) { + [[maybe_unused]] __ensure_current_device __dev_setter(stream); cudaError_t status = [&](ExpArgs... args) { return detail::launch_impl(stream, conf, kernel, args...); }(std::forward(args)...); @@ -406,6 +412,7 @@ template void launch( ::cuda::stream_ref stream, const hierarchy_dimensions& dims, void (*kernel)(ExpArgs...), ActArgs&&... args) { + [[maybe_unused]] __ensure_current_device __dev_setter(stream); cudaError_t status = [&](ExpArgs... args) { return detail::launch_impl(stream, kernel_config(dims), kernel, args...); }(std::forward(args)...); diff --git a/cudax/include/cuda/experimental/__stream/stream.cuh b/cudax/include/cuda/experimental/__stream/stream.cuh index 4859e9fabcb..0ba125269bd 100644 --- a/cudax/include/cuda/experimental/__stream/stream.cuh +++ b/cudax/include/cuda/experimental/__stream/stream.cuh @@ -27,6 +27,7 @@ #include #include +#include namespace cuda::experimental { @@ -51,7 +52,7 @@ struct stream : stream_ref //! @throws cuda_error if stream creation fails explicit stream(device_ref __dev, int __priority = default_priority) { - __scoped_device dev_setter(__dev); + [[maybe_unused]] __ensure_current_device __dev_setter(__dev); _CCCL_TRY_CUDA_API( ::cudaStreamCreateWithPriority, "Failed to create a stream", &__stream, cudaStreamDefault, __priority); } @@ -89,7 +90,9 @@ struct stream : stream_ref { if (__stream != detail::invalid_stream) { - [[maybe_unused]] auto status = ::cudaStreamDestroy(__stream); + // Needs to call driver API in case current device is not set, runtime version would set dev 0 current + // Alternative would be to store the device and push/pop here + [[maybe_unused]] auto status = detail::driver::streamDestroy(__stream); } } @@ -139,18 +142,20 @@ struct stream : stream_ref void wait(event_ref __ev) const { assert(__ev.get() != nullptr); - _CCCL_TRY_CUDA_API(::cudaStreamWaitEvent, "Failed to make a stream wait for an event", get(), __ev.get()); + // Need to use driver API, cudaStreamWaitEvent would push dev 0 if stack was empty + detail::driver::streamWaitEvent(get(), __ev.get()); } - //! @brief Make all future work submitted into this stream depend on completion of all work from the specified stream + //! @brief Make all future work submitted into this stream depend on completion of all work from the specified + //! stream //! //! @param __other Stream that this stream should wait for //! //! @throws cuda_error if inserting the dependency fails void wait(stream_ref __other) const { - // TODO consider an optimization to not create an event every time and instead have one persistent event or one per - // stream + // TODO consider an optimization to not create an event every time and instead have one persistent event or one + // per stream assert(__stream != detail::invalid_stream); event __tmp(__other); wait(__tmp); diff --git a/cudax/include/cuda/experimental/__utility/driver_api.cuh b/cudax/include/cuda/experimental/__utility/driver_api.cuh index 21b8c4d7425..8a52dd89fca 100644 --- a/cudax/include/cuda/experimental/__utility/driver_api.cuh +++ b/cudax/include/cuda/experimental/__utility/driver_api.cuh @@ -25,7 +25,13 @@ inline void* get_driver_entry_point(const char* name) { void* fn; cudaDriverEntryPointQueryResult result; +#if CUDART_VERSION >= 12050 + // For minor version compatibility request the 12.0 version of everything for now + cudaGetDriverEntryPointByVersion(name, &fn, 12000, cudaEnableDefault, &result); +#else + // Versioned get entry point not available before 12.5, but we don't need anything versioned before that cudaGetDriverEntryPoint(name, &fn, cudaEnableDefault, &result); +#endif if (result != cudaDriverEntryPointSuccess) { if (result == cudaDriverEntryPointVersionNotSufficent) @@ -56,11 +62,12 @@ inline void ctxPush(CUcontext ctx) call_driver_fn(driver_fn, "Failed to push context", ctx); } -inline void ctxPop() +inline CUcontext ctxPop() { static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuCtxPopCurrent); - CUcontext dummy; - call_driver_fn(driver_fn, "Failed to pop context", &dummy); + CUcontext result; + call_driver_fn(driver_fn, "Failed to pop context", &result); + return result; } inline CUcontext ctxGetCurrent() @@ -71,6 +78,38 @@ inline CUcontext ctxGetCurrent() return result; } +inline CUdevice deviceGet(int ordinal) +{ + static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuDeviceGet); + CUdevice result; + call_driver_fn(driver_fn, "Failed to get device", &result, ordinal); + return result; +} + +inline CUcontext primaryCtxRetain(CUdevice dev) +{ + static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuDevicePrimaryCtxRetain); + CUcontext result; + call_driver_fn(driver_fn, "Failed to retain context for a device", &result, dev); + return result; +} + +inline void primaryCtxRelease(CUdevice dev) +{ + static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuDevicePrimaryCtxRelease); + // TODO we might need to ignore failure here + call_driver_fn(driver_fn, "Failed to release context for a device", dev); +} + +inline bool isPrimaryCtxActive(CUdevice dev) +{ + static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuDevicePrimaryCtxGetState); + int result; + unsigned int dummy; + call_driver_fn(driver_fn, "Failed to check the primary ctx state", dev, &dummy, &result); + return result == 1; +} + inline CUcontext streamGetCtx(CUstream stream) { static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuStreamGetCtx); @@ -78,6 +117,31 @@ inline CUcontext streamGetCtx(CUstream stream) call_driver_fn(driver_fn, "Failed to get context from a stream", stream, &result); return result; } + +inline void streamWaitEvent(CUstream stream, CUevent event) +{ + static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuStreamWaitEvent); + call_driver_fn(driver_fn, "Failed to make a stream wait for an event", stream, event, CU_EVENT_WAIT_DEFAULT); +} + +inline void eventRecord(CUevent event, CUstream stream) +{ + static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuEventRecord); + call_driver_fn(driver_fn, "Failed to record CUDA event", event, stream); +} + +// Destroy calls return error codes to let the calling code decide if the error should be ignored +inline cudaError_t streamDestroy(CUstream stream) +{ + static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuStreamDestroy); + return static_cast(driver_fn(stream)); +} + +inline cudaError_t eventDestroy(CUevent event) +{ + static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuEventDestroy); + return static_cast(driver_fn(event)); +} } // namespace cuda::experimental::detail::driver #undef CUDAX_GET_DRIVER_FUNCTION diff --git a/cudax/include/cuda/experimental/__utility/ensure_current_device.cuh b/cudax/include/cuda/experimental/__utility/ensure_current_device.cuh new file mode 100644 index 00000000000..2431d028187 --- /dev/null +++ b/cudax/include/cuda/experimental/__utility/ensure_current_device.cuh @@ -0,0 +1,80 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#ifndef _CUDAX__UTILITY_ENSURE_CURRENT_DEVICE +#define _CUDAX__UTILITY_ENSURE_CURRENT_DEVICE + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include + +#include +#include + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +namespace cuda::experimental +{ +//! @brief RAII helper which on construction sets the current device to the specified one or one a +//! stream was created under. It sets the state back on destruction. +//! +struct __ensure_current_device +{ + //! @brief Construct a new `__ensure_current_device` object and switch to the specified + //! device. + //! + //! @param new_device The device to switch to + //! + //! @throws cuda_error if the device switch fails + explicit __ensure_current_device(device_ref new_device) + { + auto ctx = devices[new_device.get()].primary_context(); + detail::driver::ctxPush(ctx); + } + + //! @brief Construct a new `__ensure_current_device` object and switch to the device + //! under which the specified stream was created. + //! + //! @param stream Stream indicating the device to switch to + //! + //! @throws cuda_error if the device switch fails + explicit __ensure_current_device(stream_ref stream) + { + auto ctx = detail::driver::streamGetCtx(stream.get()); + detail::driver::ctxPush(ctx); + } + + __ensure_current_device(__ensure_current_device&&) = delete; + __ensure_current_device(__ensure_current_device const&) = delete; + __ensure_current_device& operator=(__ensure_current_device&&) = delete; + __ensure_current_device& operator=(__ensure_current_device const&) = delete; + + //! @brief Destroy the `__ensure_current_device` object and switch back to the original + //! device. + //! + //! @throws cuda_error if the device switch fails. If the destructor is called + //! during stack unwinding, the program is automatically terminated. + ~__ensure_current_device() noexcept(false) + { + // TODO would it make sense to assert here that we pushed and popped the same thing? + detail::driver::ctxPop(); + } +}; +} // namespace cuda::experimental +#endif // DOXYGEN_SHOULD_SKIP_THIS +#endif // _CUDAX__UTILITY_ENSURE_CURRENT_DEVICE diff --git a/cudax/test/CMakeLists.txt b/cudax/test/CMakeLists.txt index bb8a7d7c545..4752f8b9645 100644 --- a/cudax/test/CMakeLists.txt +++ b/cudax/test/CMakeLists.txt @@ -29,6 +29,7 @@ function(cudax_add_catch2_test target_name_var test_name cn_target) # ARGN=test target_link_libraries(${test_target} PRIVATE ${cn_target} Catch2::Catch2 catch2_main) target_link_libraries(${test_target} PRIVATE ${cn_target} cudax::Thrust) target_compile_options(${test_target} PRIVATE "-DLIBCUDACXX_ENABLE_EXPERIMENTAL_MEMORY_RESOURCE") + target_compile_options(${test_target} PRIVATE "-DLIBCUDACXX_ENABLE_EXCEPTIONS") target_compile_options(${test_target} PRIVATE $<$:--extended-lambda>) cudax_clone_target_properties(${test_target} ${cn_target}) set_target_properties(${test_target} PROPERTIES @@ -80,6 +81,7 @@ foreach(cn_target IN LISTS cudax_TARGETS) cudax_add_catch2_test(test_target misc_tests ${cn_target} utility/driver_api.cu + utility/ensure_current_device.cu ) cudax_add_catch2_test(test_target containers ${cn_target} diff --git a/cudax/test/common/utility.cuh b/cudax/test/common/utility.cuh index 2d7254c0699..64a54e1b480 100644 --- a/cudax/test/common/utility.cuh +++ b/cudax/test/common/utility.cuh @@ -137,6 +137,11 @@ struct spin_until_80 } }; +struct empty_kernel +{ + __device__ void operator()() const noexcept {} +}; + /// A kernel that takes a callable object and invokes it with a set of arguments template __global__ void invokernel(Fn fn, Args... args) @@ -144,5 +149,28 @@ __global__ void invokernel(Fn fn, Args... args) fn(args...); } +inline int count_driver_stack() +{ + if (cudax::detail::driver::ctxGetCurrent() != nullptr) + { + auto ctx = cudax::detail::driver::ctxPop(); + auto result = 1 + count_driver_stack(); + cudax::detail::driver::ctxPush(ctx); + return result; + } + else + { + return 0; + } +} + +inline void empty_driver_stack() +{ + while (cudax::detail::driver::ctxGetCurrent() != nullptr) + { + cudax::detail::driver::ctxPop(); + } +} + } // namespace test } // namespace diff --git a/cudax/test/device/device_smoke.cu b/cudax/test/device/device_smoke.cu index 86c9625e21c..6f772de08aa 100644 --- a/cudax/test/device/device_smoke.cu +++ b/cudax/test/device/device_smoke.cu @@ -8,7 +8,6 @@ // //===----------------------------------------------------------------------===// -#define LIBCUDACXX_ENABLE_EXCEPTIONS #include #include "../hierarchy/testing_common.cuh" @@ -260,9 +259,9 @@ TEST_CASE("global devices vector", "[device]") CUDAX_REQUIRE(1 == std::next(cudax::devices.begin())->get()); CUDAX_REQUIRE(1 == cudax::devices.begin()[1].get()); - CUDAX_REQUIRE(0 == (*std::prev(cudax::devices.end())).get()); - CUDAX_REQUIRE(0 == std::prev(cudax::devices.end())->get()); - CUDAX_REQUIRE(0 == cudax::devices.end()[-1].get()); + CUDAX_REQUIRE(cudax::devices.size() - 1 == (*std::prev(cudax::devices.end())).get()); + CUDAX_REQUIRE(cudax::devices.size() - 1 == std::prev(cudax::devices.end())->get()); + CUDAX_REQUIRE(cudax::devices.size() - 1 == cudax::devices.end()[-1].get()); } try diff --git a/cudax/test/launch/configuration.cu b/cudax/test/launch/configuration.cu index a47eea25908..9e7f98df1b0 100644 --- a/cudax/test/launch/configuration.cu +++ b/cudax/test/launch/configuration.cu @@ -8,7 +8,6 @@ // //===----------------------------------------------------------------------===// -#define LIBCUDACXX_ENABLE_EXCEPTIONS // Test translation of launch function arguments to cudaLaunchConfig_t sent to cudaLaunchKernelEx internally // We replace cudaLaunchKernelEx with a test function here through a macro to intercept the cudaLaunchConfig_t #define cudaLaunchKernelEx cudaLaunchKernelExTestReplacement diff --git a/cudax/test/launch/launch_smoke.cu b/cudax/test/launch/launch_smoke.cu index 554cabd015c..810e65c3908 100644 --- a/cudax/test/launch/launch_smoke.cu +++ b/cudax/test/launch/launch_smoke.cu @@ -7,7 +7,6 @@ // SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. // //===----------------------------------------------------------------------===// -#define LIBCUDACXX_ENABLE_EXCEPTIONS #include #include diff --git a/cudax/test/stream/get_stream.cu b/cudax/test/stream/get_stream.cu index 277a10246af..0654c3be393 100644 --- a/cudax/test/stream/get_stream.cu +++ b/cudax/test/stream/get_stream.cu @@ -8,7 +8,6 @@ // //===----------------------------------------------------------------------===// -#define LIBCUDACXX_ENABLE_EXCEPTIONS #include #include "../common/utility.cuh" diff --git a/cudax/test/stream/stream_smoke.cu b/cudax/test/stream/stream_smoke.cu index e6b86ccf16f..cbee3520806 100644 --- a/cudax/test/stream/stream_smoke.cu +++ b/cudax/test/stream/stream_smoke.cu @@ -8,7 +8,6 @@ // //===----------------------------------------------------------------------===// -#define LIBCUDACXX_ENABLE_EXCEPTIONS #include #include diff --git a/cudax/test/utility/driver_api.cu b/cudax/test/utility/driver_api.cu index 513d6476eb5..e5fd64d14f2 100644 --- a/cudax/test/utility/driver_api.cu +++ b/cudax/test/utility/driver_api.cu @@ -7,14 +7,14 @@ // SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. // //===----------------------------------------------------------------------===// -#define LIBCUDACXX_ENABLE_EXCEPTIONS #include #include "../hierarchy/testing_common.cuh" -TEST_CASE("Call each one", "[driver api]") +TEST_CASE("Call each driver api", "[utility]") { + namespace driver = cuda::experimental::detail::driver; cudaStream_t stream; // Assumes the ctx stack was empty or had one ctx, should be the case unless some other // test leaves 2+ ctxs on the stack @@ -22,23 +22,48 @@ TEST_CASE("Call each one", "[driver api]") // Pushes the primary context if the stack is empty CUDART(cudaStreamCreate(&stream)); - auto ctx = cuda::experimental::detail::driver::ctxGetCurrent(); + auto ctx = driver::ctxGetCurrent(); CUDAX_REQUIRE(ctx != nullptr); - cuda::experimental::detail::driver::ctxPop(); - CUDAX_REQUIRE(cuda::experimental::detail::driver::ctxGetCurrent() == nullptr); + // Confirm pop will leave the stack empty + driver::ctxPop(); + CUDAX_REQUIRE(driver::ctxGetCurrent() == nullptr); - cuda::experimental::detail::driver::ctxPush(ctx); - CUDAX_REQUIRE(cuda::experimental::detail::driver::ctxGetCurrent() == ctx); + // Confirm we can push multiple times + driver::ctxPush(ctx); + CUDAX_REQUIRE(driver::ctxGetCurrent() == ctx); - cuda::experimental::detail::driver::ctxPush(ctx); - CUDAX_REQUIRE(cuda::experimental::detail::driver::ctxGetCurrent() == ctx); + driver::ctxPush(ctx); + CUDAX_REQUIRE(driver::ctxGetCurrent() == ctx); - cuda::experimental::detail::driver::ctxPop(); - CUDAX_REQUIRE(cuda::experimental::detail::driver::ctxGetCurrent() == ctx); + driver::ctxPop(); + CUDAX_REQUIRE(driver::ctxGetCurrent() == ctx); - auto stream_ctx = cuda::experimental::detail::driver::streamGetCtx(stream); + // Confirm stream ctx match + auto stream_ctx = driver::streamGetCtx(stream); CUDAX_REQUIRE(ctx == stream_ctx); CUDART(cudaStreamDestroy(stream)); + + CUDAX_REQUIRE(driver::deviceGet(0) == 0); + + // Confirm we can retain the primary ctx that cudart retained first + auto primary_ctx = driver::primaryCtxRetain(0); + CUDAX_REQUIRE(ctx == primary_ctx); + + driver::ctxPop(); + CUDAX_REQUIRE(driver::ctxGetCurrent() == nullptr); + + CUDAX_REQUIRE(driver::isPrimaryCtxActive(0)); + // Confirm we can reset the primary context with double release + driver::primaryCtxRelease(0); + driver::primaryCtxRelease(0); + + CUDAX_REQUIRE(!driver::isPrimaryCtxActive(0)); + + // Confirm cudart can recover + CUDART(cudaStreamCreate(&stream)); + CUDAX_REQUIRE(driver::ctxGetCurrent() == ctx); + + CUDART(driver::streamDestroy(stream)); } diff --git a/cudax/test/utility/ensure_current_device.cu b/cudax/test/utility/ensure_current_device.cu new file mode 100644 index 00000000000..89efc7d4f6c --- /dev/null +++ b/cudax/test/utility/ensure_current_device.cu @@ -0,0 +1,135 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include "../common/utility.cuh" + +namespace driver = cuda::experimental::detail::driver; + +void recursive_check_device_setter(int id) +{ + int cudart_id; + cudax::__ensure_current_device setter(cudax::device_ref{id}); + CUDAX_REQUIRE(test::count_driver_stack() == cudax::devices.size() - id); + auto ctx = driver::ctxGetCurrent(); + CUDART(cudaGetDevice(&cudart_id)); + CUDAX_REQUIRE(cudart_id == id); + + if (id != 0) + { + recursive_check_device_setter(id - 1); + + CUDAX_REQUIRE(test::count_driver_stack() == cudax::devices.size() - id); + CUDAX_REQUIRE(ctx == driver::ctxGetCurrent()); + CUDART(cudaGetDevice(&cudart_id)); + CUDAX_REQUIRE(cudart_id == id); + } +} + +TEST_CASE("ensure current device", "[device]") +{ + test::empty_driver_stack(); + // If possible use something different than CUDART default 0 + int target_device = static_cast(cudax::devices.size() - 1); + int dev_id = 0; + + SECTION("device setter") + { + recursive_check_device_setter(target_device); + + CUDAX_REQUIRE(test::count_driver_stack() == 0); + } + + SECTION("stream interactions with driver stack") + { + { + cudax::stream stream(target_device); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + { + cudax::__ensure_current_device setter(cudax::device_ref{target_device}); + CUDAX_REQUIRE(driver::ctxGetCurrent() == driver::streamGetCtx(stream.get())); + } + { + auto ev = stream.record_event(); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + } + CUDAX_REQUIRE(test::count_driver_stack() == 0); + { + auto ev = stream.record_timed_event(); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + } + { + auto lambda = [&](int dev_id) { + cudax::stream another_stream(dev_id); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + stream.wait(another_stream); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + another_stream.wait(stream); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + }; + lambda(target_device); + if (cudax::devices.size() > 1) + { + lambda(0); + } + } + + cudax::__ensure_current_device setter(stream); + CUDAX_REQUIRE(test::count_driver_stack() == 1); + CUDART(cudaGetDevice(&dev_id)); + CUDAX_REQUIRE(dev_id == target_device); + CUDAX_REQUIRE(driver::ctxGetCurrent() == driver::streamGetCtx(stream.get())); + } + + CHECK(test::count_driver_stack() == 0); + + { + // Check NULL stream ref is handled ok + cudax::__ensure_current_device setter1(cudax::device_ref{target_device}); + cudaStream_t null_stream = nullptr; + auto ref = cuda::stream_ref(null_stream); + auto ctx = driver::ctxGetCurrent(); + CUDAX_REQUIRE(test::count_driver_stack() == 1); + + cudax::__ensure_current_device setter2(ref); + CUDAX_REQUIRE(test::count_driver_stack() == 2); + CUDAX_REQUIRE(ctx == driver::ctxGetCurrent()); + CUDART(cudaGetDevice(&dev_id)); + CUDAX_REQUIRE(dev_id == target_device); + } + } + + SECTION("event interactions with driver stack") + { + { + cudax::stream stream(target_device); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + + cudax::event event(stream); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + + event.record(stream); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + } + CUDAX_REQUIRE(test::count_driver_stack() == 0); + } + + SECTION("launch interactions with driver stack") + { + cudax::stream stream(target_device); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + cudax::launch(stream, cudax::make_hierarchy(cudax::block_dims<1>(), cudax::grid_dims<1>()), test::empty_kernel{}); + CUDAX_REQUIRE(test::count_driver_stack() == 0); + } +}