@@ -149,29 +149,24 @@ void Intersect(FsaOrVec &a_fsas, FsaOrVec &b_fsas, FsaVec *out,
149
149
150
150
Fsa LinearFsa (Array1<int32_t > &symbols) {
151
151
ContextPtr c = symbols.Context ();
152
- int32_t n = symbols.Dim (),
153
- num_states = n + 2 ,
154
- num_arcs = n + 1 ;
152
+ int32_t n = symbols.Dim (), num_states = n + 2 , num_arcs = n + 1 ;
155
153
Array1<int32_t > row_splits1 = Range (c, num_states + 1 , 0 ),
156
- row_ids1 = Range (c, num_arcs, 0 );
154
+ row_ids1 = Range (c, num_arcs, 0 );
157
155
Array1<Arc> arcs (c, num_arcs);
158
156
Arc *arcs_data = arcs.Data ();
159
157
const int32_t *symbols_data = symbols.Data ();
160
- auto lambda_set_arcs = [=] __host__ __device__ (int32_t arc_idx01) -> void {
161
- int32_t src_state = arc_idx01,
162
- dest_state = arc_idx01 + 1 ,
163
- // -1 == kFinalSymbol
164
- symbol = (arc_idx01 < n ? symbols_data[n] : -1 );
158
+ auto lambda_set_arcs = [=] __host__ __device__ (int32_t arc_idx01) -> void {
159
+ int32_t src_state = arc_idx01, dest_state = arc_idx01 + 1 ,
160
+ // -1 == kFinalSymbol
161
+ symbol = (arc_idx01 < n ? symbols_data[n] : -1 );
165
162
K2_CHECK_NE (symbol, -1 );
166
163
float score = 0.0 ;
167
164
arcs_data[arc_idx01] = Arc (src_state, dest_state, symbol, score);
168
165
};
169
166
Eval (c, num_arcs, lambda_set_arcs);
170
- return Ragged<Arc>(RaggedShape2 (&row_splits1, &row_ids1, num_arcs),
171
- arcs);
167
+ return Ragged<Arc>(RaggedShape2 (&row_splits1, &row_ids1, num_arcs), arcs);
172
168
}
173
169
174
-
175
170
Fsa LinearFsas (Ragged<int32_t > &symbols) {
176
171
K2_CHECK (symbols.NumAxes () == 2 );
177
172
ContextPtr c = symbols.Context ();
@@ -180,61 +175,59 @@ Fsa LinearFsas(Ragged<int32_t> &symbols) {
180
175
RaggedShape states_shape = ChangeSublistSize (symbols.shape , 2 );
181
176
182
177
int32_t num_states = states_shape.NumElements (),
183
- num_arcs = symbols.NumElements () + symbols.Dim0 ();
178
+ num_arcs = symbols.NumElements () + symbols.Dim0 ();
184
179
185
180
// row_splits2 maps from state_idx01 to arc_idx012; row_ids2 does the reverse.
186
181
// We'll set them in the lambda below.
187
- Array1<int32_t > row_splits2 (c, num_states + 2 ),
188
- row_ids2 (c, num_arcs);
189
-
182
+ Array1<int32_t > row_splits2 (c, num_states + 2 ), row_ids2 (c, num_arcs);
190
183
191
184
int32_t *row_ids2_data = row_ids2.Data (),
192
- *row_splits2_data = row_splits2.Data ();
185
+ *row_splits2_data = row_splits2.Data ();
193
186
const int32_t *row_ids1_data = states_shape.RowIds (1 ).Data (),
194
- *row_splits1_data = states_shape.RowSplits (1 ).Data (),
195
- *symbols_data = symbols.values .Data ();
187
+ *row_splits1_data = states_shape.RowSplits (1 ).Data (),
188
+ *symbols_data = symbols.values .Data ();
196
189
Array1<Arc> arcs (c, num_arcs);
197
190
Arc *arcs_data = arcs.Data ();
198
- auto lambda = [=] __host__ __device__ (int32_t state_idx01) -> void {
191
+ auto lambda = [=] __host__ __device__ (int32_t state_idx01) -> void {
199
192
int32_t fsa_idx0 = row_ids1_data[state_idx01],
200
- state_idx0x = row_splits1_data[fsa_idx0],
201
- next_state_idx0x = row_splits1_data[fsa_idx0 + 1 ],
202
- idx1 = state_idx01 - state_idx0x;
193
+ state_idx0x = row_splits1_data[fsa_idx0],
194
+ next_state_idx0x = row_splits1_data[fsa_idx0 + 1 ],
195
+ idx1 = state_idx01 - state_idx0x;
203
196
204
197
// the following works because each FSA has one fewer arcs than states.
205
198
int32_t arc_idx0xx = state_idx0x - fsa_idx0,
206
- next_arc_idx0xx = next_state_idx0x - (fsa_idx0 + 1 ),
207
- // the following may look a bit wrong.. here, the idx1 is the same as
208
- // the idx12 if the arc exists, because each state has one arc leaving
209
- // it (except the last state).
210
- arc_idx012 = arc_idx0xx + idx1;
199
+ next_arc_idx0xx = next_state_idx0x - (fsa_idx0 + 1 ),
200
+ // the following may look a bit wrong.. here, the idx1 is the same
201
+ // as the idx12 if the arc exists, because each state has one arc
202
+ // leaving it (except the last state).
203
+ arc_idx012 = arc_idx0xx + idx1;
211
204
// the following works because each FSA has one fewer symbols than arcs
212
205
// (however it doesn't work for the last arc of each FSA; we check below.)
213
206
int32_t symbol_idx01 = arc_idx012 - fsa_idx0;
214
207
if (arc_idx012 < next_arc_idx0xx) {
215
- int32_t src_state = idx1,
216
- dest_state = idx1 + 1 ,
217
- symbol = (arc_idx012 + 1 < next_arc_idx0xx ?
218
- symbols_data[symbol_idx01] : -1 ); // kFinalSymbol
208
+ int32_t src_state = idx1, dest_state = idx1 + 1 ,
209
+ symbol =
210
+ (arc_idx012 + 1 < next_arc_idx0xx ? symbols_data[symbol_idx01]
211
+ : -1 ); // kFinalSymbol
219
212
float score = 0.0 ;
220
213
arcs_data[arc_idx012] = Arc (src_state, dest_state, symbol, score);
221
214
row_ids2_data[arc_idx012] = state_idx01;
222
215
} else {
223
216
// The following ensures that the last element of row_splits1_data
224
- // (i.e. row_splits1[num_states]) is set to num_arcs. It also writes something
225
- // unnecessary for the last state of each FSA but the last one, which will
226
- // cause 2 threads to write the same item to the same location.
217
+ // (i.e. row_splits1[num_states]) is set to num_arcs. It also writes
218
+ // something unnecessary for the last state of each FSA but the last one,
219
+ // which will cause 2 threads to write the same item to the same location.
227
220
// Note that there is no arc with index `arc_idx01`, if you reach here.
228
- row_splits2_data[state_idx01+ 1 ] = arc_idx012;
221
+ row_splits2_data[state_idx01 + 1 ] = arc_idx012;
229
222
}
230
223
row_splits2_data[state_idx01] = arc_idx012;
231
224
};
232
225
Eval (c, num_states, lambda);
233
226
234
- return Ragged<Arc>(RaggedShape3 (&states_shape. RowSplits ( 1 ),
235
- &states_shape.RowIds (1 ), num_states ,
236
- &row_splits2, &row_splits2, num_arcs),
237
- arcs);
227
+ return Ragged<Arc>(
228
+ RaggedShape3 ( &states_shape.RowSplits (1 ), &states_shape. RowIds ( 1 ) ,
229
+ num_states, &row_splits2, &row_splits2, num_arcs),
230
+ arcs);
238
231
}
239
232
240
233
namespace {
0 commit comments