@@ -147,6 +147,96 @@ void Intersect(FsaOrVec &a_fsas, FsaOrVec &b_fsas, FsaVec *out,
147
147
*out = creator.GetFsaVec ();
148
148
}
149
149
150
+ Fsa LinearFsa (Array1<int32_t > &symbols) {
151
+ ContextPtr c = symbols.Context ();
152
+ int32_t n = symbols.Dim (),
153
+ num_states = n + 2 ,
154
+ num_arcs = n + 1 ;
155
+ Array1<int32_t > row_splits1 = Range (c, num_states + 1 , 0 ),
156
+ row_ids1 = Range (c, num_arcs, 0 );
157
+ Array1<Arc> arcs (c, num_arcs);
158
+ Arc *arcs_data = arcs.Data ();
159
+ const int32_t *symbols_data = symbols.Data ();
160
+ auto lambda_set_arcs = [=] __host__ __device__ (int32_t arc_idx01) -> void {
161
+ int32_t src_state = arc_idx01,
162
+ dest_state = arc_idx01 + 1 ,
163
+ // -1 == kFinalSymbol
164
+ symbol = (arc_idx01 < n ? symbols_data[n] : -1 );
165
+ K2_CHECK_NE (symbol, -1 );
166
+ float score = 0.0 ;
167
+ arcs_data[arc_idx01] = Arc (src_state, dest_state, symbol, score);
168
+ };
169
+ Eval (c, num_arcs, lambda_set_arcs);
170
+ return Ragged<Arc>(RaggedShape2 (&row_splits1, &row_ids1, num_arcs),
171
+ arcs);
172
+ }
173
+
174
+
175
+ Fsa LinearFsas (Ragged<int32_t > &symbols) {
176
+ K2_CHECK (symbols.NumAxes () == 2 );
177
+ ContextPtr c = symbols.Context ();
178
+
179
+ // if there are n symbols, there are n+2 states and n+1 arcs.
180
+ RaggedShape states_shape = ChangeSublistSize (symbols.shape , 2 );
181
+
182
+ int32_t num_states = states_shape.NumElements (),
183
+ num_arcs = symbols.NumElements () + symbols.Dim0 ();
184
+
185
+ // row_splits2 maps from state_idx01 to arc_idx012; row_ids2 does the reverse.
186
+ // We'll set them in the lambda below.
187
+ Array1<int32_t > row_splits2 (c, num_states + 2 ),
188
+ row_ids2 (c, num_arcs);
189
+
190
+
191
+ int32_t *row_ids2_data = row_ids2.Data (),
192
+ *row_splits2_data = row_splits2.Data ();
193
+ const int32_t *row_ids1_data = states_shape.RowIds (1 ).Data (),
194
+ *row_splits1_data = states_shape.RowSplits (1 ).Data (),
195
+ *symbols_data = symbols.values .Data ();
196
+ Array1<Arc> arcs (c, num_arcs);
197
+ Arc *arcs_data = arcs.Data ();
198
+ auto lambda = [=] __host__ __device__ (int32_t state_idx01) -> void {
199
+ int32_t fsa_idx0 = row_ids1_data[state_idx01],
200
+ state_idx0x = row_splits1_data[fsa_idx0],
201
+ next_state_idx0x = row_splits1_data[fsa_idx0 + 1 ],
202
+ idx1 = state_idx01 - state_idx0x;
203
+
204
+ // the following works because each FSA has one fewer arcs than states.
205
+ int32_t arc_idx0xx = state_idx0x - fsa_idx0,
206
+ next_arc_idx0xx = next_state_idx0x - (fsa_idx0 + 1 ),
207
+ // the following may look a bit wrong.. here, the idx1 is the same as
208
+ // the idx12 if the arc exists, because each state has one arc leaving
209
+ // it (except the last state).
210
+ arc_idx012 = arc_idx0xx + idx1;
211
+ // the following works because each FSA has one fewer symbols than arcs
212
+ // (however it doesn't work for the last arc of each FSA; we check below.)
213
+ int32_t symbol_idx01 = arc_idx012 - fsa_idx0;
214
+ if (arc_idx012 < next_arc_idx0xx) {
215
+ int32_t src_state = idx1,
216
+ dest_state = idx1 + 1 ,
217
+ symbol = (arc_idx012 + 1 < next_arc_idx0xx ?
218
+ symbols_data[symbol_idx01] : -1 ); // kFinalSymbol
219
+ float score = 0.0 ;
220
+ arcs_data[arc_idx012] = Arc (src_state, dest_state, symbol, score);
221
+ row_ids2_data[arc_idx012] = state_idx01;
222
+ } else {
223
+ // The following ensures that the last element of row_splits1_data
224
+ // (i.e. row_splits1[num_states]) is set to num_arcs. It also writes something
225
+ // unnecessary for the last state of each FSA but the last one, which will
226
+ // cause 2 threads to write the same item to the same location.
227
+ // Note that there is no arc with index `arc_idx01`, if you reach here.
228
+ row_splits2_data[state_idx01+1 ] = arc_idx012;
229
+ }
230
+ row_splits2_data[state_idx01] = arc_idx012;
231
+ };
232
+ Eval (c, num_states, lambda);
233
+
234
+ return Ragged<Arc>(RaggedShape3 (&states_shape.RowSplits (1 ),
235
+ &states_shape.RowIds (1 ), num_states,
236
+ &row_splits2, &row_splits2, num_arcs),
237
+ arcs);
238
+ }
239
+
150
240
namespace {
151
241
struct ArcComparer {
152
242
__host__ __device__ __forceinline__ bool operator ()(const Arc &lhs,
0 commit comments