13
13
#ifndef K2_CSRC_PYTORCH_CONTEXT_H_
14
14
#define K2_CSRC_PYTORCH_CONTEXT_H_
15
15
16
+ #include < memory>
17
+
16
18
#include " c10/cuda/CUDACachingAllocator.h"
17
19
#include " k2/csrc/context.h"
18
20
#include " k2/csrc/log.h"
19
21
#include " torch/torch.h"
20
22
21
23
namespace k2 {
22
24
23
- class PytorchContext : public Context {
24
- public:
25
- // if device_id < 0, then this is a cpu context;
26
- // otherwise, it is a cuda context.
27
- explicit PytorchContext (int32_t device_id) : device_id_(device_id) {
28
- if (device_id_ < 0 )
29
- InitCpu ();
30
- else
31
- InitCuda ();
25
+ class PytorchCpuContext : public Context {
26
+ private:
27
+ PytorchCpuContext () {
28
+ allocator_ = torch::GetAllocator (torch::kCPU );
29
+ K2_CHECK (allocator_->raw_deleter () != nullptr );
32
30
}
33
31
34
- ContextPtr GetCpuContext () override {
35
- // TODO(fangjun): return `this` if it's cpu ?
36
- return nullptr ;
32
+ public:
33
+ static ContextPtr Make () {
34
+ auto p = new PytorchCpuContext ();
35
+ return ContextPtr{p};
37
36
}
38
37
38
+ // since the constructor is private, the only way to create an instance
39
+ // of PytorchCpuContext is via `Make`, which returns a `shared_ptr`.
40
+ // Thus it is safe to call `shared_from_this`.
41
+ ContextPtr GetCpuContext () override { return shared_from_this (); }
42
+
39
43
ContextPtr GetPinnedContext () override { return nullptr ; }
40
44
41
- DeviceType GetDeviceType () const override {
42
- return device_id_ >= 0 ? kCuda : kCpu ;
45
+ DeviceType GetDeviceType () const override { return kCpu ; }
46
+
47
+ void *Allocate (std::size_t bytes, void **deleter_context) override {
48
+ void *p = allocator_->raw_allocate (bytes);
49
+ if (deleter_context) *deleter_context = nullptr ;
50
+ return p;
43
51
}
44
52
45
- int32_t GetDeviceId () const override { return device_id_; }
53
+ void Deallocate (void *data, void * /* deleter_context*/ ) override {
54
+ allocator_->raw_deallocate (data);
55
+ }
56
+
57
+ bool IsCompatible (const Context &other) const override {
58
+ return other.GetDeviceType () == kCpu ;
59
+ }
60
+
61
+ private:
62
+ torch::Allocator *allocator_; // NOT owned here
63
+ };
64
+
65
+ class PytorchCudaContext : public Context {
66
+ public:
67
+ explicit PytorchCudaContext (int32_t gpu_id) : gpu_id_(gpu_id) {
68
+ K2_CHECK_GE (gpu_id, 0 );
69
+ K2_CHECK_LT (gpu_id, c10::cuda::device_count ());
70
+
71
+ c10::cuda::set_device (gpu_id);
72
+
73
+ // The internals of `lazyInitCUDA` are executed only once
74
+ // so it is fine to invoke lazyInitCUDA() multiple times.
75
+ // The call will be inlined since it is defined in the header
76
+ // aten/src/ATen/Context.h
77
+ at::globalContext ().lazyInitCUDA ();
78
+
79
+ allocator_ = c10::cuda::CUDACachingAllocator::get ();
80
+ K2_CHECK (allocator_->raw_deleter () != nullptr );
81
+ }
82
+
83
+ ContextPtr GetCpuContext () override { return nullptr ; }
84
+
85
+ ContextPtr GetPinnedContext () override { return nullptr ; }
86
+
87
+ DeviceType GetDeviceType () const override { return kCuda ; }
88
+
89
+ int32_t GetDeviceId () const override { return gpu_id_; }
46
90
47
91
cudaStream_t GetCudaStream () const override {
48
- return device_id_ >= 0 ? c10::cuda::getCurrentCUDAStream (device_id_)
49
- : kCudaStreamInvalid ;
92
+ return c10::cuda::getCurrentCUDAStream (gpu_id_);
50
93
}
51
94
52
95
void *Allocate (std::size_t bytes, void **deleter_context) override {
@@ -60,36 +103,17 @@ class PytorchContext : public Context {
60
103
}
61
104
62
105
bool IsCompatible (const Context &other) const override {
63
- return other.GetDeviceType () == GetDeviceType () &&
64
- other.GetDeviceId () == device_id_;
106
+ return other.GetDeviceType () == kCuda && other.GetDeviceId () == gpu_id_;
65
107
}
66
108
67
109
void Sync () const override {
68
- if (device_id_ >= 0 ) {
69
- auto ret = cudaStreamSynchronize (GetCudaStream ());
70
- K2_CHECK_CUDA_ERROR (ret);
71
- }
72
- }
73
-
74
- private:
75
- void InitCpu () {
76
- allocator_ = torch::GetAllocator (torch::kCPU );
77
- K2_CHECK (allocator_->raw_deleter () != nullptr );
78
- }
79
-
80
- void InitCuda () {
81
- auto ret = cudaSetDevice (device_id_);
110
+ auto ret = cudaStreamSynchronize (GetCudaStream ());
82
111
K2_CHECK_CUDA_ERROR (ret);
83
- // TODO(fangjun): invoke init only once
84
- c10::cuda::CUDACachingAllocator::init (device_id_ + 1 );
85
-
86
- allocator_ = c10::cuda::CUDACachingAllocator::get ();
87
- K2_CHECK (allocator_->raw_deleter () != nullptr );
88
112
}
89
113
90
114
private:
91
115
torch::Allocator *allocator_; // NOT owned here
92
- int32_t device_id_ ;
116
+ int32_t gpu_id_ ;
93
117
};
94
118
95
119
} // namespace k2
0 commit comments