Skip to content

Commit d1b1bea

Browse files
authored
Wrap Fsa to Python. (k2-fsa#193)
* wrap Fsa to Python. * add more documentation. * resolve some comments. * fix style issues. * resolve some comments. * Make Fsa a Python class. Now the API is more flexible and is much simpler. * fix style issues. * fix a typo. * add more documentation. * fix style issues. * wrap Array2<T> to Python. * fix style issues. * Wrap DenseFsaVec to Python. * add visualization support for FSA.
1 parent 866c6cf commit d1b1bea

25 files changed

+1113
-36
lines changed

k2/csrc/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,6 @@ function(k2_add_cuda_test name)
7676
)
7777
endfunction()
7878

79-
foreach (name IN LISTS cuda_tests)
79+
foreach(name IN LISTS cuda_tests)
8080
k2_add_cuda_test(${name})
81-
endforeach ()
81+
endforeach()

k2/csrc/array_ops_inl.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ Array1<T> Append(int32_t num_arrays, const Array1<T> **src) {
241241
std::vector<int32_t> row_splits_vec(num_arrays + 1);
242242
int32_t sum = 0, max_dim = 0;
243243
row_splits_vec[0] = sum;
244-
for (int32_t i = 0; i < num_arrays; i++) {
244+
for (int32_t i = 0; i < num_arrays; ++i) {
245245
int32_t dim = src[i]->Dim();
246246
if (dim > max_dim) max_dim = dim;
247247
sum += dim;
@@ -256,7 +256,7 @@ Array1<T> Append(int32_t num_arrays, const Array1<T> **src) {
256256
// a simple loop is faster, although the other branches should still work on
257257
// CPU.
258258
int32_t elem_size = src[0]->ElementSize();
259-
for (int32_t i = 0; i < num_arrays; i++) {
259+
for (int32_t i = 0; i < num_arrays; ++i) {
260260
int32_t this_dim = src[i]->Dim();
261261
const T *this_src_data = src[i]->Data();
262262
memcpy(static_cast<void *>(ans_data),
@@ -268,7 +268,7 @@ Array1<T> Append(int32_t num_arrays, const Array1<T> **src) {
268268
Array1<int32_t> row_splits(c, row_splits_vec);
269269
const int32_t *row_splits_data = row_splits.Data();
270270
std::vector<const T *> src_ptrs_vec(num_arrays);
271-
for (int32_t i = 0; i < num_arrays; i++) src_ptrs_vec[i] = src[i]->Data();
271+
for (int32_t i = 0; i < num_arrays; ++i) src_ptrs_vec[i] = src[i]->Data();
272272
Array1<const T *> src_ptrs(c, src_ptrs_vec);
273273
const T **src_ptrs_data = src_ptrs.Data();
274274
int32_t avg_input_size = ans_size / num_arrays;
@@ -305,10 +305,10 @@ Array1<T> Append(int32_t num_arrays, const Array1<T> **src) {
305305
// them on CPU.
306306
std::vector<uint64_t> index_map;
307307
index_map.reserve((2 * ans_size) / block_dim);
308-
for (int32_t i = 0; i < num_arrays; i++) {
308+
for (int32_t i = 0; i < num_arrays; ++i) {
309309
int32_t this_array_size = src[i]->Dim();
310310
int32_t this_num_blocks = NumBlocks(this_array_size, block_dim);
311-
for (int32_t j = 0; j < this_num_blocks; j++) {
311+
for (int32_t j = 0; j < this_num_blocks; ++j) {
312312
index_map.push_back((static_cast<uint64_t>(j) << 32) +
313313
static_cast<uint64_t>(i));
314314
}
@@ -362,10 +362,10 @@ void ApplyOpPerSublist(Ragged<T> &src, T default_value, Array1<T> *dst) {
362362

363363
if (c->GetDeviceType() == kCpu) {
364364
int32_t j = row_splits[0];
365-
for (int32_t i = 0; i < num_rows; i++) {
365+
for (int32_t i = 0; i < num_rows; ++i) {
366366
T val = default_value;
367367
int32_t row_end = row_splits[i + 1];
368-
for (; j < row_end; j++) {
368+
for (; j < row_end; ++j) {
369369
T elem = values_data[j];
370370
val = op(elem, val);
371371
}

k2/csrc/dtype.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
namespace k2 {
2020

21-
enum BaseType { // BaseType is the *general type*
22-
kUnknownBase = 0, // e.g. can use this for structs
23-
kFloatBase = 1,
24-
kIntBase = 2, // signed int
25-
kUintBase = 3, // unsigned int
21+
enum BaseType : int8_t { // BaseType is the *general type*
22+
kUnknownBase = 0, // e.g. can use this for structs
23+
kFloatBase = 1, // real numbers, e.g., float or double
24+
kIntBase = 2, // signed int, e.g., int8_t, int32_t
25+
kUintBase = 3, // unsigned int, e.g, uint32_t, uint64_t
2626
};
2727

2828
class DtypeTraits {

k2/csrc/fsa_utils.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
/**
2-
* @brief Utilities for reading, writing and creating FSAs.
2+
* @brief Utilities for creating FSAs.
3+
*
4+
* Note that serializations are done in Python.
35
*
46
* @copyright
57
* Copyright (c) 2020 Mobvoi Inc. (authors: Fangjun Kuang)

k2/csrc/fsa_utils.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
/**
2-
* @brief Utilities for reading, writing and creating FSAs.
2+
* @brief Utilities for creating FSAs.
3+
*
4+
* Note that serializations are done in Python.
35
*
46
* @copyright
57
* Copyright (c) 2020 Mobvoi Inc. (authors: Fangjun Kuang)
@@ -45,6 +47,8 @@ namespace k2 {
4547
4648
CAUTION: We assume that `final_state` has the largest state number.
4749
50+
CAUTION: The first column has to be in non-decreasing order.
51+
4852
@param [in] s The input string. See the above description for its format.
4953
@param [in] negate_scores
5054
If true, the string form has the weights as costs,

k2/csrc/ragged.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ struct Ragged {
462462
return Ragged<T>(new_shape, values);
463463
}
464464

465-
Ragged<T> To(ContextPtr ctx) {
465+
Ragged<T> To(ContextPtr ctx) const {
466466
RaggedShape new_shape = shape.To(ctx);
467467
Array1<T> new_values = values.To(ctx);
468468
return Ragged<T>(new_shape, new_values);

k2/python/csrc/torch.cu

+9-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,17 @@
1212

1313
#if defined(K2_USE_PYTORCH)
1414

15+
#include "k2/python/csrc/torch/arc.h"
1516
#include "k2/python/csrc/torch/array.h"
17+
#include "k2/python/csrc/torch/fsa.h"
18+
#include "k2/python/csrc/torch/ragged.h"
1619

17-
void PybindTorch(py::module &m) { PybindArray(m); }
20+
void PybindTorch(py::module &m) {
21+
PybindArc(m);
22+
PybindArray(m);
23+
PybindRagged(m);
24+
PybindFsa(m);
25+
}
1826

1927
#else
2028

k2/python/csrc/torch/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# please keep the list sorted
22
set(torch_srcs
3+
arc.cu
34
array.cu
5+
fsa.cu
6+
ragged.cu
47
torch_util.cu
58
)
69

k2/python/csrc/torch/arc.cu

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/**
2+
* @brief python wrappers for Arc.
3+
*
4+
* @copyright
5+
* Copyright (c) 2020 Mobvoi Inc. (authors: Fangjun Kuang)
6+
*
7+
* @copyright
8+
* See LICENSE for clarification regarding multiple authors
9+
*/
10+
11+
#include <string>
12+
13+
#include "k2/csrc/fsa.h"
14+
#include "k2/python/csrc/torch/arc.h"
15+
#include "k2/python/csrc/torch/torch_util.h"
16+
#include "torch/extension.h"
17+
18+
namespace k2 {
19+
20+
static void PybindArcImpl(py::module &m) {
21+
using PyClass = Arc;
22+
py::class_<PyClass> pyclass(m, "Arc");
23+
pyclass.def(py::init<>());
24+
pyclass.def(py::init<int32_t, int32_t, int32_t, float>(),
25+
py::arg("src_state"), py::arg("dest_state"), py::arg("symbol"),
26+
py::arg("score"));
27+
28+
pyclass.def_readwrite("src_state", &PyClass::src_state)
29+
.def_readwrite("dest_state", &PyClass::dest_state)
30+
.def_readwrite("symbol", &PyClass::symbol)
31+
.def_readwrite("score", &PyClass::score);
32+
33+
pyclass.def("__str__", [](const PyClass &self) -> std::string {
34+
std::ostringstream os;
35+
os << self;
36+
return os.str();
37+
});
38+
39+
m.def("_float_as_int",
40+
[](float f) -> int32_t { return *reinterpret_cast<int32_t *>(&f); });
41+
42+
m.def("_int_as_float",
43+
[](int32_t i) -> float { return *reinterpret_cast<float *>(&i); });
44+
45+
m.def("_as_int", [](torch::Tensor tensor) -> torch::Tensor {
46+
auto scalar_type = ToScalarType<int32_t>::value;
47+
return torch::from_blob(
48+
tensor.data_ptr(), tensor.sizes(), tensor.strides(),
49+
[tensor](void *p) {}, tensor.options().dtype(scalar_type));
50+
});
51+
52+
m.def("_as_float", [](torch::Tensor tensor) -> torch::Tensor {
53+
auto scalar_type = ToScalarType<float>::value;
54+
return torch::from_blob(
55+
tensor.data_ptr(), tensor.sizes(), tensor.strides(),
56+
[tensor](void *p) {}, tensor.options().dtype(scalar_type));
57+
});
58+
}
59+
60+
} // namespace k2
61+
62+
void PybindArc(py::module &m) { k2::PybindArcImpl(m); }

k2/python/csrc/torch/arc.h

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/**
2+
* @brief python wrappers for Arc.
3+
*
4+
* @copyright
5+
* Copyright (c) 2020 Mobvoi Inc. (authors: Fangjun Kuang)
6+
*
7+
* @copyright
8+
* See LICENSE for clarification regarding multiple authors
9+
*/
10+
11+
#ifndef K2_PYTHON_CSRC_TORCH_ARC_H_
12+
#define K2_PYTHON_CSRC_TORCH_ARC_H_
13+
14+
#include "k2/python/csrc/k2.h"
15+
16+
void PybindArc(py::module &m);
17+
18+
#endif // K2_PYTHON_CSRC_TORCH_ARC_H_

k2/python/csrc/torch/array.cu

+71-7
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,54 @@
99
*/
1010

1111
#include <type_traits>
12+
#include <vector>
1213

1314
#include "c10/core/ScalarType.h"
1415
#include "k2/csrc/array.h"
16+
#include "k2/csrc/fsa.h"
1517
#include "k2/csrc/pytorch_context.h"
1618
#include "k2/python/csrc/torch/array.h"
1719
#include "k2/python/csrc/torch/torch_util.h"
1820
#include "torch/extension.h"
1921

2022
namespace k2 {
2123

24+
template <typename T>
25+
static void PybindArray2Tpl(py::module &m, const char *name) {
26+
using PyClass = Array2<T>;
27+
py::class_<PyClass> pyclass(m, name);
28+
pyclass.def("tensor",
29+
[](PyClass &self) -> torch::Tensor { return ToTensor(self); });
30+
31+
pyclass.def_static(
32+
"from_tensor",
33+
[](torch::Tensor &tensor) -> PyClass {
34+
return FromTensor<T>(tensor, Array2Tag{});
35+
},
36+
py::arg("tensor"));
37+
38+
// the following functions are for testing only
39+
pyclass.def(
40+
"get", [](PyClass &self, int32_t i) -> Array1<T> { return self[i]; },
41+
py::arg("i"));
42+
43+
pyclass.def("__str__", [](const PyClass &self) {
44+
std::ostringstream os;
45+
os << self;
46+
return os.str();
47+
});
48+
}
49+
2250
template <typename T>
2351
static void PybindArray1Tpl(py::module &m, const char *name) {
2452
using PyClass = Array1<T>;
2553
py::class_<PyClass> pyclass(m, name);
26-
pyclass.def(py::init<>());
27-
pyclass.def("tensor", [](PyClass &self) { return ToTensor(self); });
54+
pyclass.def("tensor",
55+
[](PyClass &self) -> torch::Tensor { return ToTensor(self); });
2856

2957
pyclass.def_static(
3058
"from_tensor",
31-
[](torch::Tensor &tensor) { return FromTensor<T>(tensor); },
59+
[](torch::Tensor &tensor) -> PyClass { return FromTensor<T>(tensor); },
3260
py::arg("tensor"));
3361

3462
// the following functions are for testing only
@@ -46,30 +74,66 @@ static void PybindArrayImpl(py::module &m) {
4674
// users should not use classes with prefix `_` in Python.
4775
PybindArray1Tpl<float>(m, "_FloatArray1");
4876
PybindArray1Tpl<int>(m, "_Int32Array1");
77+
PybindArray1Tpl<Arc>(m, "_ArcArray1");
78+
79+
PybindArray2Tpl<float>(m, "_FloatArray2");
80+
PybindArray2Tpl<int>(m, "_Int32Array2");
4981

5082
// the following functions are for testing purposes
5183
// and they can be removed later.
52-
m.def("get_cpu_float_array1", []() {
84+
m.def("get_cpu_float_array1", []() -> Array1<float> {
5385
return Array1<float>(GetCpuContext(), {1, 2, 3, 4});
5486
});
5587

56-
m.def("get_cpu_int_array1", []() {
88+
m.def("get_cpu_int_array1", []() -> Array1<int32_t> {
5789
return Array1<int32_t>(GetCpuContext(), {1, 2, 3, 4});
5890
});
5991

6092
m.def(
6193
"get_cuda_float_array1",
62-
[](int32_t gpu_id = -1) {
94+
[](int32_t gpu_id = -1) -> Array1<float> {
6395
return Array1<float>(GetCudaContext(gpu_id), {0, 1, 2, 3});
6496
},
6597
py::arg("gpu_id") = -1);
6698

6799
m.def(
68100
"get_cuda_int_array1",
69-
[](int32_t gpu_id = -1) {
101+
[](int32_t gpu_id = -1) -> Array1<int32_t> {
70102
return Array1<int32_t>(GetCudaContext(gpu_id), {0, 1, 2, 3});
71103
},
72104
py::arg("gpu_id") = -1);
105+
106+
m.def("get_cpu_arc_array1", []() -> Array1<Arc> {
107+
std::vector<Arc> arcs = {
108+
{1, 2, 3, 1.5},
109+
{10, 20, 30, 2.5},
110+
};
111+
return Array1<Arc>(GetCpuContext(), arcs);
112+
});
113+
114+
m.def(
115+
"get_cuda_arc_array1",
116+
[](int32_t gpu_id = -1) -> Array1<Arc> {
117+
std::vector<Arc> arcs = {
118+
{1, 2, 3, 1.5},
119+
{10, 20, 30, 2.5},
120+
};
121+
return Array1<Arc>(GetCudaContext(gpu_id), arcs);
122+
},
123+
py::arg("gpu_id") = -1);
124+
125+
m.def("get_cpu_int_array2", []() -> Array2<int32_t> {
126+
Array1<int32_t> array1(GetCpuContext(), {1, 2, 3, 4, 5, 6});
127+
return Array2<int32_t>(array1, 2, 3);
128+
});
129+
130+
m.def(
131+
"get_cuda_float_array2",
132+
[](int32_t gpu_id = -1) -> Array2<float> {
133+
Array1<float> array1(GetCudaContext(gpu_id), {1, 2, 3, 4, 5, 6});
134+
return Array2<float>(array1, 2, 3);
135+
},
136+
py::arg("gpu_id") = -1);
73137
}
74138

75139
} // namespace k2

0 commit comments

Comments
 (0)