Skip to content

Commit d4b9943

Browse files
authored
Construct a CfsaVec from torch::Tensor via DLPack (k2-fsa#49)
* WIP: implement cfsa. * finish CfsaVec. * fix clang-tidy warnings and fix windows compiling. * fix memory leak. * fix clang-tidy warnings. * add python wrappers for FsaVec and some python tests. * fix python3.5 * fix cpplint. * resolve various comment. * add python wrappers for Cfsa. * add python wrappers for CfsaVec and support DLPack. Now we can construct a CfsaVec from `torch::Tensor`. * fix cpplint. * disable Windows. * fix an error. torch::Tensor should be contiguous.
1 parent eaffcf1 commit d4b9943

33 files changed

+1420
-136
lines changed

.clang-tidy

+13-2
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,28 @@ Checks: >
5252
-google-readability-braces-around-statements,
5353
-google-runtime-references,
5454
cppcoreguidelines-*,
55+
-cppcoreguidelines-avoid-c-arrays,
5556
-cppcoreguidelines-avoid-magic-numbers,
57+
-cppcoreguidelines-macro-usage,
58+
-cppcoreguidelines-no-malloc,
5659
-cppcoreguidelines-non-private-member-variables-in-classes,
5760
-cppcoreguidelines-owning-memory,
5861
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
59-
-cppcoreguidelines-special-member-functions,
62+
-cppcoreguidelines-pro-type-const-cast,
63+
-cppcoreguidelines-pro-type-member-init,
64+
-cppcoreguidelines-pro-type-reinterpret-cast,
6065
-cppcoreguidelines-pro-type-vararg,
66+
-cppcoreguidelines-special-member-functions,
6167
modernize-*,
68+
-modernize-avoid-c-arrays,
69+
-modernize-deprecated-headers,
70+
-modernize-use-default-member-init,
6271
-modernize-use-trailing-return-type,
6372
readability-*,
64-
-readability-magic-numbers,
6573
-readability-braces-around-statements,
74+
-readability-isolate-declaration,
75+
-readability-magic-numbers,
76+
-readability-static-definition-in-anonymous-namespace,
6677
-readability-uppercase-literal-suffix,
6778
performance-*,
6879

.github/workflows/build.yml

+14-3
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,30 @@ on:
1515
- master
1616

1717
env:
18-
BUILD_TYPE: Release
18+
BUILD_TYPE: Debug
1919

2020
jobs:
2121
build:
2222
runs-on: ${{ matrix.os }}
2323
strategy:
2424
matrix:
25-
os: [ubuntu-latest, macOS-latest, windows-latest]
25+
os: [ubuntu-latest, macOS-latest] #, windows-latest]
26+
python-version: [3.5, 3.6, 3.7, 3.8]
2627

2728
steps:
2829
# refer to https://github.com/actions/checkout
2930
- uses: actions/checkout@v2
3031

32+
- name: Setup Python ${{ matrix.python-version }}
33+
uses: actions/setup-python@v1
34+
with:
35+
python-version: ${{ matrix.python-version }}
36+
37+
- name: Install Python dependencies
38+
run: |
39+
python3 -m pip install --upgrade pip
40+
python3 -m pip install torch==1.5.0
41+
3142
- name: Create Build Directory
3243
run: cmake -E make_directory ${{runner.workspace}}/build
3344

@@ -44,4 +55,4 @@ jobs:
4455
- name: Test
4556
shell: bash
4657
working-directory: ${{runner.workspace}}/build
47-
run: CTEST_OUTPUT_ON_FAILURE=1 ctest --build-config $BUILD_TYPE
58+
run: ctest --verbose --build-config $BUILD_TYPE

.github/workflows/style_check.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ jobs:
4141
run: |
4242
# stop the build if there are Python syntax errors or undefined names
4343
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
44-
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
45-
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
44+
# exit-zero treats all errors as warnings.
45+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=79 --statistics
4646
4747
# - name: Install cppcheck
4848
# run: |
@@ -54,7 +54,7 @@ jobs:
5454
# cmake ..
5555
# make -j
5656
# sudo make install
57-
57+
5858
- name: Create Build Directory
5959
run: cmake -E make_directory ${{runner.workspace}}/build
6060

.style.yapf

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[style]
22
based_on_style = google
3+
column_limit = 79

k2/csrc/CMakeLists.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# please sort the source files alphabetically
22
add_library(fsa
33
aux_labels.cc
4+
determinize.cc
5+
fsa.cc
46
fsa_algo.cc
57
fsa_equivalent.cc
68
fsa_renderer.cc
79
fsa_util.cc
810
properties.cc
11+
util.cc
912
weights.cc
10-
determinize.cc
1113
)
1214

1315
target_include_directories(fsa PUBLIC ${CMAKE_SOURCE_DIR})
@@ -40,6 +42,7 @@ set(fsa_tests
4042
fsa_algo_test
4143
fsa_equivalent_test
4244
fsa_renderer_test
45+
fsa_test
4346
fsa_util_test
4447
properties_test
4548
weights_test

k2/csrc/determinize.cc

+8-6
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ void TraceBack(std::unordered_set<LogSumTracebackState *> *cur_states,
7474
for (int32_t i = 0; i < num_steps; i++) {
7575
for (LogSumTracebackState *state_ptr : *cur_states) {
7676
double backward_prob = state_ptr->backward_prob;
77-
for (auto link : state_ptr->prev_elements) {
78-
float arc_log_posterior = link.forward_prob + backward_prob;
77+
for (const auto &link : state_ptr->prev_elements) {
78+
auto arc_log_posterior =
79+
static_cast<float>(link.forward_prob + backward_prob);
7980
deriv_out->push_back(
8081
std::pair<int32_t, float>(link.arc_index, expf(arc_log_posterior)));
8182
LogSumTracebackState *prev_state = link.prev_state.get();
@@ -96,7 +97,7 @@ void TraceBack(std::unordered_set<LogSumTracebackState *> *cur_states,
9697
// algorithm.
9798
CHECK_EQ(cur_states->size(), 1);
9899
double prev_forward_prob = (*(cur_states->begin()))->forward_prob;
99-
*weight_out = cur_forward_prob - prev_forward_prob;
100+
*weight_out = static_cast<float>(cur_forward_prob - prev_forward_prob);
100101
// The following is mostly for ease of interpretability of the output;
101102
// conceptually the order makes no difference.
102103
// TODO(dpovey): maybe remove this, for efficiency?
@@ -105,20 +106,21 @@ void TraceBack(std::unordered_set<LogSumTracebackState *> *cur_states,
105106

106107
void TraceBack(std::unordered_set<MaxTracebackState *> *cur_states,
107108
int32_t num_steps,
108-
const float *, // arc_weights_in, unused.
109+
const float *unused, // arc_weights_in, unused.
109110
float *weight_out, std::vector<int32_t> *deriv_out) {
111+
(void)unused;
110112
CHECK_EQ(cur_states->size(), 1);
111113
MaxTracebackState *state = *(cur_states->begin());
112114
double cur_forward_prob = state->forward_prob;
113115
deriv_out->resize(num_steps);
114-
for (int32_t i = num_steps - 1; i >= 0; i--) {
116+
for (int32_t i = num_steps - 1; i >= 0; --i) {
115117
// `deriv_out` is just a list of arc indexes in the input FSA
116118
// that this output arc depends on (it's their sum).
117119
(*deriv_out)[i] = state->arc_id;
118120
state = state->prev_state.get();
119121
}
120122
double prev_forward_prob = state->forward_prob;
121-
*weight_out = cur_forward_prob - prev_forward_prob;
123+
*weight_out = static_cast<float>(cur_forward_prob - prev_forward_prob);
122124
}
123125

124126
template <>

k2/csrc/determinize.h

+14-14
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ namespace k2 {
8686
arcs in the output FSA.
8787
8888
89-
*The problem with differentability
89+
*The problem with differentiability
9090
9191
Consider how to differentiate the weights of the output weighted FSA
9292
w.r.t. those of the input. The problem with differentiability if we use the
@@ -113,7 +113,7 @@ namespace k2 {
113113
114114
*Different normalization
115115
116-
Our form of "normalization" of this representation is differen too. The
116+
Our form of "normalization" of this representation is different too. The
117117
normalization is to make `symbol_sequence` as short as possible, and advance
118118
`base_state` to compensate. For instance, if `symbol_sequence` is `a b c
119119
d`, but the weighted subset of states we can reach by this symbol sequence
@@ -177,7 +177,7 @@ struct MaxTracebackState {
177177

178178
/**
179179
@param [in] state_id State in input FSA that this corresponds to
180-
@param [in] src Previous LogSumTracebackState that we'll point back
180+
@param [in] src Previous MaxTracebackState that we'll point back
181181
to, or NULL
182182
@param [in] incoming_arc_index Arc-index in input FSA.
183183
Its src_state will equal src->state_id,
@@ -213,7 +213,7 @@ class LogSumTracebackState;
213213
/*
214214
This struct is used inside LogSumTracebackState; it represents an
215215
arc that traces back to a previous LogSumTracebackState.
216-
A LogSumTracebackState represents a weighted colletion of paths
216+
A LogSumTracebackState represents a weighted collection of paths
217217
terminating in a specific state.
218218
*/
219219
struct LogSumTracebackLink {
@@ -364,7 +364,7 @@ void TraceBack(std::unordered_set<LogSumTracebackState *> *cur_states,
364364
// for LogSumTracebackState, above. This version is simpler.
365365
void TraceBack(std::unordered_set<MaxTracebackState *> *cur_states,
366366
int32_t num_steps,
367-
const float *, // arc_weights_in, unused.
367+
const float *unused, // arc_weights_in, unused.
368368
float *weight_out, std::vector<int32_t> *deriv_out);
369369

370370
template <class TracebackState>
@@ -633,7 +633,7 @@ void DetState<TracebackState>::Normalize(const WfsaWithFbWeights &wfsa_in,
633633
std::unordered_set<TracebackState *> cur_states;
634634

635635
double fb_prob = -std::numeric_limits<double>::infinity();
636-
for (auto p : elements) {
636+
for (const auto &p : elements) {
637637
TracebackState *state = p.second.get();
638638
fb_prob = LogSumOrMax<TracebackState>(
639639
fb_prob,
@@ -701,19 +701,19 @@ class DetStateMap {
701701
if (inserted) {
702702
a->state_id = cur_output_state_++;
703703
return true;
704-
} else {
705-
a->state_id = p.first->second;
706-
return false;
707704
}
705+
706+
a->state_id = p.first->second;
707+
return false;
708708
}
709709

710710
int32_t size() const { return cur_output_state_; }
711711

712712
private:
713713
// simple hashing function that just takes the first element of the pair.
714714
struct PairHasher {
715-
size_t operator()(const std::pair<uint64_t, uint64_t> &p) const {
716-
return static_cast<size_t>(p.first);
715+
std::size_t operator()(const std::pair<uint64_t, uint64_t> &p) const {
716+
return static_cast<std::size_t>(p.first);
717717
}
718718
};
719719

@@ -781,7 +781,7 @@ class DetStateMap {
781781
}
782782

783783
struct DetStateHasher {
784-
size_t operator()(const std::pair<uint64_t, uint64_t> &p) const {
784+
std::size_t operator()(const std::pair<uint64_t, uint64_t> &p) const {
785785
return p.first;
786786
}
787787
};
@@ -837,9 +837,9 @@ float DeterminizePrunedTpl(
837837
arc_derivs_out->begin());
838838
if (!queue.empty()) { // We stopped early due to max_step
839839
return total_prob - queue.top()->forward_backward_prob;
840-
} else {
841-
return beam;
842840
}
841+
842+
return beam;
843843
}
844844
} // namespace k2
845845

0 commit comments

Comments
 (0)