Skip to content

Commit 68a5d00

Browse files
authored
add input/output symbol table. (k2-fsa#206)
* add input/output symbol table. * add validation for SymbolTable.
1 parent b4848b7 commit 68a5d00

File tree

5 files changed

+200
-29
lines changed

5 files changed

+200
-29
lines changed

k2/csrc/fsa_utils.cu

+7-10
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,8 @@ static Fsa AcceptorFromStream(std::string first_line, std::istringstream &is,
159159
K2_CHECK_EQ(original_final_states.size(), original_final_weights.size());
160160
int32_t super_final_state = max_state + 1;
161161
for (std::size_t i = 0; i != original_final_states.size(); ++i) {
162-
arcs.emplace_back(original_final_states[i],
163-
super_final_state,
164-
-1, // kFinalSymbol
162+
arcs.emplace_back(original_final_states[i], super_final_state,
163+
-1, // kFinalSymbol
165164
scale * original_final_weights[i]);
166165
}
167166
}
@@ -178,8 +177,7 @@ static Fsa AcceptorFromStream(std::string first_line, std::istringstream &is,
178177
}
179178

180179
static Fsa TransducerFromStream(std::string first_line, std::istringstream &is,
181-
bool openfst,
182-
Array1<int32_t> *aux_labels) {
180+
bool openfst, Array1<int32_t> *aux_labels) {
183181
K2_CHECK(aux_labels != nullptr);
184182

185183
std::vector<int32_t> state_aux_labels;
@@ -240,16 +238,15 @@ static Fsa TransducerFromStream(std::string first_line, std::istringstream &is,
240238
K2_CHECK_EQ(openfst, true);
241239
K2_CHECK_EQ(original_final_states.size(), original_final_weights.size());
242240
int32_t super_final_state = max_state + 1;
243-
for (std::size_t = 0; i != original_final_states.size(); ++i) {
244-
arcs.emplace_back(original_final_states[i],
245-
super_final_state,
246-
-1, // kFinalSymbol
241+
for (std::size_t i = 0; i != original_final_states.size(); ++i) {
242+
arcs.emplace_back(original_final_states[i], super_final_state,
243+
-1, // kFinalSymbol
247244
scale * original_final_weights[i]);
248245
// TODO(guoguo) We are not sure yet what to put as the auxiliary label for
249246
// arcs entering the super final state. The only real choices
250247
// are kEpsilon or kFinalSymbol. We are using kEpsilon for
251248
// now.
252-
state_aux_labels.push_back(0); // kEpsilon
249+
state_aux_labels.push_back(0); // kEpsilon
253250
}
254251
}
255252

k2/python/k2/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from .array import Array
22
from .fsa import Fsa
3+
from .symbol_table import SymbolTable
34
from _k2 import Arc
45

56
# please keep the list sorted
67
__all__ = [
78
'Arc',
89
'Array',
910
'Fsa',
11+
'SymbolTable',
1012
]

k2/python/k2/fsa.py

+38-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99

10+
from .symbol_table import SymbolTable
1011
from _k2 import _Fsa
1112
from _k2 import _as_float
1213
from _k2 import _fsa_from_str
@@ -17,7 +18,7 @@
1718

1819
class Fsa(object):
1920

20-
def __init__(self, s: str, negate_scores: bool = False):
21+
def __init__(self, s: str, openfst: bool = False):
2122
'''Create an Fsa from a string.
2223
2324
The given string `s` consists of lines with the following format:
@@ -49,14 +50,14 @@ def __init__(self, s: str, negate_scores: bool = False):
4950
Args:
5051
s:
5152
The input string. Refer to the above comment for its format.
52-
negate_scores:
53+
openfst:
5354
Optional. If true, the string form has the weights as costs,
5455
not scores, so we negate as we read.
5556
'''
5657
fsa: _Fsa
5758
aux_labels: Optional[torch.Tensor]
5859

59-
fsa, aux_labels = _fsa_from_str(s, negate_scores)
60+
fsa, aux_labels = _fsa_from_str(s, openfst)
6061

6162
self._fsa = fsa
6263
self._aux_labels = aux_labels
@@ -106,20 +107,43 @@ def from_tensor(cls,
106107
ans._aux_labels = aux_labels
107108
return ans
108109

109-
def to_str(self, negate_scores: bool = False) -> str:
110+
def set_isymbol(self, isym: SymbolTable) -> None:
111+
'''Set the input symbol table.
112+
113+
Args:
114+
isym:
115+
The input symbol table.
116+
Returns:
117+
None.
118+
'''
119+
self.isym = isym
120+
121+
def set_osymbol(self, osym: SymbolTable) -> None:
122+
'''Set the output symbol table.
123+
124+
Args:
125+
osym:
126+
The output symbol table.
127+
128+
Returns:
129+
None.
130+
'''
131+
self.osym = osym
132+
133+
def to_str(self, openfst: bool = False) -> str:
110134
'''Convert an Fsa to a string.
111135
112136
Note:
113137
The returned string can be used to construct an Fsa.
114138
115139
Args:
116-
negate_scores:
140+
openfst:
117141
Optional. If true, we negate the score during the conversion,
118142
119143
Returns:
120144
A string representation of the Fsa.
121145
'''
122-
return _fsa_to_str(self._fsa, negate_scores, self._aux_labels)
146+
return _fsa_to_str(self._fsa, openfst, self._aux_labels)
123147

124148
@property
125149
def arcs(self) -> torch.Tensor:
@@ -221,6 +245,7 @@ def to_dot(self) -> Digraph:
221245
src_state, dst_state, label = arc.tolist()
222246
src_state = str(src_state)
223247
dst_state = str(dst_state)
248+
label = int(label)
224249
if label == -1:
225250
final_state = dst_state
226251
if src_state not in seen:
@@ -235,10 +260,16 @@ def to_dot(self) -> Digraph:
235260
dot.node(dst_state, label=dst_state, **default_node_attr)
236261
seen.add(dst_state)
237262
if self._aux_labels is not None:
238-
aux_label = f':{self._aux_labels[i]}'
263+
aux_label = int(self._aux_labels[i])
264+
if hasattr(self, 'osym'):
265+
aux_label = self.osym.get(aux_label)
266+
aux_label = f':{aux_label}'
239267
else:
240268
aux_label = ''
241269

270+
if hasattr(self, 'isym') and label != -1:
271+
label = self.isym.get(label)
272+
242273
dot.edge(src_state,
243274
dst_state,
244275
label=f'{label}{aux_label}/{weight:.2f}')

k2/python/k2/symbol_table.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) 2020 Mobvoi Inc. (authors: Fangjun Kuang)
2+
#
3+
# See ../../../LICENSE for clarification regarding multiple authors
4+
5+
from typing import Dict
6+
from typing import Union
7+
from dataclasses import dataclass
8+
9+
10+
@dataclass(frozen=True)
11+
class SymbolTable(object):
12+
_id2sym: Dict[int, str]
13+
'''Map an integer to a symbol.
14+
'''
15+
16+
_sym2id: Dict[str, int]
17+
'''Map a symbol to an integer.
18+
'''
19+
20+
def __post_init__(self):
21+
for idx, sym in self._id2sym.items():
22+
assert self._sym2id[sym] == idx
23+
assert idx >= 0
24+
25+
for sym, idx in self._sym2id.items():
26+
assert idx >= 0
27+
assert self._id2sym[idx] == sym
28+
29+
eps_sym = '<eps>'
30+
if 0 not in self._id2sym:
31+
self._id2sym[0] = eps_sym
32+
self._sym2id[eps_sym] = 0
33+
else:
34+
assert self._id2sym[0] == eps_sym
35+
assert self._sym2id[eps_sym] == 0
36+
37+
@staticmethod
38+
def from_str(s: str) -> 'SymbolTable':
39+
'''Build a symbol table from a string.
40+
41+
The string consists of lines. Every line has two fields separated
42+
by space(s), tab(s) or both. The first field is the symbol and the
43+
second the integer id of the symbol.
44+
45+
Args:
46+
s:
47+
The input string with the format described above.
48+
Returns:
49+
An instance of :class:`SymbolTable`.
50+
'''
51+
id2sym: Dict[int, str] = dict()
52+
sym2id: Dict[str, int] = dict()
53+
54+
for line in s.split('\n'):
55+
fields = line.split()
56+
if len(fields) == 0:
57+
continue # skip empty lines
58+
assert len(fields) == 2, \
59+
f'Expect a line with 2 fields. Given: {len(fields)}'
60+
sym, idx = fields[0], int(fields[1])
61+
assert sym not in sym2id, f'Duplicated symbol {sym}'
62+
assert idx not in id2sym, f'Duplicated id {idx}'
63+
id2sym[idx] = sym
64+
sym2id[sym] = idx
65+
66+
return SymbolTable(_id2sym=id2sym, _sym2id=sym2id)
67+
68+
@staticmethod
69+
def from_file(filename: str) -> 'SymbolTable':
70+
'''Build a symbol table from file.
71+
72+
Every line in the symbol table file has two fields separated by
73+
space(s), tab(s) or both. The following is an example file:
74+
75+
.. code-block::
76+
77+
<eps> 0
78+
a 1
79+
b 2
80+
c 3
81+
82+
Args:
83+
filename:
84+
Name of the symbol table file. Its format is documented above.
85+
86+
Returns:
87+
An instance of :class:`SymbolTable`.
88+
89+
'''
90+
with open(filename, 'r') as f:
91+
return SymbolTable.from_str(f.read().strip())
92+
93+
def get(self, k: Union[int, str]) -> Union[str, int]:
94+
'''Get a symbol for an id or get an id for a symbol
95+
96+
Args:
97+
k:
98+
If it is an id, it tries to find the symbol corresponding
99+
to the id; if it is a symbol, it tries to find the id
100+
corresponding to the symbol.
101+
102+
Returns:
103+
An id or a symbol depending on the given ``k``.
104+
'''
105+
if isinstance(k, int):
106+
return self._id2sym[k]
107+
elif isinstance(k, str):
108+
return self._sym2id(k)
109+
else:
110+
raise ValueError(f'Unsupported type {type(k)}.')

k2/python/tests/fsa_test.py

+43-12
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def test_acceptor_from_str(self):
2525
s = '''
2626
0 1 2 -1.2
2727
0 2 10 -2.2
28-
1 3 3 -3.2
29-
1 6 -1 -4.2
28+
1 6 -1 -3.2
29+
1 3 3 -4.2
3030
2 6 -1 -5.2
3131
2 4 2 -6.2
3232
3 6 -1 -7.2
@@ -39,8 +39,8 @@ def test_acceptor_from_str(self):
3939
expected_str = '''
4040
0 1 2 -1.2
4141
0 2 10 -2.2
42-
1 3 3 -3.2
43-
1 6 -1 -4.2
42+
1 6 -1 -3.2
43+
1 3 3 -4.2
4444
2 6 -1 -5.2
4545
2 4 2 -6.2
4646
3 6 -1 -7.2
@@ -53,16 +53,16 @@ def test_acceptor_from_str(self):
5353
expected_str = '''
5454
0 1 2 1.2
5555
0 2 10 2.2
56-
1 3 3 3.2
57-
1 6 -1 4.2
56+
1 6 -1 3.2
57+
1 3 3 4.2
5858
2 6 -1 5.2
5959
2 4 2 6.2
6060
3 6 -1 7.2
6161
5 0 1 8.2
6262
6
6363
'''
6464
assert _remove_leading_spaces(expected_str) == _remove_leading_spaces(
65-
fsa.to_str(negate_scores=True))
65+
fsa.to_str(openfst=True))
6666

6767
arcs = fsa.arcs
6868
assert isinstance(arcs, torch.Tensor)
@@ -94,8 +94,8 @@ def test_transducer_from_str(self):
9494
s = '''
9595
0 1 2 22 -1.2
9696
0 2 10 100 -2.2
97-
1 3 3 33 -3.2
9897
1 6 -1 16 -4.2
98+
1 3 3 33 -3.2
9999
2 6 -1 26 -5.2
100100
2 4 2 22 -6.2
101101
3 6 -1 36 -7.2
@@ -107,13 +107,13 @@ def test_transducer_from_str(self):
107107
assert fsa.aux_labels.device.type == 'cpu'
108108
assert torch.allclose(
109109
fsa.aux_labels,
110-
torch.tensor([22, 100, 33, 16, 26, 22, 36, 50], dtype=torch.int32))
110+
torch.tensor([22, 100, 16, 33, 26, 22, 36, 50], dtype=torch.int32))
111111

112112
expected_str = '''
113113
0 1 2 22 -1.2
114114
0 2 10 100 -2.2
115-
1 3 3 33 -3.2
116115
1 6 -1 16 -4.2
116+
1 3 3 33 -3.2
117117
2 6 -1 26 -5.2
118118
2 4 2 22 -6.2
119119
3 6 -1 36 -7.2
@@ -126,16 +126,47 @@ def test_transducer_from_str(self):
126126
expected_str = '''
127127
0 1 2 22 1.2
128128
0 2 10 100 2.2
129-
1 3 3 33 3.2
130129
1 6 -1 16 4.2
130+
1 3 3 33 3.2
131131
2 6 -1 26 5.2
132132
2 4 2 22 6.2
133133
3 6 -1 36 7.2
134134
5 0 1 50 8.2
135135
6
136136
'''
137137
assert _remove_leading_spaces(expected_str) == _remove_leading_spaces(
138-
fsa.to_str(negate_scores=True))
138+
fsa.to_str(openfst=True))
139+
140+
def test_symbol_table_and_dot(self):
141+
isym_str = '''
142+
<eps> 0
143+
a 1
144+
b 2
145+
c 3
146+
'''
147+
148+
osym_str = '''
149+
<eps> 0
150+
x 1
151+
y 2
152+
z 3
153+
'''
154+
isym = k2.SymbolTable.from_str(isym_str)
155+
osym = k2.SymbolTable.from_str(osym_str)
156+
157+
rules = '''
158+
0 1 1 1 0.5
159+
0 1 2 2 1.5
160+
1 2 3 3 2.5
161+
2 3 -1 0 3.5
162+
3
163+
'''
164+
fsa = k2.Fsa(_remove_leading_spaces(rules))
165+
fsa.set_isymbol(isym)
166+
fsa.set_osymbol(osym)
167+
dot = fsa.to_dot()
168+
dot.render('/tmp/fsa', format='pdf')
169+
# the fsa is saved to /tmp/fsa.pdf
139170

140171

141172
if __name__ == '__main__':

0 commit comments

Comments
 (0)