-
Notifications
You must be signed in to change notification settings - Fork 188
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CUDAX] Branch out an experimental version of stream_ref (#2343)
* Branch out experimental version of stream_ref * Add tests for the experimental part of stream_ref * Move inequality check * typo * Remove not needed using declaration * Add a TODO to remove NULL stream_ref * Remove TODO and remove NULL stream ref constructor * move runtime api include after the system header decl Co-authored-by: Michael Schellenberger Costa <[email protected]>
- Loading branch information
Showing
3 changed files
with
121 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
107 changes: 107 additions & 0 deletions
107
cudax/include/cuda/experimental/__stream/stream_ref.cuh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of CUDA Experimental in CUDA C++ Core Libraries, | ||
// under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef _CUDAX__STREAM_STREAM_REF | ||
#define _CUDAX__STREAM_STREAM_REF | ||
|
||
#include <cuda/std/detail/__config> | ||
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||
# pragma GCC system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) | ||
# pragma clang system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) | ||
# pragma system_header | ||
#endif // no system header | ||
|
||
#include <cuda_runtime_api.h> | ||
|
||
#include <cuda/std/__cuda/api_wrapper.h> | ||
#include <cuda/stream_ref> | ||
|
||
#include <cuda/experimental/__device/device_ref.cuh> | ||
#include <cuda/experimental/__event/timed_event.cuh> | ||
#include <cuda/experimental/__utility/ensure_current_device.cuh> | ||
|
||
namespace cuda::experimental | ||
{ | ||
|
||
//! @brief A non-owning wrapper for cudaStream_t. | ||
struct stream_ref : ::cuda::stream_ref | ||
{ | ||
using ::cuda::stream_ref::stream_ref; | ||
|
||
stream_ref() = delete; | ||
|
||
//! @brief Create a new event and record it into this stream | ||
//! | ||
//! @return A new event that was recorded into this stream | ||
//! | ||
//! @throws cuda_error if event creation or record failed | ||
_CCCL_NODISCARD event record_event(event::flags __flags = event::flags::none) const | ||
{ | ||
return event(*this, __flags); | ||
} | ||
|
||
//! @brief Create a new timed event and record it into this stream | ||
//! | ||
//! @return A new timed event that was recorded into this stream | ||
//! | ||
//! @throws cuda_error if event creation or record failed | ||
_CCCL_NODISCARD timed_event record_timed_event(event::flags __flags = event::flags::none) const | ||
{ | ||
return timed_event(*this, __flags); | ||
} | ||
|
||
using ::cuda::stream_ref::wait; | ||
|
||
//! @brief Make all future work submitted into this stream depend on completion of the specified event | ||
//! | ||
//! @param __ev Event that this stream should wait for | ||
//! | ||
//! @throws cuda_error if inserting the dependency fails | ||
void wait(event_ref __ev) const | ||
{ | ||
assert(__ev.get() != nullptr); | ||
// Need to use driver API, cudaStreamWaitEvent would push dev 0 if stack was empty | ||
detail::driver::streamWaitEvent(get(), __ev.get()); | ||
} | ||
|
||
//! @brief Make all future work submitted into this stream depend on completion of all work from the specified | ||
//! stream | ||
//! | ||
//! @param __other Stream that this stream should wait for | ||
//! | ||
//! @throws cuda_error if inserting the dependency fails | ||
void wait(stream_ref __other) const | ||
{ | ||
// TODO consider an optimization to not create an event every time and instead have one persistent event or one | ||
// per stream | ||
assert(__stream != detail::invalid_stream); | ||
event __tmp(__other); | ||
wait(__tmp); | ||
} | ||
|
||
//! @brief Get device under which this stream was created. | ||
//! | ||
//! @throws cuda_error if device check fails | ||
device_ref device() const | ||
{ | ||
// Because the stream can come from_native_handle, we can't just loop over devices comparing contexts, | ||
// lower to CUDART for this instead | ||
__ensure_current_device __dev_setter(*this); | ||
int result; | ||
_CCCL_TRY_CUDA_API(cudaGetDevice, "Could not get device from a stream", &result); | ||
return result; | ||
} | ||
}; | ||
|
||
} // namespace cuda::experimental | ||
|
||
#endif // _CUDAX__STREAM_STREAM_REF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters