Skip to content

Commit 071b0b3

Browse files
qindazhudanpovey
andauthored
Introduce Array to template-ize current code
* Added interface for auxiliary labels * Add notes on Python interface * Fix typos * Fix conflicts, remove some typedefs * Fixes from review * Small fixes in determinization code * Progress on determinization code; add new declarations of un-pruned functions * Fix compile error * Resolve conflicts * Add LogAdd * Fix compile errors in util.h * More progress on determinization code * More progress on determinizaton draft. * More work on Determinize code. * Draft of ConstFsa interface and CfsaVec * Small fixes to ConstFsa interface * Changes from review * Fix style issues * Add itf for DenseFsa (not compiled) * Draft of Array2/Array3 * Some notes on how this would work in Python * Fix conflict * [src] More drafts in array stuff, RE interface of functions. * Further changes * merge Dan's PR about Array2 Co-authored-by: Daniel Povey <[email protected]>
1 parent 5565bd4 commit 071b0b3

File tree

5 files changed

+255
-52
lines changed

5 files changed

+255
-52
lines changed

k2/csrc/array.h

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// k2/csrc/array.h
2+
3+
// Copyright (c) 2020 Xiaomi Corporation (author: Daniel Povey)
4+
5+
// See ../../LICENSE for clarification regarding multiple authors
6+
7+
#ifndef K2_CSRC_ARRAY_H_
8+
#define K2_CSRC_ARRAY_H_
9+
10+
#include <functional>
11+
#include <limits>
12+
#include <memory>
13+
#include <vector>
14+
15+
namespace k2 {
16+
17+
/*
18+
We will use e.g. StridedPtr<int32_t, T> when the stride is not 1, and
19+
otherwise just T* (which presumably be faster).
20+
*/
21+
template <typename I, typename T>
22+
struct StridedPtr {
23+
T *data;
24+
I stride;
25+
T &operator[](I i) { return data[i]; }
26+
StridedPtr(T *data, I stride) : data(data), stride(stride) {}
27+
};
28+
29+
/* MIGHT NOT NEED THIS */
30+
template <typename I, typename Ptr>
31+
struct Array1 {
32+
// Irregular two dimensional array of something, like vector<vector<X> >
33+
// where Ptr is, or behaves like, X*.
34+
using IndexT = I;
35+
using PtrT = Ptr;
36+
37+
// 'begin' and 'end' are the first and one-past-the-last indexes into `data`
38+
// that we are allowed to use.
39+
IndexT begin;
40+
IndexT end;
41+
42+
PtrT data;
43+
};
44+
45+
/*
46+
This struct stores the size of an Array2 object; it will generally be used as
47+
an output argument by functions that work out this size.
48+
*/
49+
template <typename I>
50+
struct Array2Size {
51+
using IndexT = I;
52+
// `size1` is the top-level size of the array, equal to the object's .size
53+
// element
54+
I size1;
55+
// `size2` is the nunber of elements in the array, equal to
56+
// o->indexes[o->size] - o->indexes[0] (if the Array2 object o is
57+
// initialized).
58+
I size2;
59+
};
60+
61+
template <typename I, typename Ptr>
62+
struct Array2 {
63+
// Irregular two dimensional array of something, like vector<vector<X> >
64+
// where Ptr is, or behaves like, X*.
65+
using IndexT = I;
66+
using PtrT = Ptr;
67+
68+
IndexT size;
69+
const IndexT *indexes; // indexes[0,1,...size] should be defined; note, this
70+
// means the array must be of at least size+1. We
71+
// require that indexes[i] <= indexes[i+1], but it is
72+
// not required that indexes[0] == 0, it may be
73+
// greater than 0.
74+
75+
PtrT data; // `data` might be an actual pointer, or might be some object
76+
// supporting operator []. data[indexes[0]] through
77+
// data[indexes[size] - 1] must be accessible through this
78+
// object.
79+
80+
/* initialized definition:
81+
82+
An Array2 object is initialized if its `size` member is set and its
83+
`indexes` and `data` pointer allocated, and the values of its `indexes`
84+
array are set for indexes[0] and indexes[size].
85+
*/
86+
};
87+
88+
template <typename I, typename Ptr>
89+
struct Array3 {
90+
// Irregular three dimensional array of something, like vector<vector<vetor<X>
91+
// > > where Ptr is or behaves like X*.
92+
using IndexT = I;
93+
using PtrT = Ptr;
94+
95+
IndexT size;
96+
const IndexT *indexes1; // indexes1[0,1,...size] should be defined; note,
97+
// this means the array must be of at least size+1.
98+
// We require that indexes[i] <= indexes[i+1], but it
99+
// is not required that indexes[0] == 0, it may be
100+
// greater than 0.
101+
102+
const IndexT *indexes2; // indexes2[indexes1[0]]
103+
// .. indexes2[indexes1[size]-1] should be defined.
104+
105+
Ptr data; // `data` might be an actual pointer, or might be some object
106+
// supporting operator []. data[indexes[0]] through
107+
// data[indexes[size] - 1] must be accessible through this
108+
// object.
109+
110+
Array2<I, Ptr> operator[](I i) {
111+
// TODO(haowen): fill real data here
112+
Array2<I, Ptr> array;
113+
return array;
114+
}
115+
};
116+
117+
// Note: we can create Array4 later if we need it.
118+
119+
} // namespace k2
120+
121+
#endif // K2_CSRC_ARRAY_H_

k2/csrc/aux_labels.h

+37
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <vector>
1111

12+
#include "k2/csrc/array.h"
1213
#include "k2/csrc/fsa.h"
1314
#include "k2/csrc/fsa_util.h"
1415
#include "k2/csrc/properties.h"
@@ -46,6 +47,9 @@ struct AuxLabels {
4647
std::vector<int32_t> labels;
4748
};
4849

50+
// TODO(haowen): replace AuxLabels above with below definition
51+
using AuxLabels_ = Array2<int32_t, int32_t>;
52+
4953
// Swap AuxLabels; it's cheap to to this as we are actually doing shallow swap.
5054
void Swap(AuxLabels *labels1, AuxLabels *labels2);
5155

@@ -95,6 +99,39 @@ void MapAuxLabels2(const AuxLabels &labels_in,
9599
void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out,
96100
AuxLabels *aux_labels_out);
97101

102+
class FstInverter {
103+
/* Constructor. Lightweight. */
104+
FstInverter(const Fsa &fsa_in, const AuxLabels &labels_in);
105+
106+
/*
107+
Do enough work that know now much memory will be needed, and output
108+
that information
109+
@param [out] fsa_size The num-states and num-arcs of the FSA
110+
will be written to here
111+
@param [out] aux_size The number of lists in the AuxLabels
112+
output (==num-arcs) and the number of
113+
elements will be written to here.
114+
*/
115+
void GetSizes(Array2Size<int32_t> *fsa_size, Array2Size<int32_t> *aux_size);
116+
117+
/*
118+
Finish the operation and output inverted FSA to `fsa_out` and
119+
auxiliary labels to `labels_out`.
120+
@param [out] fsa_out The inverted FSA will be written to
121+
here. Must be initialized; search for
122+
'initialized definition' in class Array2
123+
in array.h for meaning.
124+
@param [out] labels_out The auxiliary labels will be written to
125+
here. Must be initialized; search for
126+
'initialized definition' in class Array2
127+
in array.h for meaning.
128+
*/
129+
void GetOutput(Fsa *fsa_out, AuxLabels *labels_out);
130+
131+
private:
132+
// ...
133+
};
134+
98135
} // namespace k2
99136

100137
#endif // K2_CSRC_AUX_LABELS_H_

k2/csrc/fsa.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <vector>
1414

1515
#include "glog/logging.h"
16+
#include "k2/csrc/array.h"
1617
#include "k2/csrc/util.h"
1718

1819
namespace k2 {
@@ -129,6 +130,10 @@ struct Fsa {
129130
}
130131
};
131132

133+
// TODO(haowen): replace Cfsa and CfsaVec with below definitions
134+
using Cfsa_ = Array2<int32_t, Arc>;
135+
using CfsaVec_ = Array3<int32_t, Arc>;
136+
132137
/*
133138
Cfsa is a 'const' FSA, which we'll use as the input to operations. It is
134139
designed in such a way that the storage underlying it may either be an Fsa
@@ -157,7 +162,7 @@ struct Cfsa {
157162
// are valid. CAUTION: arc_indexes[0] may be
158163
// greater than zero.
159164

160-
Arc *arcs; // Note: arcs[BeginArcIndex()] through arcs[EndArcIndex() - 1]
165+
Arc *arcs; // Note: arcs[begin_arc] through arcs[end_arc - 1]
161166
// are valid.
162167

163168
Cfsa();

notes/array.txt

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
3+
4+
5+
6+
# defining type k2.Array
7+
8+
9+
class Array:
10+
11+
# `indexes` is a Tensor with one
12+
Tensor indexes;
13+
14+
# `data` is either:
15+
# - of type Tensor (if this corresponds to Array2 == 2-dimensional
16+
# array in C++)
17+
# - of type Array (if this corresponds to Array3 or higher-dimensional
18+
# array in C++)
19+
# The Python code is structured a bit differently from the C++ code,
20+
# due to the differences in the languages.
21+
# When we dispatch things to C++ code there would be some
22+
# big switch statement or if-statement to select the right
23+
# template instantiation.
24+
data;
25+
26+
def __len__(self):
27+
return indexes.shape[0] - 1
28+
29+
@property
30+
def shape(self):
31+
# e.g. if indexes.shape is (15,) and
32+
# data.shape is (150) -> this.shape would be (15,None)
33+
# If data.shape is (150,4), this.shape would be (15,4)
34+
# If data.shape is (150,None) (since data is an Array), this.shape
35+
# would be (150,None,None).
36+
# The Nones are for dimensions where the shape is not known
37+
# because it is variable.
38+
return (indexes.shape[0] - 1, None, *data.shape[1:])
39+
40+
41+
42+
class Fsa(Array):
43+
44+
# Think of this as a vector of vector of Arc, or in C++,
45+
# an Array2<Arc>.
46+
# An Arc has 3 int32_t's, so this.data is a Tensor with
47+
# dtype int32 and shape (_, 3).
48+
49+
50+
51+
class FsaVec(Array):
52+
53+
# Think of this as a vector of vector of vector of Arc, or in C++,
54+
# an Array3<Arc>.
55+
#
56+
# this.data is an Array, and this.data.data is a Tensor with
57+
# dtype int32 and shape (_, 3).

notes/python.txt

+34-51
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,19 @@
207207
# note: fsas_det still has `phone_syms` and `nnet_arc_indexes`, now as sequences.
208208

209209

210+
def PseudoTensor:
211+
"""This is a class that behaves like a torch.Tensor but in fact only supports one kind of
212+
operation, namely indexing with another torch.Tensor"""
213+
def __init__(self, t, divisor):
214+
""" Constructor.
215+
Parameters:
216+
t: torch.LongTensor
217+
divisor: int
218+
"""
219+
self.t = t
220+
self.divisor = divisor
221+
def __getitem__(self, indexes):
222+
return self.t[indexes / divisor]
210223

211224
def DenseFsaVec:
212225

@@ -246,66 +259,36 @@ def DenseFsaVec:
246259
# loglikes, one per arc of the CfsaVec object. This is
247260
# a repeat of `loglikes` but possibly in a different
248261
# order.
262+
263+
249264
pass
250265

251266
@property
252267
def loglikes(self):
253268
return self.arc_loglikes
254269

255270

256-
def seg_frames_for_arcs(self, arc_indexes):
257-
"""
258-
Returns the frame-indexes relative to the start of each segment
259-
for each of a provided list of arc indexes, as a torch.LongTensor.
260-
"""
261-
262-
# Note: self.seg_frame_indexes will be a torch.IntTensor containing
263-
# the frame index for each arc. Later we'll address not being
264-
# able to index with IntTensor but only LongTensor.
265-
return self.seg_frame_indexes[arc_indexes / self.num_symbols]
266-
267-
def seq_frames_for_arcs(self, arc_indexes):
268-
"""
269-
Returns the frame-indexes relative to the start of each sequence
270-
for each of a provided list of arc indexes, as a torch.LongTensor.
271-
272-
Note: if a returned frame-index equals num_frames, then that
273-
frame was a `final-arc` (a special arc going to the final state),
274-
which cannot be used to index the `loglikes` array provided to
275-
the constructor because it's out-of-range.
276-
"""
277-
278-
# Note: self.seq_frame_indexes will be a torch.IntTensor containing
279-
# the frame index for each arc. Later we'll address not being
280-
# able to index with IntTensor but only LongTensor.
281-
return self.seq_frame_indexes[arc_indexes / self.num_symbols]
282-
283-
def segments_for_arcs(self, arc_indexes):
284-
"""
285-
Return the segment-indexes for each of a provided list of arcs,
286-
which tells you which segment it was a part of.
287-
"""
288-
return self.segment_indexes[arc_indexes / self.num_symbols]
289-
290-
def seqs_for_arcs(self, arc_indexes):
291-
"""
292-
Return the segment-indexes for each of a provided list of arcs,
293-
which tells you which sequence it was a part of.
271+
@property
272+
def seg_frames(self):
273+
"""Return something that 'acts' like a tensor, indexed by arc, of
274+
the frame-index relative to the segment start corresponding to that
275+
arc. NOTE: self.frame_loglikes will actually be a sub-Tensor
276+
of the Tensor created at the C++ level as the DenseFsaVecMeta object.
294277
"""
295-
return self.input_seq_indexes[self.segments_for_arcs(arc_indexes)]
296-
297-
298-
299-
300-
301-
302-
# compute posteriors..
303-
first_pass_posts =
304-
305-
306-
278+
return PseudoTensor(self.frame_loglikes, self.num_symbols)
307279

308280

281+
@property
282+
def seq_frames(self):
283+
""" as for seg_frames"""
284+
pass
309285

286+
@property
287+
def seq_ids(self):
288+
""" as for seg_frames"""
289+
pass
310290

311-
nnet_post = log_softmax(nnet_output) # might use this later for something..
291+
@property
292+
def seg_ids(self):
293+
""" as for seg_frames"""
294+
pass

0 commit comments

Comments
 (0)