Skip to content

Commit

Permalink
Update to TensorRT 6 (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Oct 18, 2019
1 parent 81eb227 commit 4421ff8
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 56 deletions.
7 changes: 4 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ if (WITH_CUDA)
message(STATUS "Found TensorRT include directory: ${TENSORRT_INCLUDE_DIR}")
endif()

# TensorRT 6 header generates a lot of deprecating warnings. Ignore them.
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -isystem ${TENSORRT_INCLUDE_DIR}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -isystem ${TENSORRT_INCLUDE_DIR}")

find_path(CUB_INCLUDE_DIR NAMES cub/cub.cuh)
if(NOT CUB_INCLUDE_DIR)
message(FATAL_ERROR "CUB library not found")
Expand All @@ -201,11 +205,8 @@ if (WITH_CUDA)
message(STATUS "Found Thrust include directory: ${THRUST_INCLUDE_DIR}")
endif()

list(APPEND INCLUDE_DIRECTORIES ${TENSORRT_INCLUDE_DIR})
cuda_include_directories(
${INCLUDE_DIRECTORIES}
${CUB_INCLUDE_DIR}
${THRUST_INCLUDE_DIR}
)
cuda_add_library(${PROJECT_NAME}
${SOURCES}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ CTranslate2 uses the following libraries for acceleration:
* [Intel MKL-DNN](https://github.com/intel/mkl-dnn) (>=0.20,<1.0)
* GPU
* [CUB](https://nvlabs.github.io/cub/) (>=1.8.0)
* [TensorRT](https://developer.nvidia.com/tensorrt) (==5.*)
* [TensorRT](https://developer.nvidia.com/tensorrt) (==6.*)
* [Thrust](https://docs.nvidia.com/cuda/thrust/index.html) (==1.9.3, included in CUDA 10.0)
* [cuBLAS](https://developer.nvidia.com/cublas) (with CUDA>=10.0)
* [cuDNN](https://developer.nvidia.com/cudnn) (>=7.5)
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile.centos7-gpu
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ RUN wget https://github.com/intel/mkl-dnn/archive/v$MKLDNN_VERSION.tar.gz && \
make -j4 && make install && \
cd ../.. && rm -r mkl-dnn-*

ENV TENSORRT_MAJOR_VERSION=5
ENV TENSORRT_VERSION=${TENSORRT_MAJOR_VERSION}.1.5
ENV TENSORRT_MAJOR_VERSION=6
ENV TENSORRT_VERSION=${TENSORRT_MAJOR_VERSION}.0.1
RUN curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/libnvinfer-devel-${TENSORRT_VERSION}-1.cuda10.0.x86_64.rpm -O && \
curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/libnvinfer${TENSORRT_MAJOR_VERSION}-${TENSORRT_VERSION}-1.cuda10.0.x86_64.rpm -O && \
rpm -ivh --nodeps libnvinfer*.rpm && \
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile.ubuntu-gpu
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
ARG UBUNTU_VERSION=18.04
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu${UBUNTU_VERSION} as builder

ENV TENSORRT_MAJOR_VERSION=5
ENV TENSORRT_VERSION=${TENSORRT_MAJOR_VERSION}.1.5
ENV TENSORRT_MAJOR_VERSION=6
ENV TENSORRT_VERSION=${TENSORRT_MAJOR_VERSION}.0.1

RUN apt-get update && \
apt-get install -y --no-install-recommends \
Expand Down
58 changes: 23 additions & 35 deletions src/cuda/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,10 @@ namespace ctranslate2 {

} g_allocator;

static int max_batch_size = 512;

static nvinfer1::IBuilder* get_trt_builder() {
static thread_local nvinfer1::IBuilder* builder = nullptr;
if (!builder) {
builder = nvinfer1::createInferBuilder(g_logger);
builder->setMaxBatchSize(max_batch_size);
builder->setMaxWorkspaceSize(1 << 30);
builder->setGpuAllocator(&g_allocator);
}
return builder;
Expand All @@ -162,42 +158,34 @@ namespace ctranslate2 {
}

TensorRTLayer::~TensorRTLayer() {
clear();
}

void TensorRTLayer::build(bool force) {
if (!_network || force) {
clear();
auto builder = get_trt_builder();
_network = builder->createNetwork();
build_network(_network);
_engine = builder->buildCudaEngine(*_network);
_execution_context = _engine->createExecutionContext();
}
}

void TensorRTLayer::clear() {
if (_network) {
_network->destroy();
_network = nullptr;
}
if (_engine) {
_engine->destroy();
_engine = nullptr;
}
if (_execution_context) {
_execution_context->destroy();
_execution_context = nullptr;
_network->destroy();
_engine->destroy();
_builder_config->destroy();
}
}

void TensorRTLayer::run(int batch_size, void** bindings) {
if (batch_size > max_batch_size)
throw std::runtime_error("Maximum batch size supported by the TensorRT engine is "
+ std::to_string(max_batch_size) + ", but got "
+ std::to_string(batch_size));
build();
_execution_context->enqueue(batch_size, bindings, get_cuda_stream(), nullptr);
void TensorRTLayer::build() {
auto builder = get_trt_builder();
_network = builder->createNetworkV2(
1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
build_network(_network);
auto profile = builder->createOptimizationProfile();
set_optimization_profile(profile);
_builder_config = builder->createBuilderConfig();
_builder_config->setMaxWorkspaceSize(1 << 30);
_builder_config->addOptimizationProfile(profile);
_engine = builder->buildEngineWithConfig(*_network, *_builder_config);
_execution_context = _engine->createExecutionContext();
}

void TensorRTLayer::run(void** bindings, const std::vector<nvinfer1::Dims>& input_dims) {
if (!_execution_context)
build();
for (size_t i = 0; i < input_dims.size(); ++i)
_execution_context->setBindingDimensions(i, input_dims[i]);
_execution_context->enqueueV2(bindings, get_cuda_stream(), nullptr);
}

}
Expand Down
10 changes: 7 additions & 3 deletions src/cuda/utils.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <string>
#include <vector>

#include <cuda_runtime.h>
#include <cublas_v2.h>
Expand Down Expand Up @@ -49,15 +50,18 @@ namespace ctranslate2 {
virtual ~TensorRTLayer();

protected:
void run(void** bindings, const std::vector<nvinfer1::Dims>& input_dims);

// These methods are called on the first call to run().
virtual void build_network(nvinfer1::INetworkDefinition* network) = 0;
void run(int batch_size, void** bindings);
void build(bool force = false);
void clear();
virtual void set_optimization_profile(nvinfer1::IOptimizationProfile* profile) = 0;

private:
void build();
nvinfer1::INetworkDefinition* _network = nullptr;
nvinfer1::ICudaEngine* _engine = nullptr;
nvinfer1::IExecutionContext* _execution_context = nullptr;
nvinfer1::IBuilderConfig* _builder_config = nullptr;
};

// Statically assiocate cudnnDataType_t with a C++ type.
Expand Down
33 changes: 23 additions & 10 deletions src/ops/topk_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,31 @@ namespace ctranslate2 {
public:
TopKLayer(int k)
: _k(k)
, _depth(0) {
, _first_depth(0) {
}

void operator()(const StorageView& x, StorageView& values, StorageView& indices) {
int depth = x.dim(-1);
int batch_size = x.size() / depth;

if (depth != _depth) {
_depth = depth;
build(/*force=*/true);
}
if (_first_depth == 0)
_first_depth = depth;

void* bindings[3] = {
const_cast<float*>(x.data<float>()),
values.data<float>(),
indices.data<int32_t>()
};

run(batch_size, bindings);
run(bindings, {nvinfer1::Dims2(batch_size, depth)});
}

protected:
void build_network(nvinfer1::INetworkDefinition* network) override {
nvinfer1::Dims input_dim{1, {_depth}, {nvinfer1::DimensionType::kCHANNEL}};
nvinfer1::ITensor* input = network->addInput("x", nvinfer1::DataType::kFLOAT, input_dim);
nvinfer1::ITopKLayer* topk = network->addTopK(*input, nvinfer1::TopKOperation::kMAX, _k, 1);
nvinfer1::ITensor* input = network->addInput("x",
nvinfer1::DataType::kFLOAT,
nvinfer1::Dims2(-1, -1));
nvinfer1::ITopKLayer* topk = network->addTopK(*input, nvinfer1::TopKOperation::kMAX, _k, 2);
nvinfer1::ITensor* values_t = topk->getOutput(0);
nvinfer1::ITensor* indices_t = topk->getOutput(1);
network->markOutput(*values_t);
Expand All @@ -44,9 +43,23 @@ namespace ctranslate2 {
indices_t->setType(nvinfer1::DataType::kINT32);
}

void set_optimization_profile(nvinfer1::IOptimizationProfile* profile) override {
// Optimize for the first seen depth which covers the standard use case
// of running TopK over a static vocabulary size.
profile->setDimensions("x",
nvinfer1::OptProfileSelector::kMIN,
nvinfer1::Dims2(1, _first_depth));
profile->setDimensions("x",
nvinfer1::OptProfileSelector::kOPT,
nvinfer1::Dims2(64, _first_depth));
profile->setDimensions("x",
nvinfer1::OptProfileSelector::kMAX,
nvinfer1::Dims2(1024, _first_depth));
}

private:
int _k;
int _depth;
int _first_depth;
};

template <Device D, typename DataType, typename IndexType>
Expand Down

0 comments on commit 4421ff8

Please sign in to comment.