Skip to content

Commit 5e7fce9

Browse files
authored
Fixed bug in GetStatesBatch (complex branch) (k2-fsa#266)
1 parent 471da78 commit 5e7fce9

File tree

3 files changed

+64
-70
lines changed

3 files changed

+64
-70
lines changed

k2/csrc/fsa_algo.cu

+33-40
Original file line numberDiff line numberDiff line change
@@ -149,29 +149,24 @@ void Intersect(FsaOrVec &a_fsas, FsaOrVec &b_fsas, FsaVec *out,
149149

150150
Fsa LinearFsa(Array1<int32_t> &symbols) {
151151
ContextPtr c = symbols.Context();
152-
int32_t n = symbols.Dim(),
153-
num_states = n + 2,
154-
num_arcs = n + 1;
152+
int32_t n = symbols.Dim(), num_states = n + 2, num_arcs = n + 1;
155153
Array1<int32_t> row_splits1 = Range(c, num_states + 1, 0),
156-
row_ids1 = Range(c, num_arcs, 0);
154+
row_ids1 = Range(c, num_arcs, 0);
157155
Array1<Arc> arcs(c, num_arcs);
158156
Arc *arcs_data = arcs.Data();
159157
const int32_t *symbols_data = symbols.Data();
160-
auto lambda_set_arcs = [=] __host__ __device__ (int32_t arc_idx01) -> void {
161-
int32_t src_state = arc_idx01,
162-
dest_state = arc_idx01 + 1,
163-
// -1 == kFinalSymbol
164-
symbol = (arc_idx01 < n ? symbols_data[n] : -1);
158+
auto lambda_set_arcs = [=] __host__ __device__(int32_t arc_idx01) -> void {
159+
int32_t src_state = arc_idx01, dest_state = arc_idx01 + 1,
160+
// -1 == kFinalSymbol
161+
symbol = (arc_idx01 < n ? symbols_data[n] : -1);
165162
K2_CHECK_NE(symbol, -1);
166163
float score = 0.0;
167164
arcs_data[arc_idx01] = Arc(src_state, dest_state, symbol, score);
168165
};
169166
Eval(c, num_arcs, lambda_set_arcs);
170-
return Ragged<Arc>(RaggedShape2(&row_splits1, &row_ids1, num_arcs),
171-
arcs);
167+
return Ragged<Arc>(RaggedShape2(&row_splits1, &row_ids1, num_arcs), arcs);
172168
}
173169

174-
175170
Fsa LinearFsas(Ragged<int32_t> &symbols) {
176171
K2_CHECK(symbols.NumAxes() == 2);
177172
ContextPtr c = symbols.Context();
@@ -180,61 +175,59 @@ Fsa LinearFsas(Ragged<int32_t> &symbols) {
180175
RaggedShape states_shape = ChangeSublistSize(symbols.shape, 2);
181176

182177
int32_t num_states = states_shape.NumElements(),
183-
num_arcs = symbols.NumElements() + symbols.Dim0();
178+
num_arcs = symbols.NumElements() + symbols.Dim0();
184179

185180
// row_splits2 maps from state_idx01 to arc_idx012; row_ids2 does the reverse.
186181
// We'll set them in the lambda below.
187-
Array1<int32_t> row_splits2(c, num_states + 2),
188-
row_ids2(c, num_arcs);
189-
182+
Array1<int32_t> row_splits2(c, num_states + 2), row_ids2(c, num_arcs);
190183

191184
int32_t *row_ids2_data = row_ids2.Data(),
192-
*row_splits2_data = row_splits2.Data();
185+
*row_splits2_data = row_splits2.Data();
193186
const int32_t *row_ids1_data = states_shape.RowIds(1).Data(),
194-
*row_splits1_data = states_shape.RowSplits(1).Data(),
195-
*symbols_data = symbols.values.Data();
187+
*row_splits1_data = states_shape.RowSplits(1).Data(),
188+
*symbols_data = symbols.values.Data();
196189
Array1<Arc> arcs(c, num_arcs);
197190
Arc *arcs_data = arcs.Data();
198-
auto lambda = [=] __host__ __device__ (int32_t state_idx01) -> void {
191+
auto lambda = [=] __host__ __device__(int32_t state_idx01) -> void {
199192
int32_t fsa_idx0 = row_ids1_data[state_idx01],
200-
state_idx0x = row_splits1_data[fsa_idx0],
201-
next_state_idx0x = row_splits1_data[fsa_idx0 + 1],
202-
idx1 = state_idx01 - state_idx0x;
193+
state_idx0x = row_splits1_data[fsa_idx0],
194+
next_state_idx0x = row_splits1_data[fsa_idx0 + 1],
195+
idx1 = state_idx01 - state_idx0x;
203196

204197
// the following works because each FSA has one fewer arcs than states.
205198
int32_t arc_idx0xx = state_idx0x - fsa_idx0,
206-
next_arc_idx0xx = next_state_idx0x - (fsa_idx0 + 1),
207-
// the following may look a bit wrong.. here, the idx1 is the same as
208-
// the idx12 if the arc exists, because each state has one arc leaving
209-
// it (except the last state).
210-
arc_idx012 = arc_idx0xx + idx1;
199+
next_arc_idx0xx = next_state_idx0x - (fsa_idx0 + 1),
200+
// the following may look a bit wrong.. here, the idx1 is the same
201+
// as the idx12 if the arc exists, because each state has one arc
202+
// leaving it (except the last state).
203+
arc_idx012 = arc_idx0xx + idx1;
211204
// the following works because each FSA has one fewer symbols than arcs
212205
// (however it doesn't work for the last arc of each FSA; we check below.)
213206
int32_t symbol_idx01 = arc_idx012 - fsa_idx0;
214207
if (arc_idx012 < next_arc_idx0xx) {
215-
int32_t src_state = idx1,
216-
dest_state = idx1 + 1,
217-
symbol = (arc_idx012 + 1 < next_arc_idx0xx ?
218-
symbols_data[symbol_idx01] : -1); // kFinalSymbol
208+
int32_t src_state = idx1, dest_state = idx1 + 1,
209+
symbol =
210+
(arc_idx012 + 1 < next_arc_idx0xx ? symbols_data[symbol_idx01]
211+
: -1); // kFinalSymbol
219212
float score = 0.0;
220213
arcs_data[arc_idx012] = Arc(src_state, dest_state, symbol, score);
221214
row_ids2_data[arc_idx012] = state_idx01;
222215
} else {
223216
// The following ensures that the last element of row_splits1_data
224-
// (i.e. row_splits1[num_states]) is set to num_arcs. It also writes something
225-
// unnecessary for the last state of each FSA but the last one, which will
226-
// cause 2 threads to write the same item to the same location.
217+
// (i.e. row_splits1[num_states]) is set to num_arcs. It also writes
218+
// something unnecessary for the last state of each FSA but the last one,
219+
// which will cause 2 threads to write the same item to the same location.
227220
// Note that there is no arc with index `arc_idx01`, if you reach here.
228-
row_splits2_data[state_idx01+1] = arc_idx012;
221+
row_splits2_data[state_idx01 + 1] = arc_idx012;
229222
}
230223
row_splits2_data[state_idx01] = arc_idx012;
231224
};
232225
Eval(c, num_states, lambda);
233226

234-
return Ragged<Arc>(RaggedShape3(&states_shape.RowSplits(1),
235-
&states_shape.RowIds(1), num_states,
236-
&row_splits2, &row_splits2, num_arcs),
237-
arcs);
227+
return Ragged<Arc>(
228+
RaggedShape3(&states_shape.RowSplits(1), &states_shape.RowIds(1),
229+
num_states, &row_splits2, &row_splits2, num_arcs),
230+
arcs);
238231
}
239232

240233
namespace {

k2/csrc/fsa_utils.cu

+12-9
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ Ragged<int32_t> GetStateBatches(FsaVec &fsas, bool transpose) {
651651
*batch_starts_data = batch_starts.Data();
652652
const int32_t *fsas_row_splits1_data = fsas.RowSplits(1).Data();
653653

654-
#if 1
654+
#if 0
655655
// This is a simple version of the kernel that demonstrates what we're trying
656656
// to do with the more complex code.
657657
auto lambda_set_batch_info_simple = [=] __host__ __device__(int32_t fsa_idx) {
@@ -710,15 +710,16 @@ Ragged<int32_t> GetStateBatches(FsaVec &fsas, bool transpose) {
710710

711711
int32_t begin_state_idx01 = fsas_row_splits1_data[fsa_idx],
712712
end_state_idx01 = fsas_row_splits1_data[fsa_idx + 1];
713+
int32_t num_states_this_fsa = end_state_idx01 - begin_state_idx01;
713714
int32_t i = 0, cur_state_idx01 = begin_state_idx01;
714715

715-
if (task_idx >= end_state_idx01 - begin_state_idx01) return;
716+
if (task_idx >= num_states_this_fsa) return;
716717

717718
// The next loop advances `cur_state_idx01` by
718719
// a number of steps equal to `task_idx`.
719720
for (int32_t m = 0; m < log_power; ++m) {
720721
int32_t n = 1 << m;
721-
if (task_idx % n != 0) {
722+
if ((task_idx & n) != 0) {
722723
i += n;
723724
int32_t next = dest_states_powers_acc(m, cur_state_idx01);
724725
if (next >= end_state_idx01) return;
@@ -728,18 +729,20 @@ Ragged<int32_t> GetStateBatches(FsaVec &fsas, bool transpose) {
728729
K2_CHECK_EQ(i, task_idx);
729730

730731
while (1) {
732+
if (i >= num_states_this_fsa) return;
731733
batch_starts_data[begin_state_idx01 + i] = cur_state_idx01;
732-
int32_t next_state_idx01 =
733-
dest_states_powers_acc(log_power, cur_state_idx01);
734+
int32_t next_state_idx01 = dest_states_powers_acc(
735+
log_power,
736+
cur_state_idx01); // advance jobs_per_fsa = (1 << log_power) steps
734737
if (next_state_idx01 >= end_state_idx01) {
735738
// if exactly one step would also be enough to take us past the
736739
// boundary...
737-
if (dest_states_powers_acc(0, cur_state_idx01) >= next_state_idx01) {
740+
if (dest_states_powers_acc(0, cur_state_idx01) >= end_state_idx01) {
738741
num_batches_per_fsa_data[fsa_idx] = i + 1;
739742
}
740743
return;
741744
} else {
742-
i += cur_state_idx01;
745+
i += jobs_per_fsa;
743746
cur_state_idx01 = next_state_idx01;
744747
}
745748
}
@@ -757,7 +760,7 @@ Ragged<int32_t> GetStateBatches(FsaVec &fsas, bool transpose) {
757760
int32_t *ans_row_splits2_data = ans_row_splits2.Data();
758761
ans_row_splits2.Range(num_batches, 1) = num_states; // The kernel below won't
759762
// set this last element
760-
auto lambda_set_ans_row_ids2 =
763+
auto lambda_set_ans_row_splits2 =
761764
[=] __host__ __device__(int32_t idx01) -> void {
762765
int32_t idx0 = ans_row_ids1_data[idx01], // Fsa index
763766
idx0x = ans_row_splits1_data[idx0], idx1 = idx01 - idx0x,
@@ -770,7 +773,7 @@ Ragged<int32_t> GetStateBatches(FsaVec &fsas, bool transpose) {
770773
this_batch_start = batch_starts_data[fsas_idx01];
771774
ans_row_splits2_data[idx01] = this_batch_start;
772775
};
773-
Eval(c, num_batches, lambda_set_ans_row_ids2);
776+
Eval(c, num_batches, lambda_set_ans_row_splits2);
774777

775778
RaggedShape ans_shape =
776779
RaggedShape3(&ans_row_splits1, &ans_row_ids1, num_batches,

k2/csrc/ragged_ops.cu

+19-21
Original file line numberDiff line numberDiff line change
@@ -946,59 +946,58 @@ RaggedShape ChangeSublistSize(RaggedShape &src, int32_t size_delta) {
946946
std::vector<RaggedShapeDim> ans_axes(src.NumAxes() - 1);
947947
int32_t last_axis = src.NumAxes() - 1;
948948
// The following will only do something if src.NumAxes() > 2.
949-
for (int32_t i = 0; i + 1 < last_axis; i++)
950-
ans_axes[i] = src.Axes()[i];
949+
for (int32_t i = 0; i + 1 < last_axis; i++) ans_axes[i] = src.Axes()[i];
951950

952951
ContextPtr c = src.Context();
953952
int32_t num_rows = src.TotSize(last_axis - 1),
954-
src_num_elems = src.TotSize(last_axis),
955-
num_elems = src_num_elems + size_delta * num_rows;
953+
src_num_elems = src.TotSize(last_axis),
954+
num_elems = src_num_elems + size_delta * num_rows;
956955
ans_axes[0].row_splits = Array1<int32_t>(c, num_rows + 1);
957956
ans_axes[0].row_ids = Array1<int32_t>(c, num_elems);
958957
ans_axes[0].cached_tot_size = num_elems;
959958
const int32_t *src_row_splits_data = src.RowSplits(last_axis).Data(),
960-
*src_row_ids_data = src.RowIds(last_axis).Data();
959+
*src_row_ids_data = src.RowIds(last_axis).Data();
961960
int32_t *row_splits_data = ans_axes[0].row_splits.Data(),
962-
*row_ids_data = ans_axes[0].row_ids.Data();
961+
*row_ids_data = ans_axes[0].row_ids.Data();
963962

964963
{
965964
ParallelRunner pr(c);
966965
{
967966
With w(pr.NewStream());
968-
auto lambda_set_row_splits = [=] __host__ __device__ (int32_t idx0) -> void {
967+
auto lambda_set_row_splits =
968+
[=] __host__ __device__(int32_t idx0) -> void {
969969
row_splits_data[idx0] = src_row_splits_data[idx0] + size_delta * idx0;
970970
};
971971
Eval(c, num_rows + 1, lambda_set_row_splits);
972972
}
973973

974974
{
975975
With w(pr.NewStream());
976-
auto lambda_set_row_ids1 = [=] __host__ __device__ (int32_t src_idx01) -> void {
976+
auto lambda_set_row_ids1 =
977+
[=] __host__ __device__(int32_t src_idx01) -> void {
977978
int32_t src_idx0 = src_row_ids_data[src_idx01],
978-
src_idx0x = src_row_splits_data[src_idx0],
979-
src_idx1 = src_idx01 - src_idx0x,
980-
new_idx0x = row_splits_data[src_idx0],
981-
new_idx0x_next = row_splits_data[src_idx0 + 1],
982-
new_idx01 = new_idx0x + src_idx1;
979+
src_idx0x = src_row_splits_data[src_idx0],
980+
src_idx1 = src_idx01 - src_idx0x,
981+
new_idx0x = row_splits_data[src_idx0],
982+
new_idx0x_next = row_splits_data[src_idx0 + 1],
983+
new_idx01 = new_idx0x + src_idx1;
983984
// it's only necessary to guard the next statement with in 'if' because
984985
// size_delta might be negative.
985-
if (new_idx01 < new_idx0x_next)
986-
row_ids_data[new_idx01] = src_idx0;
986+
if (new_idx01 < new_idx0x_next) row_ids_data[new_idx01] = src_idx0;
987987
};
988988
Eval(c, num_elems, lambda_set_row_ids1);
989989
}
990990
if (size_delta > 0) {
991991
// This sets the row-ids that are not set by lambda_set_row_ids1.
992992
With w(pr.NewStream());
993-
auto lambda_set_row_ids2 = [=] __host__ __device__ (int32_t i) -> void {
994-
int32_t idx0 = i / size_delta, n = i % size_delta,
995-
next_idx0 = idx0 + 1;
993+
auto lambda_set_row_ids2 = [=] __host__ __device__(int32_t i) -> void {
994+
int32_t idx0 = i / size_delta, n = i % size_delta, next_idx0 = idx0 + 1;
996995
// The following formula is the same as the one in
997996
// lambda_set_row_splits; we want to compute the new value of
998997
// row_splits_data[next_idx0] without waiting for that kernel to
999998
// terminate.
1000-
int32_t next_idx0x = src_row_splits_data[next_idx0] +
1001-
size_delta * next_idx0;
999+
int32_t next_idx0x =
1000+
src_row_splits_data[next_idx0] + size_delta * next_idx0;
10021001
row_ids_data[next_idx0x - 1 - n] = idx0;
10031002
};
10041003
Eval(c, num_rows * size_delta, lambda_set_row_ids2);
@@ -1010,5 +1009,4 @@ RaggedShape ChangeSublistSize(RaggedShape &src, int32_t size_delta) {
10101009
return RaggedShape(ans_axes);
10111010
}
10121011

1013-
10141012
} // namespace k2

0 commit comments

Comments
 (0)