16
16
#include < memory>
17
17
18
18
#include " c10/cuda/CUDACachingAllocator.h"
19
+ #include " c10/cuda/CUDAFunctions.h"
19
20
#include " k2/csrc/context.h"
20
21
#include " k2/csrc/log.h"
21
22
#include " torch/torch.h"
22
23
23
24
namespace k2 {
24
25
26
+ class ManagedTensor {
27
+ public:
28
+ explicit ManagedTensor (torch::Tensor &tensor) : handle_(tensor) {}
29
+
30
+ private:
31
+ torch::Tensor handle_; // retain a copy of the tensor passed from Python
32
+ };
33
+
25
34
class PytorchCpuContext : public Context {
26
35
private:
27
36
PytorchCpuContext () {
@@ -46,12 +55,18 @@ class PytorchCpuContext : public Context {
46
55
47
56
void *Allocate (std::size_t bytes, void **deleter_context) override {
48
57
void *p = allocator_->raw_allocate (bytes);
49
- if (deleter_context) *deleter_context = nullptr ;
58
+ if (deleter_context != nullptr ) *deleter_context = nullptr ;
50
59
return p;
51
60
}
52
61
53
- void Deallocate (void *data, void * /* deleter_context*/ ) override {
54
- allocator_->raw_deallocate (data);
62
+ void Deallocate (void *data, void *deleter_context) override {
63
+ if (deleter_context != nullptr ) {
64
+ // a non-empty `deleter_context` indicates that
65
+ // the memory is passed from a `torch::Tensor`
66
+ delete reinterpret_cast <ManagedTensor *>(deleter_context);
67
+ } else {
68
+ allocator_->raw_deallocate (data);
69
+ }
55
70
}
56
71
57
72
bool IsCompatible (const Context &other) const override {
@@ -94,12 +109,18 @@ class PytorchCudaContext : public Context {
94
109
95
110
void *Allocate (std::size_t bytes, void **deleter_context) override {
96
111
void *p = allocator_->raw_allocate (bytes);
97
- if (deleter_context) *deleter_context = nullptr ;
112
+ if (deleter_context != nullptr ) *deleter_context = nullptr ;
98
113
return p;
99
114
}
100
115
101
- void Deallocate (void *data, void * /* deleter_context*/ ) override {
102
- allocator_->raw_deallocate (data);
116
+ void Deallocate (void *data, void *deleter_context) override {
117
+ if (deleter_context != nullptr ) {
118
+ // a non-empty `deleter_context` indicates that
119
+ // the memory is passed from a `torch::Tensor`
120
+ delete reinterpret_cast <ManagedTensor *>(deleter_context);
121
+ } else {
122
+ allocator_->raw_deallocate (data);
123
+ }
103
124
}
104
125
105
126
bool IsCompatible (const Context &other) const override {
@@ -116,6 +137,12 @@ class PytorchCudaContext : public Context {
116
137
int32_t gpu_id_;
117
138
};
118
139
140
+ // Construct a region from a `torch::Tensor`.
141
+ //
142
+ // The resulting region shares the underlying memory with
143
+ // the given tensor.
144
+ RegionPtr NewRegion (torch::Tensor &tensor);
145
+
119
146
} // namespace k2
120
147
121
148
#endif // K2_CSRC_PYTORCH_CONTEXT_H_
0 commit comments