9
9
*/
10
10
11
11
#include < type_traits>
12
+ #include < vector>
12
13
13
14
#include " c10/core/ScalarType.h"
14
15
#include " k2/csrc/array.h"
16
+ #include " k2/csrc/fsa.h"
15
17
#include " k2/csrc/pytorch_context.h"
16
18
#include " k2/python/csrc/torch/array.h"
17
19
#include " k2/python/csrc/torch/torch_util.h"
18
20
#include " torch/extension.h"
19
21
20
22
namespace k2 {
21
23
24
+ template <typename T>
25
+ static void PybindArray2Tpl (py::module &m, const char *name) {
26
+ using PyClass = Array2<T>;
27
+ py::class_<PyClass> pyclass (m, name);
28
+ pyclass.def (" tensor" ,
29
+ [](PyClass &self) -> torch::Tensor { return ToTensor (self); });
30
+
31
+ pyclass.def_static (
32
+ " from_tensor" ,
33
+ [](torch::Tensor &tensor) -> PyClass {
34
+ return FromTensor<T>(tensor, Array2Tag{});
35
+ },
36
+ py::arg (" tensor" ));
37
+
38
+ // the following functions are for testing only
39
+ pyclass.def (
40
+ " get" , [](PyClass &self, int32_t i) -> Array1<T> { return self[i]; },
41
+ py::arg (" i" ));
42
+
43
+ pyclass.def (" __str__" , [](const PyClass &self) {
44
+ std::ostringstream os;
45
+ os << self;
46
+ return os.str ();
47
+ });
48
+ }
49
+
22
50
template <typename T>
23
51
static void PybindArray1Tpl (py::module &m, const char *name) {
24
52
using PyClass = Array1<T>;
25
53
py::class_<PyClass> pyclass (m, name);
26
- pyclass.def (py::init<>());
27
- pyclass. def ( " tensor " , [](PyClass &self) { return ToTensor (self); });
54
+ pyclass.def (" tensor " ,
55
+ [](PyClass &self) -> torch::Tensor { return ToTensor (self); });
28
56
29
57
pyclass.def_static (
30
58
" from_tensor" ,
31
- [](torch::Tensor &tensor) { return FromTensor<T>(tensor); },
59
+ [](torch::Tensor &tensor) -> PyClass { return FromTensor<T>(tensor); },
32
60
py::arg (" tensor" ));
33
61
34
62
// the following functions are for testing only
@@ -46,30 +74,66 @@ static void PybindArrayImpl(py::module &m) {
46
74
// users should not use classes with prefix `_` in Python.
47
75
PybindArray1Tpl<float >(m, " _FloatArray1" );
48
76
PybindArray1Tpl<int >(m, " _Int32Array1" );
77
+ PybindArray1Tpl<Arc>(m, " _ArcArray1" );
78
+
79
+ PybindArray2Tpl<float >(m, " _FloatArray2" );
80
+ PybindArray2Tpl<int >(m, " _Int32Array2" );
49
81
50
82
// the following functions are for testing purposes
51
83
// and they can be removed later.
52
- m.def (" get_cpu_float_array1" , []() {
84
+ m.def (" get_cpu_float_array1" , []() -> Array1< float > {
53
85
return Array1<float >(GetCpuContext (), {1 , 2 , 3 , 4 });
54
86
});
55
87
56
- m.def (" get_cpu_int_array1" , []() {
88
+ m.def (" get_cpu_int_array1" , []() -> Array1< int32_t > {
57
89
return Array1<int32_t >(GetCpuContext (), {1 , 2 , 3 , 4 });
58
90
});
59
91
60
92
m.def (
61
93
" get_cuda_float_array1" ,
62
- [](int32_t gpu_id = -1 ) {
94
+ [](int32_t gpu_id = -1 ) -> Array1< float > {
63
95
return Array1<float >(GetCudaContext (gpu_id), {0 , 1 , 2 , 3 });
64
96
},
65
97
py::arg (" gpu_id" ) = -1 );
66
98
67
99
m.def (
68
100
" get_cuda_int_array1" ,
69
- [](int32_t gpu_id = -1 ) {
101
+ [](int32_t gpu_id = -1 ) -> Array1< int32_t > {
70
102
return Array1<int32_t >(GetCudaContext (gpu_id), {0 , 1 , 2 , 3 });
71
103
},
72
104
py::arg (" gpu_id" ) = -1 );
105
+
106
+ m.def (" get_cpu_arc_array1" , []() -> Array1<Arc> {
107
+ std::vector<Arc> arcs = {
108
+ {1 , 2 , 3 , 1.5 },
109
+ {10 , 20 , 30 , 2.5 },
110
+ };
111
+ return Array1<Arc>(GetCpuContext (), arcs);
112
+ });
113
+
114
+ m.def (
115
+ " get_cuda_arc_array1" ,
116
+ [](int32_t gpu_id = -1 ) -> Array1<Arc> {
117
+ std::vector<Arc> arcs = {
118
+ {1 , 2 , 3 , 1.5 },
119
+ {10 , 20 , 30 , 2.5 },
120
+ };
121
+ return Array1<Arc>(GetCudaContext (gpu_id), arcs);
122
+ },
123
+ py::arg (" gpu_id" ) = -1 );
124
+
125
+ m.def (" get_cpu_int_array2" , []() -> Array2<int32_t > {
126
+ Array1<int32_t > array1 (GetCpuContext (), {1 , 2 , 3 , 4 , 5 , 6 });
127
+ return Array2<int32_t >(array1, 2 , 3 );
128
+ });
129
+
130
+ m.def (
131
+ " get_cuda_float_array2" ,
132
+ [](int32_t gpu_id = -1 ) -> Array2<float > {
133
+ Array1<float > array1 (GetCudaContext (gpu_id), {1 , 2 , 3 , 4 , 5 , 6 });
134
+ return Array2<float >(array1, 2 , 3 );
135
+ },
136
+ py::arg (" gpu_id" ) = -1 );
73
137
}
74
138
75
139
} // namespace k2
0 commit comments