diff --git a/include/ctranslate2/devices.h b/include/ctranslate2/devices.h index 2691efc3a..c2e27b329 100644 --- a/include/ctranslate2/devices.h +++ b/include/ctranslate2/devices.h @@ -22,6 +22,8 @@ namespace ctranslate2 { void synchronize_device(Device device, int index); void synchronize_stream(Device device); + void destroy_context(Device device); + class ScopedDeviceSetter { public: ScopedDeviceSetter(Device device, int index) diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h index 7ed39cde8..52b5626f2 100644 --- a/include/ctranslate2/replica_pool.h +++ b/include/ctranslate2/replica_pool.h @@ -352,6 +352,8 @@ namespace ctranslate2 { void finalize() override { _replica.reset(); + + destroy_context(_device); } private: diff --git a/src/cuda/utils.cc b/src/cuda/utils.cc index 2964c7ec5..d74f02e1d 100644 --- a/src/cuda/utils.cc +++ b/src/cuda/utils.cc @@ -6,6 +6,8 @@ #include #include +#include + #include "ctranslate2/utils.h" #include "env.h" @@ -81,7 +83,11 @@ namespace ctranslate2 { } ~CublasHandle() { ScopedDeviceSetter scoped_device_setter(Device::CUDA, _device); - cublasDestroy(_handle); + cublasStatus_t status = cublasDestroy(_handle); + + if (status != CUBLAS_STATUS_SUCCESS) + spdlog::error("cublasDestroy failed with status " + + std::string(cuda::cublasGetStatusName(status))); } cublasHandle_t get() const { return _handle; @@ -92,16 +98,20 @@ namespace ctranslate2 { }; // We create one cuBLAS/cuDNN handle per host thread. The handle is destroyed - // when the thread exits. + // when the thread exits or when destroy_handles is called. cudaStream_t get_cuda_stream() { static thread_local CudaStream cuda_stream; return cuda_stream.get(); } + static thread_local std::unique_ptr cublas_handle; + cublasHandle_t get_cublas_handle() { - static thread_local CublasHandle cublas_handle; - return cublas_handle.get(); + if (!cublas_handle) + cublas_handle = std::make_unique(); + + return cublas_handle->get(); } #ifdef CT2_WITH_CUDNN @@ -114,7 +124,11 @@ namespace ctranslate2 { } ~CudnnHandle() { ScopedDeviceSetter scoped_device_setter(Device::CUDA, _device); - cudnnDestroy(_handle); + cudnnStatus_t status = cudnnDestroy(_handle); + + if (status != CUDNN_STATUS_SUCCESS) + spdlog::error("cudnnDestroy failed with status " + + std::string(cudnnGetErrorString(status))); } cudnnHandle_t get() const { return _handle; @@ -124,9 +138,13 @@ namespace ctranslate2 { cudnnHandle_t _handle; }; + static thread_local std::unique_ptr cudnn_handle; + cudnnHandle_t get_cudnn_handle() { - static thread_local CudnnHandle cudnn_handle; - return cudnn_handle.get(); + if (!cudnn_handle) + cudnn_handle = std::make_unique(); + + return cudnn_handle->get(); } cudnnDataType_t get_cudnn_data_type(DataType dtype) { @@ -145,6 +163,14 @@ namespace ctranslate2 { } #endif + void destroy_handles() { +#ifdef CT2_WITH_CUDNN + cudnn_handle.reset(); +#endif + + cublas_handle.reset(); + } + int get_gpu_count() { int gpu_count = 0; cudaError_t status = cudaGetDeviceCount(&gpu_count); diff --git a/src/cuda/utils.h b/src/cuda/utils.h index 29bc99a39..4f2025038 100644 --- a/src/cuda/utils.h +++ b/src/cuda/utils.h @@ -49,6 +49,9 @@ namespace ctranslate2 { cudnnDataType_t get_cudnn_data_type(DataType dtype); #endif + // Destroy cuBLAS and cuDNN handles for the current thread. + void destroy_handles(); + int get_gpu_count(); bool has_gpu(); const cudaDeviceProp& get_device_properties(int device = -1); diff --git a/src/devices.cc b/src/devices.cc index 3822cc3c3..6c4fe8cc1 100644 --- a/src/devices.cc +++ b/src/devices.cc @@ -116,4 +116,14 @@ namespace ctranslate2 { #endif } + void destroy_context(Device device) { +#ifdef CT2_WITH_CUDA + if (device == Device::CUDA) { + cuda::destroy_handles(); + } +#else + (void)device; +#endif + } + }