Skip to content

Commit d4ca571

Browse files
authored
Implemented GetCountsPartitioned and tested (k2-fsa#239)
1 parent 5535999 commit d4ca571

File tree

5 files changed

+82
-11
lines changed

5 files changed

+82
-11
lines changed

k2/csrc/array_ops.cu

+5-1
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,16 @@ void RowIdsToRowSplits(const Array1<int32_t> &row_ids,
271271
}
272272

273273
Array1<int32_t> GetCounts(const Array1<int32_t> &src, int32_t n) {
274-
K2_CHECK_GE(n, 1);
274+
K2_CHECK_GE(n, 0);
275275
ContextPtr c = src.Context();
276276
int32_t dim = src.Dim();
277277
const int32_t *src_data = src.Data();
278278
Array1<int32_t> ans(c, n, 0); // init with 0
279279
int32_t *ans_data = ans.Data();
280+
if (n == 0) {
281+
K2_CHECK_EQ(dim, 0);
282+
return ans;
283+
}
280284

281285
DeviceType d = c->GetDeviceType();
282286
if (d == kCpu) {

k2/csrc/array_ops_test.cu

+9
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,15 @@ void TestGetCounts() {
13551355
context = GetCudaContext();
13561356
}
13571357

1358+
{
1359+
// empty case
1360+
int32_t n = 0;
1361+
std::vector<int32_t> values;
1362+
Array1<int32_t> src(context, values);
1363+
Array1<int32_t> ans = GetCounts(src, n);
1364+
EXPECT_EQ(ans.Dim(), 0);
1365+
}
1366+
13581367
{
13591368
// simple case
13601369
int32_t n = 8;

k2/csrc/ragged_ops.cu

+13
Original file line numberDiff line numberDiff line change
@@ -793,4 +793,17 @@ RaggedShape TrivialShape(ContextPtr &c, int32_t num_elems) {
793793
return RaggedShape2(&row_splits, &row_ids, num_elems);
794794
}
795795

796+
Ragged<int32_t> GetCountsPartitioned(Ragged<int32_t> &src,
797+
RaggedShape &ans_ragged_shape) {
798+
K2_CHECK_EQ(src.NumAxes(), 2);
799+
K2_CHECK_EQ(ans_ragged_shape.NumAxes(), 2);
800+
K2_CHECK(IsCompatible(src, ans_ragged_shape));
801+
K2_CHECK_EQ(src.Dim0(), ans_ragged_shape.Dim0());
802+
const Array1<int32_t> &values = src.values;
803+
const Array1<int32_t> &row_splits = ans_ragged_shape.RowSplits(1);
804+
int32_t n = ans_ragged_shape.NumElements();
805+
Array1<int32_t> counts = GetCounts(values, n);
806+
return Ragged<int32_t>(ans_ragged_shape, counts);
807+
}
808+
796809
} // namespace k2

k2/csrc/ragged_ops.h

+11-10
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,8 @@ inline Ragged<T> RaggedFromTotSizes(ContextPtr &c,
357357
src.values.Dim() which tells us the order in which these elements would
358358
appear if sorted by column. (TODO: we can decide later whether to require
359359
sorting secondarily by row). So `src.values[ans]` will be in sorted
360-
order at exit, and `ans` will contain all numbers from 0 to `src.values.Dim() - 1`.
360+
order at exit, and `ans` will contain all numbers from 0 to `src.values.Dim()
361+
- 1`.
361362
362363
If `src` has more than 2 axes, the earlier-numbered axes do not affect
363364
the result, except for an efficiency modification: we require that the
@@ -369,21 +370,21 @@ inline Ragged<T> RaggedFromTotSizes(ContextPtr &c,
369370
TODO(dan): we may at some point make, as an optional output, row-splits and/or
370371
row-ids of the rearranged matrix.
371372
372-
This problem has some relationship to the cusparse library, specifically the csr2csc
373-
functions https://docs.nvidia.com/cuda/cusparse/index.html#csr2cscEx2).
374-
However I'm not sure what it does when there are repeated elements. It might
375-
be easiest to implement it via sorting for now.
373+
This problem has some relationship to the cusparse library, specifically the
374+
csr2csc functions
375+
https://docs.nvidia.com/cuda/cusparse/index.html#csr2cscEx2). However I'm not
376+
sure what it does when there are repeated elements. It might be easiest to
377+
implement it via sorting for now.
378+
376379
377-
378380
@param [in] src Input tensor, see above.
379-
@param [in] num_cols Number of columns in matrix to be transposed;
381+
@param [in] num_cols Number of columns in matrix to be transposed;
380382
we require 0 <= src.values[i] < num_cols.
381383
*/
382384
Array1<int32_t> GetTransposeReordering(Ragged<int32_t> &src, int32_t num_cols);
383385

384-
385386
/*
386-
This function is like GetCounts() that is declared in array_ops.h,
387+
This function is like GetCounts() that is declared in array_ops.h,
387388
but works on a partitioned problem (this should be faster).
388389
389390
@param [in] src A ragged array with src.NumAxes() == 2
@@ -403,7 +404,7 @@ Array1<int32_t> GetTransposeReordering(Ragged<int32_t> &src, int32_t num_cols);
403404
the result (with num_rows == ans_ragged_shape.NumElements()), then
404405
for each i, let ans.values[i] = row_splits[i+1]-row_splits[i] (where
405406
row_splits is the output of RowIdsToRowSplits() we just called).
406-
407+
407408
This could actually be implemented using the GetCounts() of array_ops.h,
408409
ignoring the structure; the structure should help the speed though.
409410
This equivalence should be useful for testing.

k2/csrc/ragged_test.cu

+44
Original file line numberDiff line numberDiff line change
@@ -1016,4 +1016,48 @@ TEST(RaggedShapeOpsTest, TestRenumber) {
10161016
TestRenumber<kCuda>();
10171017
}
10181018

1019+
template <DeviceType d>
1020+
void TestGetCountsPartitioned() {
1021+
ContextPtr cpu = GetCpuContext(); // will use to copy data
1022+
ContextPtr context = nullptr;
1023+
if (d == kCpu) {
1024+
context = GetCpuContext();
1025+
} else {
1026+
K2_CHECK_EQ(d, kCuda);
1027+
context = GetCudaContext();
1028+
}
1029+
1030+
// Testing with simple case is good enough as we have tested GetCounts() with
1031+
// random large size and GetCountsPartitioned just calls GetCounts.
1032+
std::vector<int32_t> src_row_splits_vec = {0, 3, 4, 6, 10};
1033+
Array1<int32_t> src_row_splits(context, src_row_splits_vec);
1034+
RaggedShape src_shape = RaggedShape2(&src_row_splits, nullptr, -1);
1035+
std::vector<int32_t> src_values_vec = {0, 1, 0, 2, 5, 5, 7, 7, 9, 7};
1036+
Array1<int32_t> src_values(context, src_values_vec);
1037+
Ragged<int32_t> src(src_shape, src_values);
1038+
1039+
std::vector<int32_t> ans_row_splits_vec = {0, 2, 4, 7, 10};
1040+
Array1<int32_t> ans_row_splits(context, ans_row_splits_vec);
1041+
RaggedShape ans_shape = RaggedShape2(&ans_row_splits, nullptr, -1);
1042+
1043+
Ragged<int32_t> result = GetCountsPartitioned(src, ans_shape);
1044+
1045+
ASSERT_EQ(result.NumAxes(), 2);
1046+
// Check row_splits
1047+
Array1<int32_t> row_splits = result.shape.RowSplits(1).To(cpu);
1048+
std::vector<int32_t> result_row_splits(row_splits.Data(),
1049+
row_splits.Data() + row_splits.Dim());
1050+
EXPECT_EQ(result_row_splits, ans_row_splits_vec);
1051+
// check values
1052+
std::vector<int32_t> expected_data = {2, 1, 1, 0, 0, 2, 0, 3, 0, 1};
1053+
Array1<int32_t> values = result.values.To(cpu);
1054+
std::vector<int32_t> data(values.Data(), values.Data() + values.Dim());
1055+
EXPECT_EQ(data, expected_data);
1056+
}
1057+
1058+
TEST(RaggedShapeOpsTest, TestGetCountsPartitioned) {
1059+
TestGetCountsPartitioned<kCpu>();
1060+
TestGetCountsPartitioned<kCuda>();
1061+
}
1062+
10191063
} // namespace k2

0 commit comments

Comments
 (0)