Skip to content

Commit c13d1b8

Browse files
authored
eliminate conditional jumps in PytorchContext. (k2-fsa#163)
1 parent f784a2c commit c13d1b8

File tree

3 files changed

+76
-44
lines changed

3 files changed

+76
-44
lines changed

k2/csrc/default_context.cu

+9-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,15 @@ static constexpr std::size_t kAlignment = 64;
2121

2222
// TODO(haowen): most of implementations below should be updated later.
2323
class CpuContext : public Context {
24+
private:
25+
CpuContext() = default;
26+
2427
public:
25-
ContextPtr GetCpuContext() override { return nullptr; }
28+
static ContextPtr Make() {
29+
auto p = new CpuContext();
30+
return ContextPtr{p};
31+
}
32+
ContextPtr GetCpuContext() override { return shared_from_this(); }
2633
ContextPtr GetPinnedContext() override { return nullptr; }
2734
DeviceType GetDeviceType() const override { return kCpu; }
2835

@@ -98,7 +105,7 @@ class CudaContext : public Context {
98105
cudaStream_t stream_;
99106
};
100107

101-
ContextPtr GetCpuContext() { return std::make_shared<CpuContext>(); }
108+
ContextPtr GetCpuContext() { return CpuContext::Make(); }
102109

103110
ContextPtr GetCudaContext(int32_t gpu_id /*= -1*/) {
104111
return std::make_shared<CudaContext>(gpu_id);

k2/csrc/pytorch_context.cu

+4-3
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@
1111

1212
#include <memory>
1313

14+
#include "c10/cuda/CUDAFunctions.h"
1415
#include "k2/csrc/pytorch_context.h"
1516

1617
namespace k2 {
1718

18-
ContextPtr GetCpuContext() { return std::make_shared<PytorchContext>(-1); }
19+
ContextPtr GetCpuContext() { return PytorchCpuContext::Make(); }
1920

2021
ContextPtr GetCudaContext(int32_t gpu_id /*= -1*/) {
21-
if (gpu_id < 0) gpu_id = 0; // TODO(fangjun): select a device
22-
return std::make_shared<PytorchContext>(gpu_id);
22+
if (gpu_id < 0) gpu_id = c10::cuda::current_device();
23+
return std::make_shared<PytorchCudaContext>(gpu_id);
2324
}
2425

2526
} // namespace k2

k2/csrc/pytorch_context.h

+63-39
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,83 @@
1313
#ifndef K2_CSRC_PYTORCH_CONTEXT_H_
1414
#define K2_CSRC_PYTORCH_CONTEXT_H_
1515

16+
#include <memory>
17+
1618
#include "c10/cuda/CUDACachingAllocator.h"
1719
#include "k2/csrc/context.h"
1820
#include "k2/csrc/log.h"
1921
#include "torch/torch.h"
2022

2123
namespace k2 {
2224

23-
class PytorchContext : public Context {
24-
public:
25-
// if device_id < 0, then this is a cpu context;
26-
// otherwise, it is a cuda context.
27-
explicit PytorchContext(int32_t device_id) : device_id_(device_id) {
28-
if (device_id_ < 0)
29-
InitCpu();
30-
else
31-
InitCuda();
25+
class PytorchCpuContext : public Context {
26+
private:
27+
PytorchCpuContext() {
28+
allocator_ = torch::GetAllocator(torch::kCPU);
29+
K2_CHECK(allocator_->raw_deleter() != nullptr);
3230
}
3331

34-
ContextPtr GetCpuContext() override {
35-
// TODO(fangjun): return `this` if it's cpu ?
36-
return nullptr;
32+
public:
33+
static ContextPtr Make() {
34+
auto p = new PytorchCpuContext();
35+
return ContextPtr{p};
3736
}
3837

38+
// since the constructor is private, the only way to create an instance
39+
// of PytorchCpuContext is via `Make`, which returns a `shared_ptr`.
40+
// Thus it is safe to call `shared_from_this`.
41+
ContextPtr GetCpuContext() override { return shared_from_this(); }
42+
3943
ContextPtr GetPinnedContext() override { return nullptr; }
4044

41-
DeviceType GetDeviceType() const override {
42-
return device_id_ >= 0 ? kCuda : kCpu;
45+
DeviceType GetDeviceType() const override { return kCpu; }
46+
47+
void *Allocate(std::size_t bytes, void **deleter_context) override {
48+
void *p = allocator_->raw_allocate(bytes);
49+
if (deleter_context) *deleter_context = nullptr;
50+
return p;
4351
}
4452

45-
int32_t GetDeviceId() const override { return device_id_; }
53+
void Deallocate(void *data, void * /*deleter_context*/) override {
54+
allocator_->raw_deallocate(data);
55+
}
56+
57+
bool IsCompatible(const Context &other) const override {
58+
return other.GetDeviceType() == kCpu;
59+
}
60+
61+
private:
62+
torch::Allocator *allocator_; // NOT owned here
63+
};
64+
65+
class PytorchCudaContext : public Context {
66+
public:
67+
explicit PytorchCudaContext(int32_t gpu_id) : gpu_id_(gpu_id) {
68+
K2_CHECK_GE(gpu_id, 0);
69+
K2_CHECK_LT(gpu_id, c10::cuda::device_count());
70+
71+
c10::cuda::set_device(gpu_id);
72+
73+
// The internals of `lazyInitCUDA` are executed only once
74+
// so it is fine to invoke lazyInitCUDA() multiple times.
75+
// The call will be inlined since it is defined in the header
76+
// aten/src/ATen/Context.h
77+
at::globalContext().lazyInitCUDA();
78+
79+
allocator_ = c10::cuda::CUDACachingAllocator::get();
80+
K2_CHECK(allocator_->raw_deleter() != nullptr);
81+
}
82+
83+
ContextPtr GetCpuContext() override { return nullptr; }
84+
85+
ContextPtr GetPinnedContext() override { return nullptr; }
86+
87+
DeviceType GetDeviceType() const override { return kCuda; }
88+
89+
int32_t GetDeviceId() const override { return gpu_id_; }
4690

4791
cudaStream_t GetCudaStream() const override {
48-
return device_id_ >= 0 ? c10::cuda::getCurrentCUDAStream(device_id_)
49-
: kCudaStreamInvalid;
92+
return c10::cuda::getCurrentCUDAStream(gpu_id_);
5093
}
5194

5295
void *Allocate(std::size_t bytes, void **deleter_context) override {
@@ -60,36 +103,17 @@ class PytorchContext : public Context {
60103
}
61104

62105
bool IsCompatible(const Context &other) const override {
63-
return other.GetDeviceType() == GetDeviceType() &&
64-
other.GetDeviceId() == device_id_;
106+
return other.GetDeviceType() == kCuda && other.GetDeviceId() == gpu_id_;
65107
}
66108

67109
void Sync() const override {
68-
if (device_id_ >= 0) {
69-
auto ret = cudaStreamSynchronize(GetCudaStream());
70-
K2_CHECK_CUDA_ERROR(ret);
71-
}
72-
}
73-
74-
private:
75-
void InitCpu() {
76-
allocator_ = torch::GetAllocator(torch::kCPU);
77-
K2_CHECK(allocator_->raw_deleter() != nullptr);
78-
}
79-
80-
void InitCuda() {
81-
auto ret = cudaSetDevice(device_id_);
110+
auto ret = cudaStreamSynchronize(GetCudaStream());
82111
K2_CHECK_CUDA_ERROR(ret);
83-
// TODO(fangjun): invoke init only once
84-
c10::cuda::CUDACachingAllocator::init(device_id_ + 1);
85-
86-
allocator_ = c10::cuda::CUDACachingAllocator::get();
87-
K2_CHECK(allocator_->raw_deleter() != nullptr);
88112
}
89113

90114
private:
91115
torch::Allocator *allocator_; // NOT owned here
92-
int32_t device_id_;
116+
int32_t gpu_id_;
93117
};
94118

95119
} // namespace k2

0 commit comments

Comments
 (0)