diff --git a/src/infer_response.cc b/src/infer_response.cc index afadc324..cdc4ad89 100644 --- a/src/infer_response.cc +++ b/src/infer_response.cc @@ -299,7 +299,7 @@ InferResponse::Send( output_buffer = PbMemory::Create( shm_pool, actual_memory_type, actual_memory_type_id, output_tensor->ByteSize(), reinterpret_cast(buffer), - false /* copy_gpu */)); + false /* copy_gpu */, true /* write_back_data */)); output_buffer->SetCudaIpcHandle(cuda_ipc_mem_handle_p); } else { SET_ERROR_AND_RETURN_IF_EXCEPTION( @@ -307,7 +307,7 @@ InferResponse::Send( output_buffer = PbMemory::Create( shm_pool, actual_memory_type, actual_memory_type_id, output_tensor->ByteSize(), reinterpret_cast(buffer), - true /* copy_gpu */)); + true /* copy_gpu */, true /* write_back_data */)); } gpu_buffer_helper.AddBuffer(output_buffer->ShmHandle()); output_buffers.push_back({std::move(output_buffer), buffer}); diff --git a/src/ipc_message.h b/src/ipc_message.h index 7040f2b4..0ac32879 100644 --- a/src/ipc_message.h +++ b/src/ipc_message.h @@ -41,6 +41,8 @@ typedef enum PYTHONSTUB_commandtype_enum { PYTHONSTUB_ExecuteResponse, PYTHONSTUB_InitializeRequest, PYTHONSTUB_InitializeResponse, + PYTHONSTUB_CUDAPoolInitializeRequest, + PYTHONSTUB_CUDAPoolInitializeResponse, PYTHONSTUB_FinalizeRequest, PYTHONSTUB_FinalizeResponse, PYTHONSTUB_LoadGPUBuffers, diff --git a/src/pb_memory.cc b/src/pb_memory.cc index c18bf912..1d4c670e 100644 --- a/src/pb_memory.cc +++ b/src/pb_memory.cc @@ -32,10 +32,9 @@ std::unique_ptr PbMemory::Create( std::unique_ptr& shm_pool, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, - uint64_t byte_size, char* data, bool copy_gpu) + uint64_t byte_size, char* data, bool copy_gpu, bool write_back_data) { size_t requested_byte_size = sizeof(MemoryShm); - if (memory_type == TRITONSERVER_MEMORY_GPU) { #ifdef TRITON_ENABLE_GPU requested_byte_size += sizeof(cudaIpcMemHandle_t); @@ -46,9 +45,11 @@ PbMemory::Create( AllocatedSharedMemory memory_shm = shm_pool->Construct(requested_byte_size); + + void* backend_memory; PbMemory::FillShmData( - memory_type, memory_type_id, byte_size, data, memory_shm.data_.get(), - memory_shm.handle_, copy_gpu); + shm_pool, &backend_memory, memory_type, memory_type_id, byte_size, data, + memory_shm.data_.get(), memory_shm.handle_, copy_gpu, write_back_data); if (memory_type == TRITONSERVER_MEMORY_CPU) { data = memory_shm.data_.get() + sizeof(MemoryShm); @@ -59,6 +60,16 @@ PbMemory::Create( #ifdef TRITON_ENABLE_GPU if (memory_type == TRITONSERVER_MEMORY_GPU) { +#ifndef TRITON_PB_STUB + if (pb_memory->memory_shm_ptr_->use_cuda_shared_pool) { + pb_memory->backend_memory_.reset( + reinterpret_cast(backend_memory)); + if (write_back_data) { + // Store the original buffer so that we can write back to it later. + pb_memory->original_buffer_ = data; + } + } +#endif pb_memory->memory_shm_ptr_->gpu_pointer_offset = pb_memory->GetGPUPointerOffset(); } @@ -79,16 +90,46 @@ PbMemory::Create( return pb_memory; } + + +void +PbMemory::WriteBackOutput(std::unique_ptr& shm_pool) +{ + if (original_buffer_) { + cudaMemcpyKind kind = cudaMemcpyDeviceToDevice; + cudaError_t err; + err = cudaMemcpy( + original_buffer_, backend_memory_->MemoryPtr(), + memory_shm_ptr_->byte_size, kind); + if (err != cudaSuccess) { + throw PythonBackendException( + std::string( + "failed to copy data: " + std::string(cudaGetErrorString(err))) + .c_str()); + } + err = cudaStreamSynchronize(0); + if (err != cudaSuccess) { + throw PythonBackendException( + std::string( + "failed to synchronize the default CUDA stream. error: " + + std::string(cudaGetErrorString(err))) + .c_str()); + } + } +} #endif std::unique_ptr PbMemory::Create( + std::unique_ptr& shm_pool, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, uint64_t byte_size, char* data, char* data_shm, bi::managed_external_buffer::handle_t handle, bool copy_gpu) { + void* backend_memory; PbMemory::FillShmData( - memory_type, memory_type_id, byte_size, data, data_shm, handle, copy_gpu); + shm_pool, &backend_memory, memory_type, memory_type_id, byte_size, data, + data_shm, handle, copy_gpu); if (memory_type == TRITONSERVER_MEMORY_CPU) { data = data_shm + sizeof(MemoryShm); @@ -99,6 +140,13 @@ PbMemory::Create( #ifdef TRITON_ENABLE_GPU if (memory_type == TRITONSERVER_MEMORY_GPU) { +#ifndef TRITON_PB_STUB + if (pb_memory->memory_shm_ptr_->use_cuda_shared_pool) { + pb_memory->backend_memory_.reset( + reinterpret_cast(backend_memory)); + } +#endif + pb_memory->memory_shm_ptr_->gpu_pointer_offset = pb_memory->GetGPUPointerOffset(); } @@ -176,14 +224,17 @@ PbMemory::CopyBuffer( void PbMemory::FillShmData( + std::unique_ptr& shm_pool, void** backend_memory, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, uint64_t byte_size, char* data, char* data_shm, - bi::managed_external_buffer::handle_t handle, bool copy_gpu) + bi::managed_external_buffer::handle_t handle, bool copy_gpu, + bool write_back_data) { char* memory_data_shm = data_shm + sizeof(MemoryShm); MemoryShm* memory_shm_ptr = reinterpret_cast(data_shm); - memory_shm_ptr->is_cuda_handle_set = copy_gpu; memory_shm_ptr->memory_release_id = 0; + bool use_cuda_shared_pool = false; + *backend_memory = nullptr; if (memory_type == TRITONSERVER_MEMORY_GPU) { #ifdef TRITON_ENABLE_GPU @@ -193,8 +244,62 @@ PbMemory::FillShmData( THROW_IF_CUDA_ERROR(cudaIpcGetMemHandle( reinterpret_cast(memory_data_shm), data)); } +#ifndef TRITON_PB_STUB + // Check if the data is already in the pool by checking the base address. + CUDAHandler& cuda_api = CUDAHandler::getInstance(); + CUdeviceptr cuda_pool_address = 0; + cuda_api.PointerGetAttribute( + &cuda_pool_address, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + reinterpret_cast(data)); + if (shm_pool->CUDAPoolAddress() == + reinterpret_cast(cuda_pool_address)) { + use_cuda_shared_pool = true; + memory_shm_ptr->cuda_pool_offset = + data - reinterpret_cast(shm_pool->CUDAPoolAddress()); + } else { + TRITONSERVER_Error* error = BackendMemory::Create( + reinterpret_cast( + shm_pool->TritonMemoryManager()), + BackendMemory::AllocationType::GPU_POOL, memory_type_id, byte_size, + reinterpret_cast(backend_memory)); + if (error != nullptr) { + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + (std::string( + "Failed to allocate memory from CUDA memory pool: ") + + TRITONSERVER_ErrorMessage(error)) + .c_str()); + } else { + // Copy the data to the new buffer in the CUDA pool. + cudaMemcpyKind kind = cudaMemcpyDeviceToDevice; + cudaError_t err; + err = cudaMemcpy( + (reinterpret_cast(*backend_memory))->MemoryPtr(), + data, byte_size, kind); + if (err != cudaSuccess) { + throw PythonBackendException( + std::string( + "failed to copy data: " + + std::string(cudaGetErrorString(err))) + .c_str()); + } + err = cudaStreamSynchronize(0); + if (err != cudaSuccess) { + throw PythonBackendException( + std::string( + "failed to synchronize the default CUDA stream. error: " + + std::string(cudaGetErrorString(err))) + .c_str()); + } + use_cuda_shared_pool = true; + memory_shm_ptr->cuda_pool_offset = + (reinterpret_cast(*backend_memory))->MemoryPtr() - + reinterpret_cast(shm_pool->CUDAPoolAddress()); + } + } +#endif // Not TRITON_PB_STUB } -#endif +#endif // TRITON_ENABLE_GPU } else { if (data != nullptr) { std::copy(data, data + byte_size, memory_data_shm); @@ -204,10 +309,12 @@ PbMemory::FillShmData( memory_shm_ptr->byte_size = byte_size; memory_shm_ptr->memory_type_id = memory_type_id; memory_shm_ptr->memory_type = memory_type; + memory_shm_ptr->use_cuda_shared_pool = use_cuda_shared_pool; } std::unique_ptr PbMemory::LoadFromSharedMemory( + std::unique_ptr& shm_pool, bi::managed_external_buffer::handle_t handle, char* data_shm, bool open_cuda_handle) { @@ -219,21 +326,32 @@ PbMemory::LoadFromSharedMemory( if (memory_shm_ptr->memory_type == TRITONSERVER_MEMORY_GPU && open_cuda_handle) { #ifdef TRITON_ENABLE_GPU - cudaIpcMemHandle_t* cuda_handle = - reinterpret_cast(memory_data_shm); + if (memory_shm_ptr->use_cuda_shared_pool) { +#ifdef TRITON_PB_STUB + // When CUDA shared memory pool is used, the stub will retrieve the + // data pointer using the offset. + data_ptr = + (reinterpret_cast(shm_pool->CUDAPoolAddress()) + + memory_shm_ptr->cuda_pool_offset); +#endif // TRITON_PB_STUB + } else { + cudaIpcMemHandle_t* cuda_handle = + reinterpret_cast(memory_data_shm); - // The pointer opened by the cudaIpcOpenMemHandle will refer to the base - // address. We need to manually correct the offset. - void* data_ptr_base; - CUDAHandler& cuda_handler = CUDAHandler::getInstance(); - cuda_handler.OpenCudaHandle( - memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base); + // The pointer opened by the cudaIpcOpenMemHandle will refer to the base + // address. We need to manually correct the offset. + void* data_ptr_base; + CUDAHandler& cuda_handler = CUDAHandler::getInstance(); + cuda_handler.OpenCudaHandle( + memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base); - data_ptr = - (reinterpret_cast(data_ptr_base) + - memory_shm_ptr->gpu_pointer_offset); - opened_cuda_ipc_handle = true; -#endif + data_ptr = + (reinterpret_cast(data_ptr_base) + + memory_shm_ptr->gpu_pointer_offset); + opened_cuda_ipc_handle = true; + } + +#endif // TRITON_ENABLE_GPU } else { data_ptr = memory_data_shm; } @@ -258,21 +376,30 @@ PbMemory::LoadFromSharedMemory( if (memory_shm_ptr->memory_type == TRITONSERVER_MEMORY_GPU) { if (memory_shm_ptr->byte_size > 0 && open_cuda_handle) { #ifdef TRITON_ENABLE_GPU - cudaIpcMemHandle_t* cuda_handle = - reinterpret_cast(memory_data_shm); - - // The pointer opened by the cudaIpcOpenMemHandle will refer to the base - // address. We need to manually correct the offset. - - void* data_ptr_base; - CUDAHandler& cuda_handler = CUDAHandler::getInstance(); - cuda_handler.OpenCudaHandle( - memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base); - - data_ptr = - (reinterpret_cast(data_ptr_base) + - memory_shm_ptr->gpu_pointer_offset); - opened_cuda_ipc_handle = true; + if (memory_shm_ptr->use_cuda_shared_pool) { +#ifdef TRITON_PB_STUB + // When CUDA shared memory pool is used, the stub will retrieve the + // data pointer using the offset. + data_ptr = + (reinterpret_cast(shm_pool->CUDAPoolAddress()) + + memory_shm_ptr->cuda_pool_offset); +#endif // TRITON_PB_STUB + } else { + cudaIpcMemHandle_t* cuda_handle = + reinterpret_cast(memory_data_shm); + + // The pointer opened by the cudaIpcOpenMemHandle will refer to the base + // address. We need to manually correct the offset. + void* data_ptr_base; + CUDAHandler& cuda_handler = CUDAHandler::getInstance(); + cuda_handler.OpenCudaHandle( + memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base); + + data_ptr = + (reinterpret_cast(data_ptr_base) + + memory_shm_ptr->gpu_pointer_offset); + opened_cuda_ipc_handle = true; + } #endif } } else { diff --git a/src/pb_memory.h b/src/pb_memory.h index e7986014..6a4011df 100644 --- a/src/pb_memory.h +++ b/src/pb_memory.h @@ -1,4 +1,4 @@ -// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -42,13 +42,18 @@ namespace triton { namespace backend { namespace python { // struct MemoryShm { // If the memory type is a GPU pointer, the offset of the GPU pointer from the - // base address. For CPU memory type this field contains garbage data. + // base address. For CPU memory type this field contains garbage data. This + // field will only be used when the memory is not allocated from the CUDA + // shared memory pool. uint64_t gpu_pointer_offset; + bool use_cuda_shared_pool; + // The offset of the memory from the base address of the CUDA shared memory + // pool. + uint64_t cuda_pool_offset; TRITONSERVER_MemoryType memory_type; int64_t memory_type_id; uint64_t byte_size; - bool is_cuda_handle_set; uint64_t memory_release_id; }; @@ -57,9 +62,11 @@ class PbMemory { static std::unique_ptr Create( std::unique_ptr& shm_pool, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, - uint64_t byte_size, char* data, bool copy_gpu = true); + uint64_t byte_size, char* data, bool copy_gpu = true, + bool write_back_data = false); static std::unique_ptr Create( + std::unique_ptr& shm_pool, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, uint64_t byte_size, char* data, char* data_shm, bi::managed_external_buffer::handle_t handle, bool copy_gpu = true); @@ -68,6 +75,10 @@ class PbMemory { static std::unique_ptr Create( std::unique_ptr& shm_pool, std::unique_ptr&& backend_memory, bool copy_gpu = true); + + // Copy the data from the CUDA shared memory pool to the output buffer + // provided by Triton + void WriteBackOutput(std::unique_ptr& shm_pool); #endif #ifdef TRITON_ENABLE_GPU @@ -83,6 +94,7 @@ class PbMemory { bi::managed_external_buffer::handle_t memory_handle, bool open_cuda_handle); static std::unique_ptr LoadFromSharedMemory( + std::unique_ptr& shm_pool, bi::managed_external_buffer::handle_t handle, char* data_shm, bool open_cuda_handle); static uint64_t ShmStructSize( @@ -117,11 +129,17 @@ class PbMemory { void SetMemoryReleaseCallback(std::function release_callback); + bool UseCudaSharedPool() const + { + return memory_shm_ptr_->use_cuda_shared_pool; + } + ~PbMemory(); private: AllocatedSharedMemory memory_shm_; MemoryShm* memory_shm_ptr_; + uint64_t cuda_pool_offset_; #ifndef TRITON_PB_STUB std::unique_ptr backend_memory_; @@ -133,6 +151,10 @@ class PbMemory { // the same as memory_data_shm_ptr_. char* data_ptr_; + // Store the buffer provided by Triton. This is used to write back the data + // from the CUDA shared memory pool to the original buffer. + char* original_buffer_; + bi::managed_external_buffer::handle_t memory_shm_handle_; bool opened_cuda_ipc_handle_; @@ -150,9 +172,11 @@ class PbMemory { #endif static void FillShmData( + std::unique_ptr& shm_pool, void** backend_memory, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, uint64_t byte_size, char* data, char* data_shm, - bi::managed_external_buffer::handle_t handle, bool copy_gpu = true); + bi::managed_external_buffer::handle_t handle, bool copy_gpu = true, + bool write_back_data = false); PbMemory( AllocatedSharedMemory& memory_shm, char* data, diff --git a/src/pb_stub.cc b/src/pb_stub.cc index c5c6b42e..1586d67d 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -364,6 +364,51 @@ Stub::RunCommand() } break; + case PYTHONSTUB_CommandType::PYTHONSTUB_CUDAPoolInitializeRequest: { + bool has_exception = false; + std::string error_string; + + std::unique_ptr response_msg = + IPCMessage::Create(shm_pool_, false /* inline_response */); + response_msg->Command() = PYTHONSTUB_CUDAPoolInitializeResponse; + std::unique_ptr error_string_shm; + AllocatedSharedMemory response = + shm_pool_->Construct(); + + ScopedDefer _([this, &response_msg] { SendIPCMessage(response_msg); }); + + response.data_->response_has_error = false; + response.data_->response_is_error_set = false; + response_msg->Args() = response.handle_; + + try { + GetCUDAMemoryPoolAddress(ipc_message->Args()); + } + catch (const PythonBackendException& pb_exception) { + has_exception = true; + error_string = pb_exception.what(); + } + + if (has_exception) { + // Do not delete the region. The region will be deleted by the parent + // process. + shm_pool_->SetDeleteRegion(false); + LOG_INFO + << "Failed to initialize CUDA shared memory pool in Python stub: " + << error_string; + response.data_->response_has_error = true; + response.data_->response_is_error_set = false; + + LOG_IF_EXCEPTION( + error_string_shm = PbString::Create(shm_pool_, error_string)); + if (error_string_shm != nullptr) { + response.data_->response_is_error_set = true; + response.data_->response_error = error_string_shm->ShmHandle(); + } + + return true; // Terminate the stub process. + } + } break; default: break; } @@ -515,6 +560,8 @@ Stub::Initialize(bi::managed_external_buffer::handle_t map_handle) model_config_params[pair.first.c_str()] = pair.second; } + device_id_ = std::stoi(map["model_instance_device_id"]); + LaunchStubToParentQueueMonitor(); LaunchParentToStubQueueMonitor(); @@ -870,6 +917,18 @@ Stub::SendIPCUtilsMessage(std::unique_ptr& ipc_message) Stub::~Stub() { +#ifdef TRITON_ENABLE_GPU + if (shm_pool_->CUDAPoolAddress() != nullptr) { + try { + CUDAHandler& cuda_api = CUDAHandler::getInstance(); + cuda_api.CloseCudaHandle(device_id_, shm_pool_->CUDAPoolAddress()); + } + catch (const PythonBackendException& pb_exception) { + std::cerr << "Error when closing CUDA handle: " << pb_exception.what(); + } + } +#endif + { py::gil_scoped_acquire acquire; model_instance_ = py::none(); @@ -1225,6 +1284,24 @@ Stub::GetProxyStream(const int& device_id) #endif } +void +Stub::GetCUDAMemoryPoolAddress(bi::managed_external_buffer::handle_t handle) +{ +#ifdef TRITON_ENABLE_GPU + AllocatedSharedMemory cuda_handle_shm = + shm_pool_->Load(handle); + CUDAMemPoolMessage* cuda_handle_shm_ptr = cuda_handle_shm.data_.get(); + + CUDAHandler& cuda_api = CUDAHandler::getInstance(); + void* cuda_pool_address; + cuda_api.OpenCudaHandle( + device_id_, &cuda_handle_shm_ptr->cuda_handle, &cuda_pool_address); + shm_pool_->SetCUDAPoolAddress(cuda_pool_address); +#else + return nullptr; +#endif +} + std::unique_ptr Logger::log_instance_; std::unique_ptr& diff --git a/src/pb_stub.h b/src/pb_stub.h index 6d047d29..079138ec 100644 --- a/src/pb_stub.h +++ b/src/pb_stub.h @@ -200,7 +200,9 @@ struct UtilsMessagePayload { class Stub { public: - Stub() : stub_to_parent_thread_(false), parent_to_stub_thread_(false){}; + Stub() + : device_id_(0), stub_to_parent_thread_(false), + parent_to_stub_thread_(false){}; static std::unique_ptr& GetOrCreateInstance(); /// Instantiate a new Python backend Stub. @@ -336,7 +338,11 @@ class Stub { /// for provided device cudaStream_t GetProxyStream(const int& device_id); + /// Get the CUDA memory pool address from the parent process. + void GetCUDAMemoryPoolAddress(bi::managed_external_buffer::handle_t handle); + private: + int32_t device_id_; bi::interprocess_mutex* stub_mutex_; bi::interprocess_condition* stub_cond_; bi::interprocess_mutex* parent_mutex_; diff --git a/src/pb_tensor.cc b/src/pb_tensor.cc index 84cd8f3f..d9d47784 100644 --- a/src/pb_tensor.cc +++ b/src/pb_tensor.cc @@ -555,7 +555,7 @@ PbTensor::SaveToSharedMemory( if (!pb_memory_) { pb_memory_ = PbMemory::Create( - memory_type_, memory_type_id_, byte_size_, + shm_pool, memory_type_, memory_type_id_, byte_size_, reinterpret_cast(memory_ptr_), reinterpret_cast(tensor_shm_ptr_) + pb_memory_offset, shm_handle_ + pb_memory_offset, copy_gpu); @@ -585,7 +585,7 @@ PbTensor::LoadFromSharedMemory( if (tensor_shm_ptr->memory == 0) { std::size_t pb_memory_offset = name_offset + name_shm->Size(); pb_memory = PbMemory::LoadFromSharedMemory( - pb_memory_offset, tensor_shm.data_.get() + pb_memory_offset, + shm_pool, pb_memory_offset, tensor_shm.data_.get() + pb_memory_offset, open_cuda_handle); } else { pb_memory = PbMemory::LoadFromSharedMemory( diff --git a/src/pb_utils.h b/src/pb_utils.h index 1d651f3f..f080736c 100644 --- a/src/pb_utils.h +++ b/src/pb_utils.h @@ -235,6 +235,10 @@ struct RequestBatch { bi::managed_external_buffer::handle_t gpu_buffers_handle; }; +struct CUDAMemPoolMessage { + cudaIpcMemHandle_t cuda_handle; +}; + #ifdef TRITON_ENABLE_GPU class CUDAHandler { public: diff --git a/src/python_be.cc b/src/python_be.cc index 14e0c74b..c12c3b6f 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -407,6 +407,9 @@ ModelInstanceState::LaunchStubProcess() RETURN_IF_ERROR(Stub()->Setup()); StartMonitor(); RETURN_IF_ERROR(Stub()->Launch()); +#ifdef TRITON_ENABLE_GPU + Stub()->ShareCUDAMemoryPool(Model()->TritonMemoryManager()); +#endif // TRITON_ENABLE_GPU thread_pool_ = std::make_unique( model_state->StateForBackend()->thread_pool_size); @@ -568,7 +571,7 @@ ModelInstanceState::GetInputTensor( void* dev_ptr; RETURN_IF_CUDA_ERROR( cudaMalloc(&dev_ptr, input_byte_size), TRITONSERVER_ERROR_INTERNAL, - std::string("Failed to allocated CUDA memory")); + std::string("Failed to allocate CUDA memory")); size_t byte_size = input_byte_size; @@ -1542,6 +1545,12 @@ ModelInstanceState::ProcessRequests( pb_memory->ByteSize(), pb_memory->DataPtr(), pointer, CudaStream(), &cuda_used)); cuda_copy |= cuda_used; + } else if ( + pb_memory->MemoryType() == TRITONSERVER_MEMORY_GPU && + pb_memory->UseCudaSharedPool()) { + // Copy the data from the CUDA shared memory pool to the output buffer + // provided by Triton + pb_memory->WriteBackOutput(Stub()->ShmPool()); } } response_index++; @@ -2172,9 +2181,23 @@ TRITONBACKEND_ModelInstanceExecute( if (err == nullptr) { instance_state->StartMonitor(); } - LOG_IF_ERROR(err, "Failed to restart the stub process."); + LOG_IF_ERROR( + err, + "Failed to restart the stub process: failed to start " + "the monitor."); err = instance_state->Stub()->Launch(); - LOG_IF_ERROR(err, "Failed to restart the stub process."); + LOG_IF_ERROR( + err, + "Failed to restart the stub process: failed to launch " + "the stub process."); +#ifdef TRITON_ENABLE_GPU + err = instance_state->Stub()->ShareCUDAMemoryPool( + instance_state->Model()->TritonMemoryManager()); + LOG_IF_ERROR( + err, + "Failed to restart the stub process: failed to share " + "CUDA memory pool."); +#endif // TRITON_ENABLE_GPU } } else { std::vector> infer_requests; diff --git a/src/shm_manager.cc b/src/shm_manager.cc index b52d5a4f..58c9fba1 100644 --- a/src/shm_manager.cc +++ b/src/shm_manager.cc @@ -36,6 +36,7 @@ namespace triton { namespace backend { namespace python { SharedMemoryManager::SharedMemoryManager( const std::string& shm_region_name, size_t shm_size, size_t shm_growth_bytes, bool create) + : cuda_pool_address_(nullptr), triton_memory_manager_(nullptr) { shm_region_name_ = shm_region_name; create_ = create; @@ -95,6 +96,7 @@ SharedMemoryManager::SharedMemoryManager( } SharedMemoryManager::SharedMemoryManager(const std::string& shm_region_name) + : cuda_pool_address_(nullptr), triton_memory_manager_(nullptr) { shm_region_name_ = shm_region_name; create_ = false; diff --git a/src/shm_manager.h b/src/shm_manager.h index bd462403..0e047018 100644 --- a/src/shm_manager.h +++ b/src/shm_manager.h @@ -157,6 +157,20 @@ class SharedMemoryManager { void SetDeleteRegion(bool delete_region); + void SetCUDAPoolAddress(void* cuda_pool_address) + { + cuda_pool_address_ = cuda_pool_address; + } + + void* CUDAPoolAddress() { return cuda_pool_address_; } + + void SetTritonMemoryManager(void* triton_memory_manager) + { + triton_memory_manager_ = triton_memory_manager; + } + + void* TritonMemoryManager() { return triton_memory_manager_; } + ~SharedMemoryManager() noexcept(false); private: @@ -171,6 +185,10 @@ class SharedMemoryManager { uint64_t* total_size_; bool create_; bool delete_region_; + // The base address of the Triton CUDA memory pool + void* cuda_pool_address_; + // TRITONBACKEND_MemoryManager + void* triton_memory_manager_; template AllocatedSharedMemory WrapObjectInUniquePtr( diff --git a/src/stub_launcher.cc b/src/stub_launcher.cc index de4dd46c..31dc441b 100644 --- a/src/stub_launcher.cc +++ b/src/stub_launcher.cc @@ -325,7 +325,7 @@ StubLauncher::Launch() // The reason it is broken into two steps is that creation of the health // monitoring thread may take longer which can make the server process think // that the stub process is unhealthy and return early. Waiting until the - // health thread is spawn would make sure would prevent this issue. + // health thread is spawn would prevent this issue. parent_message_queue_->Pop(); if (stub_process_kind_ == "AUTOCOMPLETE_STUB") { @@ -598,4 +598,91 @@ StubLauncher::ReceiveMessageFromStub( return nullptr; // success } -}}}; // namespace triton::backend::python + +#ifdef TRITON_ENABLE_GPU +TRITONSERVER_Error* +StubLauncher::ShareCUDAMemoryPool( + TRITONBACKEND_MemoryManager* triton_mem_manager) +{ + // Create a dummy BackendMemory object to get the start address of the CUDA + // memory pool. + BackendMemory* backend_memory; + std::unique_ptr lbackend_memory; + + TRITONSERVER_Error* error = BackendMemory::Create( + triton_mem_manager, BackendMemory::AllocationType::GPU_POOL, device_id_, + 1 /* byte size*/, &backend_memory); + if (error != nullptr) { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, TRITONSERVER_ErrorMessage(error)); + } + lbackend_memory.reset(backend_memory); + + CUDAHandler& cuda_api = CUDAHandler::getInstance(); + CUdeviceptr cuda_pool_address = 0; + cuda_api.PointerGetAttribute( + &cuda_pool_address, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + reinterpret_cast(lbackend_memory->MemoryPtr())); + + shm_pool_->SetCUDAPoolAddress(reinterpret_cast(cuda_pool_address)); + shm_pool_->SetTritonMemoryManager( + reinterpret_cast(triton_mem_manager)); + + // Get the memory handle from the CUDA memory pool. + AllocatedSharedMemory memory_data_shm = + shm_pool_->Construct(); + CUDAMemPoolMessage* memory_data_ptr = memory_data_shm.data_.get(); + { + ScopedSetDevice scoped_set_device(device_id_); + THROW_IF_CUDA_ERROR(cudaIpcGetMemHandle( + reinterpret_cast(&memory_data_ptr->cuda_handle), + reinterpret_cast(shm_pool_->CUDAPoolAddress()))); + } + + // Share the CUDA memory pool with the stub process. + std::unique_ptr cuda_memory_pool_message = + IPCMessage::Create(shm_pool_, false /* inline_response */); + cuda_memory_pool_message->Command() = PYTHONSTUB_CUDAPoolInitializeRequest; + cuda_memory_pool_message->Args() = memory_data_shm.handle_; + + stub_message_queue_->Push(cuda_memory_pool_message->ShmHandle()); + + bi::managed_external_buffer::handle_t message; + RETURN_IF_ERROR(ReceiveMessageFromStub(message)); + + std::unique_ptr response_message = + IPCMessage::LoadFromSharedMemory(shm_pool_, message); + + if (response_message->Command() != PYTHONSTUB_CUDAPoolInitializeResponse) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string( + "Received unexpected response from Python backend stub: ") + + model_instance_name_) + .c_str()); + } + + auto response = + std::move( + (shm_pool_->Load(response_message->Args()))) + .data_; + + if (response->response_has_error) { + if (response->response_is_error_set) { + std::unique_ptr error_message = + PbString::LoadFromSharedMemory(shm_pool_, response->response_error); + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, error_message->String().c_str()); + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("Error when sharing CUDA memory pool with stub process " + "on model ") + + model_name_) + .c_str()); + } + } + + return nullptr; +} +#endif // TRITON_ENABLE_GPU +}}}; // namespace triton::backend::python diff --git a/src/stub_launcher.h b/src/stub_launcher.h index 89f35422..7dceca68 100644 --- a/src/stub_launcher.h +++ b/src/stub_launcher.h @@ -151,6 +151,12 @@ class StubLauncher { TRITONSERVER_Error* ReceiveMessageFromStub( bi::managed_external_buffer::handle_t& message); +#ifdef TRITON_ENABLE_GPU + // Share CUDA memory pool with stub process + TRITONSERVER_Error* ShareCUDAMemoryPool( + TRITONBACKEND_MemoryManager* triton_mem_manager); +#endif // TRITON_ENABLE_GPU + private: pid_t parent_pid_; pid_t stub_pid_;