Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use CUDA shared memory pool to optimize tensor transfer between proce…
Browse files Browse the repository at this point in the history
…sses
krishung5 committed Sep 1, 2023
1 parent 6f369ef commit 6dc4fa6
Showing 13 changed files with 427 additions and 51 deletions.
4 changes: 2 additions & 2 deletions src/infer_response.cc
Original file line number Diff line number Diff line change
@@ -299,15 +299,15 @@ InferResponse::Send(
output_buffer = PbMemory::Create(
shm_pool, actual_memory_type, actual_memory_type_id,
output_tensor->ByteSize(), reinterpret_cast<char*>(buffer),
false /* copy_gpu */));
false /* copy_gpu */, true /* write_back_data */));
output_buffer->SetCudaIpcHandle(cuda_ipc_mem_handle_p);
} else {
SET_ERROR_AND_RETURN_IF_EXCEPTION(
response_error,
output_buffer = PbMemory::Create(
shm_pool, actual_memory_type, actual_memory_type_id,
output_tensor->ByteSize(), reinterpret_cast<char*>(buffer),
true /* copy_gpu */));
true /* copy_gpu */, true /* write_back_data */));
}
gpu_buffer_helper.AddBuffer(output_buffer->ShmHandle());
output_buffers.push_back({std::move(output_buffer), buffer});
2 changes: 2 additions & 0 deletions src/ipc_message.h
Original file line number Diff line number Diff line change
@@ -41,6 +41,8 @@ typedef enum PYTHONSTUB_commandtype_enum {
PYTHONSTUB_ExecuteResponse,
PYTHONSTUB_InitializeRequest,
PYTHONSTUB_InitializeResponse,
PYTHONSTUB_CUDAPoolInitializeRequest,
PYTHONSTUB_CUDAPoolInitializeResponse,
PYTHONSTUB_FinalizeRequest,
PYTHONSTUB_FinalizeResponse,
PYTHONSTUB_LoadGPUBuffers,
199 changes: 163 additions & 36 deletions src/pb_memory.cc
Original file line number Diff line number Diff line change
@@ -32,10 +32,9 @@ std::unique_ptr<PbMemory>
PbMemory::Create(
std::unique_ptr<SharedMemoryManager>& shm_pool,
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id,
uint64_t byte_size, char* data, bool copy_gpu)
uint64_t byte_size, char* data, bool copy_gpu, bool write_back_data)
{
size_t requested_byte_size = sizeof(MemoryShm);

if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
requested_byte_size += sizeof(cudaIpcMemHandle_t);
@@ -46,9 +45,11 @@ PbMemory::Create(

AllocatedSharedMemory<char> memory_shm =
shm_pool->Construct<char>(requested_byte_size);

void* backend_memory;
PbMemory::FillShmData(
memory_type, memory_type_id, byte_size, data, memory_shm.data_.get(),
memory_shm.handle_, copy_gpu);
shm_pool, &backend_memory, memory_type, memory_type_id, byte_size, data,
memory_shm.data_.get(), memory_shm.handle_, copy_gpu, write_back_data);

if (memory_type == TRITONSERVER_MEMORY_CPU) {
data = memory_shm.data_.get() + sizeof(MemoryShm);
@@ -59,6 +60,16 @@ PbMemory::Create(

#ifdef TRITON_ENABLE_GPU
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifndef TRITON_PB_STUB
if (pb_memory->memory_shm_ptr_->use_cuda_shared_pool) {
pb_memory->backend_memory_.reset(
reinterpret_cast<BackendMemory*>(backend_memory));
if (write_back_data) {
// Store the original buffer so that we can write back to it later.
pb_memory->original_buffer_ = data;
}
}
#endif
pb_memory->memory_shm_ptr_->gpu_pointer_offset =
pb_memory->GetGPUPointerOffset();
}
@@ -79,16 +90,46 @@ PbMemory::Create(

return pb_memory;
}


void
PbMemory::WriteBackOutput(std::unique_ptr<SharedMemoryManager>& shm_pool)
{
if (original_buffer_) {
cudaMemcpyKind kind = cudaMemcpyDeviceToDevice;
cudaError_t err;
err = cudaMemcpy(
original_buffer_, backend_memory_->MemoryPtr(),
memory_shm_ptr_->byte_size, kind);
if (err != cudaSuccess) {
throw PythonBackendException(
std::string(
"failed to copy data: " + std::string(cudaGetErrorString(err)))
.c_str());
}
err = cudaStreamSynchronize(0);
if (err != cudaSuccess) {
throw PythonBackendException(
std::string(
"failed to synchronize the default CUDA stream. error: " +
std::string(cudaGetErrorString(err)))
.c_str());
}
}
}
#endif

std::unique_ptr<PbMemory>
PbMemory::Create(
std::unique_ptr<SharedMemoryManager>& shm_pool,
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id,
uint64_t byte_size, char* data, char* data_shm,
bi::managed_external_buffer::handle_t handle, bool copy_gpu)
{
void* backend_memory;
PbMemory::FillShmData(
memory_type, memory_type_id, byte_size, data, data_shm, handle, copy_gpu);
shm_pool, &backend_memory, memory_type, memory_type_id, byte_size, data,
data_shm, handle, copy_gpu);

if (memory_type == TRITONSERVER_MEMORY_CPU) {
data = data_shm + sizeof(MemoryShm);
@@ -99,6 +140,13 @@ PbMemory::Create(

#ifdef TRITON_ENABLE_GPU
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifndef TRITON_PB_STUB
if (pb_memory->memory_shm_ptr_->use_cuda_shared_pool) {
pb_memory->backend_memory_.reset(
reinterpret_cast<BackendMemory*>(backend_memory));
}
#endif

pb_memory->memory_shm_ptr_->gpu_pointer_offset =
pb_memory->GetGPUPointerOffset();
}
@@ -176,14 +224,17 @@ PbMemory::CopyBuffer(

void
PbMemory::FillShmData(
std::unique_ptr<SharedMemoryManager>& shm_pool, void** backend_memory,
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id,
uint64_t byte_size, char* data, char* data_shm,
bi::managed_external_buffer::handle_t handle, bool copy_gpu)
bi::managed_external_buffer::handle_t handle, bool copy_gpu,
bool write_back_data)
{
char* memory_data_shm = data_shm + sizeof(MemoryShm);
MemoryShm* memory_shm_ptr = reinterpret_cast<MemoryShm*>(data_shm);
memory_shm_ptr->is_cuda_handle_set = copy_gpu;
memory_shm_ptr->memory_release_id = 0;
bool use_cuda_shared_pool = false;
*backend_memory = nullptr;

if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
@@ -193,8 +244,62 @@ PbMemory::FillShmData(
THROW_IF_CUDA_ERROR(cudaIpcGetMemHandle(
reinterpret_cast<cudaIpcMemHandle_t*>(memory_data_shm), data));
}
#ifndef TRITON_PB_STUB
// Check if the data is already in the pool by checking the base address.
CUDAHandler& cuda_api = CUDAHandler::getInstance();
CUdeviceptr cuda_pool_address = 0;
cuda_api.PointerGetAttribute(
&cuda_pool_address, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
reinterpret_cast<CUdeviceptr>(data));
if (shm_pool->CUDAPoolAddress() ==
reinterpret_cast<void*>(cuda_pool_address)) {
use_cuda_shared_pool = true;
memory_shm_ptr->cuda_pool_offset =
data - reinterpret_cast<char*>(shm_pool->CUDAPoolAddress());
} else {
TRITONSERVER_Error* error = BackendMemory::Create(
reinterpret_cast<TRITONBACKEND_MemoryManager*>(
shm_pool->TritonMemoryManager()),
BackendMemory::AllocationType::GPU_POOL, memory_type_id, byte_size,
reinterpret_cast<BackendMemory**>(backend_memory));
if (error != nullptr) {
LOG_MESSAGE(
TRITONSERVER_LOG_WARN,
(std::string(
"Failed to allocate memory from CUDA memory pool: ") +
TRITONSERVER_ErrorMessage(error))
.c_str());
} else {
// Copy the data to the new buffer in the CUDA pool.
cudaMemcpyKind kind = cudaMemcpyDeviceToDevice;
cudaError_t err;
err = cudaMemcpy(
(reinterpret_cast<BackendMemory*>(*backend_memory))->MemoryPtr(),
data, byte_size, kind);
if (err != cudaSuccess) {
throw PythonBackendException(
std::string(
"failed to copy data: " +
std::string(cudaGetErrorString(err)))
.c_str());
}
err = cudaStreamSynchronize(0);
if (err != cudaSuccess) {
throw PythonBackendException(
std::string(
"failed to synchronize the default CUDA stream. error: " +
std::string(cudaGetErrorString(err)))
.c_str());
}
use_cuda_shared_pool = true;
memory_shm_ptr->cuda_pool_offset =
(reinterpret_cast<BackendMemory*>(*backend_memory))->MemoryPtr() -
reinterpret_cast<char*>(shm_pool->CUDAPoolAddress());
}
}
#endif // Not TRITON_PB_STUB
}
#endif
#endif // TRITON_ENABLE_GPU
} else {
if (data != nullptr) {
std::copy(data, data + byte_size, memory_data_shm);
@@ -204,10 +309,12 @@ PbMemory::FillShmData(
memory_shm_ptr->byte_size = byte_size;
memory_shm_ptr->memory_type_id = memory_type_id;
memory_shm_ptr->memory_type = memory_type;
memory_shm_ptr->use_cuda_shared_pool = use_cuda_shared_pool;
}

std::unique_ptr<PbMemory>
PbMemory::LoadFromSharedMemory(
std::unique_ptr<SharedMemoryManager>& shm_pool,
bi::managed_external_buffer::handle_t handle, char* data_shm,
bool open_cuda_handle)
{
@@ -219,21 +326,32 @@ PbMemory::LoadFromSharedMemory(
if (memory_shm_ptr->memory_type == TRITONSERVER_MEMORY_GPU &&
open_cuda_handle) {
#ifdef TRITON_ENABLE_GPU
cudaIpcMemHandle_t* cuda_handle =
reinterpret_cast<cudaIpcMemHandle_t*>(memory_data_shm);
if (memory_shm_ptr->use_cuda_shared_pool) {
#ifdef TRITON_PB_STUB
// When CUDA shared memory pool is used, the stub will retrieve the
// data pointer using the offset.
data_ptr =
(reinterpret_cast<char*>(shm_pool->CUDAPoolAddress()) +
memory_shm_ptr->cuda_pool_offset);
#endif // TRITON_PB_STUB
} else {
cudaIpcMemHandle_t* cuda_handle =
reinterpret_cast<cudaIpcMemHandle_t*>(memory_data_shm);

// The pointer opened by the cudaIpcOpenMemHandle will refer to the base
// address. We need to manually correct the offset.
void* data_ptr_base;
CUDAHandler& cuda_handler = CUDAHandler::getInstance();
cuda_handler.OpenCudaHandle(
memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base);
// The pointer opened by the cudaIpcOpenMemHandle will refer to the base
// address. We need to manually correct the offset.
void* data_ptr_base;
CUDAHandler& cuda_handler = CUDAHandler::getInstance();
cuda_handler.OpenCudaHandle(
memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base);

data_ptr =
(reinterpret_cast<char*>(data_ptr_base) +
memory_shm_ptr->gpu_pointer_offset);
opened_cuda_ipc_handle = true;
#endif
data_ptr =
(reinterpret_cast<char*>(data_ptr_base) +
memory_shm_ptr->gpu_pointer_offset);
opened_cuda_ipc_handle = true;
}

#endif // TRITON_ENABLE_GPU
} else {
data_ptr = memory_data_shm;
}
@@ -258,21 +376,30 @@ PbMemory::LoadFromSharedMemory(
if (memory_shm_ptr->memory_type == TRITONSERVER_MEMORY_GPU) {
if (memory_shm_ptr->byte_size > 0 && open_cuda_handle) {
#ifdef TRITON_ENABLE_GPU
cudaIpcMemHandle_t* cuda_handle =
reinterpret_cast<cudaIpcMemHandle_t*>(memory_data_shm);

// The pointer opened by the cudaIpcOpenMemHandle will refer to the base
// address. We need to manually correct the offset.

void* data_ptr_base;
CUDAHandler& cuda_handler = CUDAHandler::getInstance();
cuda_handler.OpenCudaHandle(
memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base);

data_ptr =
(reinterpret_cast<char*>(data_ptr_base) +
memory_shm_ptr->gpu_pointer_offset);
opened_cuda_ipc_handle = true;
if (memory_shm_ptr->use_cuda_shared_pool) {
#ifdef TRITON_PB_STUB
// When CUDA shared memory pool is used, the stub will retrieve the
// data pointer using the offset.
data_ptr =
(reinterpret_cast<char*>(shm_pool->CUDAPoolAddress()) +
memory_shm_ptr->cuda_pool_offset);
#endif // TRITON_PB_STUB
} else {
cudaIpcMemHandle_t* cuda_handle =
reinterpret_cast<cudaIpcMemHandle_t*>(memory_data_shm);

// The pointer opened by the cudaIpcOpenMemHandle will refer to the base
// address. We need to manually correct the offset.
void* data_ptr_base;
CUDAHandler& cuda_handler = CUDAHandler::getInstance();
cuda_handler.OpenCudaHandle(
memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base);

data_ptr =
(reinterpret_cast<char*>(data_ptr_base) +
memory_shm_ptr->gpu_pointer_offset);
opened_cuda_ipc_handle = true;
}
#endif
}
} else {
Loading

0 comments on commit 6dc4fa6

Please sign in to comment.