Skip to content

Commit eaffcf1

Browse files
authored
implement MaxAuxLabels1(2) (k2-fsa#50)
* implement MaxAuxLabels1(2) * fix some issues * documented that treat epsilon same as other symbols
1 parent 71c7fc2 commit eaffcf1

7 files changed

+276
-141
lines changed

k2/csrc/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# please sort the source files alphabetically
22
add_library(fsa
3+
aux_labels.cc
34
fsa_algo.cc
45
fsa_equivalent.cc
56
fsa_renderer.cc
@@ -35,6 +36,7 @@ endfunction()
3536

3637
# please sort the source files alphabetically
3738
set(fsa_tests
39+
aux_labels_test
3840
fsa_algo_test
3941
fsa_equivalent_test
4042
fsa_renderer_test

k2/csrc/aux_labels.cc

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// k2/csrc/aux_labels.cc
2+
3+
// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)
4+
5+
// See ../../LICENSE for clarification regarding multiple authors
6+
7+
#include "k2/csrc/aux_labels.h"
8+
9+
#include <numeric>
10+
#include <vector>
11+
12+
#include "glog/logging.h"
13+
#include "k2/csrc/fsa.h"
14+
15+
namespace k2 {
16+
17+
void MapAuxLabels1(const AuxLabels &labels_in,
18+
const std::vector<int32_t> &arc_map, AuxLabels *labels_out) {
19+
CHECK_NOTNULL(labels_out);
20+
auto &start_pos = labels_out->start_pos;
21+
auto &labels = labels_out->labels;
22+
start_pos.clear();
23+
start_pos.reserve(arc_map.size() + 1);
24+
labels.clear();
25+
26+
int32_t num_labels = 0;
27+
auto labels_in_iter_begin = labels_in.labels.begin();
28+
for (const auto &arc_index : arc_map) {
29+
start_pos.push_back(num_labels);
30+
int32_t pos_start = labels_in.start_pos[arc_index];
31+
int32_t pos_end = labels_in.start_pos[arc_index + 1];
32+
labels.insert(labels.end(), labels_in_iter_begin + pos_start,
33+
labels_in_iter_begin + pos_end);
34+
num_labels += pos_end - pos_start;
35+
}
36+
start_pos.push_back(num_labels);
37+
}
38+
39+
void MapAuxLabels2(const AuxLabels &labels_in,
40+
const std::vector<std::vector<int32_t>> &arc_map,
41+
AuxLabels *labels_out) {
42+
CHECK_NOTNULL(labels_out);
43+
auto &start_pos = labels_out->start_pos;
44+
auto &labels = labels_out->labels;
45+
start_pos.clear();
46+
start_pos.reserve(arc_map.size() + 1);
47+
labels.clear();
48+
49+
int32_t num_labels = 0;
50+
auto labels_in_iter_begin = labels_in.labels.begin();
51+
for (const auto &arc_indexes : arc_map) {
52+
start_pos.push_back(num_labels);
53+
for (const auto &arc_index : arc_indexes) {
54+
int32_t pos_start = labels_in.start_pos[arc_index];
55+
int32_t pos_end = labels_in.start_pos[arc_index + 1];
56+
labels.insert(labels.end(), labels_in_iter_begin + pos_start,
57+
labels_in_iter_begin + pos_end);
58+
num_labels += pos_end - pos_start;
59+
}
60+
}
61+
start_pos.push_back(num_labels);
62+
}
63+
64+
} // namespace k2

k2/csrc/aux_labels.h

+7-14
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ namespace k2 {
2929
3030
*/
3131

32-
3332
/*
3433
This allows you to store auxiliary labels (e.g. olabels or ilabels)
3534
on each arc of an Fsa.
@@ -40,13 +39,13 @@ struct AuxLabels {
4039
`labels` of the label sequence on arc i. start_pos.end()
4140
equals labels.size(). */
4241
std::vector<int32_t> start_pos;
43-
/* For arc i, (labels[start_pos[i] ], labels[start_pos[i]+1], ... labels[start_pos[i+1]-1])
44-
are the list of labels on that arc. None of the elements of `labels` are
45-
expected to be zero (epsilon). */
42+
/* For arc i, (labels[start_pos[i] ], labels[start_pos[i]+1], ...
43+
labels[start_pos[i+1]-1]) are the list of labels on that arc.
44+
We treat epsilon the same as other symbols here, so there are no
45+
requirements on elements of `labels`. */
4646
std::vector<int32_t> labels;
4747
};
4848

49-
5049
/*
5150
Maps auxiliary labels after an FSA operation where each arc in the output
5251
FSA corresponds to exactly one arc in the input FSA.
@@ -57,8 +56,7 @@ struct AuxLabels {
5756
@param [in] labels_out Labels on the arcs of the output FSA
5857
*/
5958
void MapAuxLabels1(const AuxLabels &labels_in,
60-
const std::vector<int32_t> &arc_map,
61-
AuxLabels *labels_out);
59+
const std::vector<int32_t> &arc_map, AuxLabels *labels_out);
6260

6361
/*
6462
Maps auxiliary labels after an FSA operation where each arc in the output
@@ -70,10 +68,9 @@ void MapAuxLabels1(const AuxLabels &labels_in,
7068
@param [in] labels_out Labels on the arcs of the output FSA
7169
*/
7270
void MapAuxLabels2(const AuxLabels &labels_in,
73-
const std::vector<std::vector<int32_t> > &arc_map,
71+
const std::vector<std::vector<int32_t>> &arc_map,
7472
AuxLabels *labels_out);
7573

76-
7774
/*
7875
Invert an FST, swapping the symbols in the FSA with the auxiliary labels.
7976
(e.g. swap input and output symbols in FST, but you decide which is which).
@@ -92,13 +89,9 @@ void MapAuxLabels2(const AuxLabels &labels_in,
9289
`fsa_in`, although epsilons (kEpsilon, zeros) will be
9390
removed.
9491
*/
95-
void InvertFst(const Fsa &fsa_in,
96-
const AuxLabels &labels_in,
97-
Fsa *fsa_out,
92+
void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out,
9893
AuxLabels *aux_labels_out);
9994

100-
101-
10295
} // namespace k2
10396

10497
#endif // K2_CSRC_AUX_LABELS_H_

k2/csrc/aux_labels_test.cc

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// k2/csrc/aux_labels_test.cc
2+
3+
// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)
4+
5+
// See ../../LICENSE for clarification regarding multiple authors
6+
7+
#include "k2/csrc/aux_labels.h"
8+
9+
#include <utility>
10+
#include <vector>
11+
12+
#include "gmock/gmock.h"
13+
#include "gtest/gtest.h"
14+
#include "k2/csrc/fsa.h"
15+
16+
namespace k2 {
17+
18+
class AuxLablesTest : public ::testing::Test {
19+
protected:
20+
AuxLablesTest() {
21+
std::vector<int32_t> start_pos = {0, 1, 3, 6, 7};
22+
std::vector<int32_t> labels = {1, 2, 3, 4, 5, 6, 7};
23+
aux_labels_in_.start_pos = std::move(start_pos);
24+
aux_labels_in_.labels = std::move(labels);
25+
}
26+
27+
AuxLabels aux_labels_in_;
28+
};
29+
30+
TEST_F(AuxLablesTest, MapAuxLabels1) {
31+
{
32+
// empty arc_map
33+
std::vector<int32_t> arc_map;
34+
AuxLabels aux_labels_out;
35+
// some dirty data
36+
aux_labels_out.start_pos = {1, 2, 3};
37+
aux_labels_out.labels = {4, 5};
38+
MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out);
39+
40+
EXPECT_TRUE(aux_labels_out.labels.empty());
41+
ASSERT_EQ(aux_labels_out.start_pos.size(), 1);
42+
EXPECT_EQ(aux_labels_out.start_pos[0], 0);
43+
}
44+
45+
{
46+
std::vector<int32_t> arc_map = {2, 0, 3};
47+
AuxLabels aux_labels_out;
48+
MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out);
49+
50+
ASSERT_EQ(aux_labels_out.start_pos.size(), 4);
51+
EXPECT_THAT(aux_labels_out.start_pos, ::testing::ElementsAre(0, 3, 4, 5));
52+
ASSERT_EQ(aux_labels_out.labels.size(), 5);
53+
EXPECT_THAT(aux_labels_out.labels, ::testing::ElementsAre(4, 5, 6, 1, 7));
54+
}
55+
56+
{
57+
// all arcs in input fsa are remained
58+
std::vector<int32_t> arc_map = {2, 0, 3, 1};
59+
AuxLabels aux_labels_out;
60+
MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out);
61+
62+
ASSERT_EQ(aux_labels_out.start_pos.size(), 5);
63+
EXPECT_THAT(aux_labels_out.start_pos,
64+
::testing::ElementsAre(0, 3, 4, 5, 7));
65+
ASSERT_EQ(aux_labels_out.labels.size(), 7);
66+
EXPECT_THAT(aux_labels_out.labels,
67+
::testing::ElementsAre(4, 5, 6, 1, 7, 2, 3));
68+
}
69+
}
70+
71+
TEST_F(AuxLablesTest, MapAuxLabels2) {
72+
{
73+
// empty arc_map
74+
std::vector<std::vector<int32_t>> arc_map;
75+
AuxLabels aux_labels_out;
76+
// some dirty data
77+
aux_labels_out.start_pos = {1, 2, 3};
78+
aux_labels_out.labels = {4, 5};
79+
MapAuxLabels2(aux_labels_in_, arc_map, &aux_labels_out);
80+
81+
EXPECT_TRUE(aux_labels_out.labels.empty());
82+
ASSERT_EQ(aux_labels_out.start_pos.size(), 1);
83+
EXPECT_EQ(aux_labels_out.start_pos[0], 0);
84+
}
85+
86+
{
87+
std::vector<std::vector<int32_t>> arc_map = {{2, 3}, {0, 1}, {0}, {2}};
88+
AuxLabels aux_labels_out;
89+
MapAuxLabels2(aux_labels_in_, arc_map, &aux_labels_out);
90+
91+
ASSERT_EQ(aux_labels_out.start_pos.size(), 5);
92+
EXPECT_THAT(aux_labels_out.start_pos,
93+
::testing::ElementsAre(0, 4, 7, 8, 11));
94+
ASSERT_EQ(aux_labels_out.labels.size(), 11);
95+
EXPECT_THAT(aux_labels_out.labels,
96+
::testing::ElementsAre(4, 5, 6, 7, 1, 2, 3, 1, 4, 5, 6));
97+
}
98+
}
99+
100+
} // namespace k2

0 commit comments

Comments
 (0)