Skip to content

Commit 471da78

Browse files
authored
Add ChangeSublistSize() and LinearFsas() (k2-fsa#265)
1 parent b32c893 commit 471da78

File tree

5 files changed

+218
-1
lines changed

5 files changed

+218
-1
lines changed

k2/csrc/context.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ class ParallelRunner {
544544
// destructor of `this` You can pass this into the Eval() and Eval2()
545545
// functions, or invoke kernels directly with it; but if you want it
546546
// to be used in called functions you should do something like
547-
// With(pr.NewStream) w;
547+
// With w(pr.NewStream());
548548
// with that object alive in the scope where you want the stream to be
549549
// used.
550550
//

k2/csrc/fsa_algo.cu

+90
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,96 @@ void Intersect(FsaOrVec &a_fsas, FsaOrVec &b_fsas, FsaVec *out,
147147
*out = creator.GetFsaVec();
148148
}
149149

150+
Fsa LinearFsa(Array1<int32_t> &symbols) {
151+
ContextPtr c = symbols.Context();
152+
int32_t n = symbols.Dim(),
153+
num_states = n + 2,
154+
num_arcs = n + 1;
155+
Array1<int32_t> row_splits1 = Range(c, num_states + 1, 0),
156+
row_ids1 = Range(c, num_arcs, 0);
157+
Array1<Arc> arcs(c, num_arcs);
158+
Arc *arcs_data = arcs.Data();
159+
const int32_t *symbols_data = symbols.Data();
160+
auto lambda_set_arcs = [=] __host__ __device__ (int32_t arc_idx01) -> void {
161+
int32_t src_state = arc_idx01,
162+
dest_state = arc_idx01 + 1,
163+
// -1 == kFinalSymbol
164+
symbol = (arc_idx01 < n ? symbols_data[n] : -1);
165+
K2_CHECK_NE(symbol, -1);
166+
float score = 0.0;
167+
arcs_data[arc_idx01] = Arc(src_state, dest_state, symbol, score);
168+
};
169+
Eval(c, num_arcs, lambda_set_arcs);
170+
return Ragged<Arc>(RaggedShape2(&row_splits1, &row_ids1, num_arcs),
171+
arcs);
172+
}
173+
174+
175+
Fsa LinearFsas(Ragged<int32_t> &symbols) {
176+
K2_CHECK(symbols.NumAxes() == 2);
177+
ContextPtr c = symbols.Context();
178+
179+
// if there are n symbols, there are n+2 states and n+1 arcs.
180+
RaggedShape states_shape = ChangeSublistSize(symbols.shape, 2);
181+
182+
int32_t num_states = states_shape.NumElements(),
183+
num_arcs = symbols.NumElements() + symbols.Dim0();
184+
185+
// row_splits2 maps from state_idx01 to arc_idx012; row_ids2 does the reverse.
186+
// We'll set them in the lambda below.
187+
Array1<int32_t> row_splits2(c, num_states + 2),
188+
row_ids2(c, num_arcs);
189+
190+
191+
int32_t *row_ids2_data = row_ids2.Data(),
192+
*row_splits2_data = row_splits2.Data();
193+
const int32_t *row_ids1_data = states_shape.RowIds(1).Data(),
194+
*row_splits1_data = states_shape.RowSplits(1).Data(),
195+
*symbols_data = symbols.values.Data();
196+
Array1<Arc> arcs(c, num_arcs);
197+
Arc *arcs_data = arcs.Data();
198+
auto lambda = [=] __host__ __device__ (int32_t state_idx01) -> void {
199+
int32_t fsa_idx0 = row_ids1_data[state_idx01],
200+
state_idx0x = row_splits1_data[fsa_idx0],
201+
next_state_idx0x = row_splits1_data[fsa_idx0 + 1],
202+
idx1 = state_idx01 - state_idx0x;
203+
204+
// the following works because each FSA has one fewer arcs than states.
205+
int32_t arc_idx0xx = state_idx0x - fsa_idx0,
206+
next_arc_idx0xx = next_state_idx0x - (fsa_idx0 + 1),
207+
// the following may look a bit wrong.. here, the idx1 is the same as
208+
// the idx12 if the arc exists, because each state has one arc leaving
209+
// it (except the last state).
210+
arc_idx012 = arc_idx0xx + idx1;
211+
// the following works because each FSA has one fewer symbols than arcs
212+
// (however it doesn't work for the last arc of each FSA; we check below.)
213+
int32_t symbol_idx01 = arc_idx012 - fsa_idx0;
214+
if (arc_idx012 < next_arc_idx0xx) {
215+
int32_t src_state = idx1,
216+
dest_state = idx1 + 1,
217+
symbol = (arc_idx012 + 1 < next_arc_idx0xx ?
218+
symbols_data[symbol_idx01] : -1); // kFinalSymbol
219+
float score = 0.0;
220+
arcs_data[arc_idx012] = Arc(src_state, dest_state, symbol, score);
221+
row_ids2_data[arc_idx012] = state_idx01;
222+
} else {
223+
// The following ensures that the last element of row_splits1_data
224+
// (i.e. row_splits1[num_states]) is set to num_arcs. It also writes something
225+
// unnecessary for the last state of each FSA but the last one, which will
226+
// cause 2 threads to write the same item to the same location.
227+
// Note that there is no arc with index `arc_idx01`, if you reach here.
228+
row_splits2_data[state_idx01+1] = arc_idx012;
229+
}
230+
row_splits2_data[state_idx01] = arc_idx012;
231+
};
232+
Eval(c, num_states, lambda);
233+
234+
return Ragged<Arc>(RaggedShape3(&states_shape.RowSplits(1),
235+
&states_shape.RowIds(1), num_states,
236+
&row_splits2, &row_splits2, num_arcs),
237+
arcs);
238+
}
239+
150240
namespace {
151241
struct ArcComparer {
152242
__host__ __device__ __forceinline__ bool operator()(const Arc &lhs,

k2/csrc/fsa_algo.h

+29
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,35 @@ void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, float beam,
131131
void Intersect(FsaOrVec &a_fsas, FsaOrVec &b_fsas, FsaVec *out,
132132
Array1<int32_t> *arc_map_a, Array1<int32_t> *arc_map_b);
133133

134+
/*
135+
Create a linear FSA from a sequence of symbols
136+
137+
@param [in] symbols Input symbol sequence (must not contain
138+
kFinalSymbol == -1).
139+
140+
@return Returns an FSA that accepts only this symbol
141+
sequence, with zero score. Note: if
142+
`symbols.size() == n`, the returned FSA
143+
will have n+1 arcs (including the final-arc) and
144+
n+2 states.
145+
*/
146+
Fsa LinearFsa(Array1<int32_t> &symbols);
147+
148+
/*
149+
Create an FsaVec contining linear FSAs, given a list of sequences of
150+
symbols
151+
152+
@param [in] symbols Input symbol sequences (must not contain
153+
kFinalSymbol == -1).
154+
155+
@return Returns an FsaVec with `ans.Dim0() == symbols.Dim0()`. Note: if
156+
the i'th row of `symbols` has n elements, the i'th returned FSA
157+
will have n+1 arcs (including the final-arc) and n+2 states.
158+
*/
159+
Fsa LinearFsas(Ragged<int32_t> &symbols);
160+
161+
162+
134163
} // namespace k2
135164

136165
#endif // K2_CSRC_FSA_ALGO_H_

k2/csrc/ragged_ops.cu

+72
Original file line numberDiff line numberDiff line change
@@ -939,4 +939,76 @@ Array1<int32_t> GetTransposeReordering(Ragged<int32_t> &src, int32_t num_cols) {
939939
return ans;
940940
}
941941

942+
RaggedShape ChangeSublistSize(RaggedShape &src, int32_t size_delta) {
943+
K2_CHECK(src.NumAxes() >= 2);
944+
// the result will have the same num-axes as `src` (the NumAxes() of the
945+
// object is not the same as the number of RaggedShapeDim axes).
946+
std::vector<RaggedShapeDim> ans_axes(src.NumAxes() - 1);
947+
int32_t last_axis = src.NumAxes() - 1;
948+
// The following will only do something if src.NumAxes() > 2.
949+
for (int32_t i = 0; i + 1 < last_axis; i++)
950+
ans_axes[i] = src.Axes()[i];
951+
952+
ContextPtr c = src.Context();
953+
int32_t num_rows = src.TotSize(last_axis - 1),
954+
src_num_elems = src.TotSize(last_axis),
955+
num_elems = src_num_elems + size_delta * num_rows;
956+
ans_axes[0].row_splits = Array1<int32_t>(c, num_rows + 1);
957+
ans_axes[0].row_ids = Array1<int32_t>(c, num_elems);
958+
ans_axes[0].cached_tot_size = num_elems;
959+
const int32_t *src_row_splits_data = src.RowSplits(last_axis).Data(),
960+
*src_row_ids_data = src.RowIds(last_axis).Data();
961+
int32_t *row_splits_data = ans_axes[0].row_splits.Data(),
962+
*row_ids_data = ans_axes[0].row_ids.Data();
963+
964+
{
965+
ParallelRunner pr(c);
966+
{
967+
With w(pr.NewStream());
968+
auto lambda_set_row_splits = [=] __host__ __device__ (int32_t idx0) -> void {
969+
row_splits_data[idx0] = src_row_splits_data[idx0] + size_delta * idx0;
970+
};
971+
Eval(c, num_rows + 1, lambda_set_row_splits);
972+
}
973+
974+
{
975+
With w(pr.NewStream());
976+
auto lambda_set_row_ids1 = [=] __host__ __device__ (int32_t src_idx01) -> void {
977+
int32_t src_idx0 = src_row_ids_data[src_idx01],
978+
src_idx0x = src_row_splits_data[src_idx0],
979+
src_idx1 = src_idx01 - src_idx0x,
980+
new_idx0x = row_splits_data[src_idx0],
981+
new_idx0x_next = row_splits_data[src_idx0 + 1],
982+
new_idx01 = new_idx0x + src_idx1;
983+
// it's only necessary to guard the next statement with in 'if' because
984+
// size_delta might be negative.
985+
if (new_idx01 < new_idx0x_next)
986+
row_ids_data[new_idx01] = src_idx0;
987+
};
988+
Eval(c, num_elems, lambda_set_row_ids1);
989+
}
990+
if (size_delta > 0) {
991+
// This sets the row-ids that are not set by lambda_set_row_ids1.
992+
With w(pr.NewStream());
993+
auto lambda_set_row_ids2 = [=] __host__ __device__ (int32_t i) -> void {
994+
int32_t idx0 = i / size_delta, n = i % size_delta,
995+
next_idx0 = idx0 + 1;
996+
// The following formula is the same as the one in
997+
// lambda_set_row_splits; we want to compute the new value of
998+
// row_splits_data[next_idx0] without waiting for that kernel to
999+
// terminate.
1000+
int32_t next_idx0x = src_row_splits_data[next_idx0] +
1001+
size_delta * next_idx0;
1002+
row_ids_data[next_idx0x - 1 - n] = idx0;
1003+
};
1004+
Eval(c, num_rows * size_delta, lambda_set_row_ids2);
1005+
}
1006+
// make the ParallelRunner go out of scope (should do this before any
1007+
// validation code that gets invoked by the constructor of RaggedShape
1008+
// below).
1009+
}
1010+
return RaggedShape(ans_axes);
1011+
}
1012+
1013+
9421014
} // namespace k2

k2/csrc/ragged_ops.h

+26
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,32 @@ void SortSublists(Ragged<T> &src, Array1<int32_t> *order);
139139
*/
140140
RaggedShape Stack(int32_t axis, int32_t src_size, RaggedShape **src);
141141

142+
143+
/*
144+
Return a modified version of `src` in which all sub-lists on the last axis of
145+
the tenor have size modified by `size_delta`. `size_delta` may have either
146+
sign. If for a sub-list of size `cur_size`, `cur_size - size_delta < 0`, that
147+
sub-list's size will be changed to 0 but the sub-list will be kept.
148+
149+
150+
@param [in] src Source tensor; must have NumAxes() >= 2, i.e. be valid.
151+
Only the last axis, i.e. the last RowSplits/RowIds(),
152+
will be affected by this.
153+
@param [in] size_delta Amount by which to change the size of sub-lists.
154+
May be either sign; if negative, we'll reduce the
155+
sub-list size by this amount, possibly leaving empty
156+
sub-lists (but it's an error if this would reduce any sub-list
157+
size below zero).
158+
@return Returns the modified RaggedShape. The RowSplits()
159+
and RowIds() of its last axis will not be shared
160+
with `src`.
161+
162+
Example: ChangeSubListSize( [ [ x x ] [ x x x ] ], 1) returns
163+
[ [ x x x ] [ x x x x ] ]
164+
(using the x as placeholders for the values since these are unknown).
165+
*/
166+
RaggedShape ChangeSublistSize(RaggedShape &src, int32_t size_delta);
167+
142168
/*
143169
Insert a new axis at position `axis`, with 0 <= axis <= src.NumAxes(), for
144170
which the only allowed index will be 0 (which is another way of saying: all

0 commit comments

Comments
 (0)