Skip to content

Commit

Permalink
Fix a sync bug in stream_ref::wait (#1238)
Browse files Browse the repository at this point in the history
We were calling the wrong function
  • Loading branch information
PointKernel authored Jan 16, 2024
1 parent 2b38693 commit d2d809d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/stream_ref
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public:
*/
void wait() const
{
const auto __result = ::cudaStreamQuery(get());
const auto __result = ::cudaStreamSynchronize(get());
switch (__result)
{
case ::cudaSuccess:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

#include <cuda/stream_ref>
#include <cuda/std/cassert>
#include <atomic>
#include <chrono>
#include <thread>

void CUDART_CB callback(cudaStream_t, cudaError_t, void* flag)
{
std::chrono::milliseconds sleep_duration{1000};
std::this_thread::sleep_for(sleep_duration);
assert(!reinterpret_cast<std::atomic_flag*>(flag)->test_and_set());
}

void test_wait(cuda::stream_ref& ref) {
#ifndef _LIBCUDACXX_NO_EXCEPTIONS
Expand All @@ -31,8 +41,11 @@ int main(int argc, char** argv) {
NV_IF_TARGET(NV_IS_HOST,( // passing case
cudaStream_t stream;
cudaStreamCreate(&stream);
std::atomic_flag flag = ATOMIC_FLAG_INIT;
cudaStreamAddCallback(stream, callback, &flag, 0);
cuda::stream_ref ref{stream};
test_wait(ref);
assert(flag.test_and_set());
cudaStreamDestroy(stream);
))

Expand Down

0 comments on commit d2d809d

Please sign in to comment.