From d2d809dda14029f04826b72c169adc925cbe7525 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 16 Jan 2024 01:19:18 -0800 Subject: [PATCH] Fix a sync bug in `stream_ref::wait` (#1238) We were calling the wrong function --- libcudacxx/include/cuda/stream_ref | 2 +- .../cuda/stream_ref/stream_ref.wait.pass.cpp | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/libcudacxx/include/cuda/stream_ref b/libcudacxx/include/cuda/stream_ref index 08937f56bf5..b2a83c463a9 100644 --- a/libcudacxx/include/cuda/stream_ref +++ b/libcudacxx/include/cuda/stream_ref @@ -139,7 +139,7 @@ public: */ void wait() const { - const auto __result = ::cudaStreamQuery(get()); + const auto __result = ::cudaStreamSynchronize(get()); switch (__result) { case ::cudaSuccess: diff --git a/libcudacxx/test/libcudacxx/cuda/stream_ref/stream_ref.wait.pass.cpp b/libcudacxx/test/libcudacxx/cuda/stream_ref/stream_ref.wait.pass.cpp index 3a52af0b7d9..963a3fa533c 100644 --- a/libcudacxx/test/libcudacxx/cuda/stream_ref/stream_ref.wait.pass.cpp +++ b/libcudacxx/test/libcudacxx/cuda/stream_ref/stream_ref.wait.pass.cpp @@ -14,6 +14,16 @@ #include #include +#include +#include +#include + +void CUDART_CB callback(cudaStream_t, cudaError_t, void* flag) +{ + std::chrono::milliseconds sleep_duration{1000}; + std::this_thread::sleep_for(sleep_duration); + assert(!reinterpret_cast(flag)->test_and_set()); +} void test_wait(cuda::stream_ref& ref) { #ifndef _LIBCUDACXX_NO_EXCEPTIONS @@ -31,8 +41,11 @@ int main(int argc, char** argv) { NV_IF_TARGET(NV_IS_HOST,( // passing case cudaStream_t stream; cudaStreamCreate(&stream); + std::atomic_flag flag = ATOMIC_FLAG_INIT; + cudaStreamAddCallback(stream, callback, &flag, 0); cuda::stream_ref ref{stream}; test_wait(ref); + assert(flag.test_and_set()); cudaStreamDestroy(stream); ))