Skip to content

Commit c0be6f5

Browse files
authored
add tests with random large size for RowSplits<->RowIds conversion (k2-fsa#172)
1 parent fde770d commit c0be6f5

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

k2/csrc/utils_test.cu

+63
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,39 @@ void TestRowSplitsToRowIds() {
159159
cpu_array.Data() + cpu_array.Dim());
160160
EXPECT_EQ(cpu_data, row_ids_vec);
161161
}
162+
163+
{
164+
// test with random large size
165+
const int32_t min_num_elements = 2000;
166+
RaggedShape shape = RandomRaggedShape(true, 2, 2, min_num_elements, 10000);
167+
ASSERT_EQ(shape.NumAxes(), 2);
168+
const auto &axes = shape.Axes();
169+
ASSERT_EQ(axes.size(), 1);
170+
// note `src_row_splits` is on CPU as it is created with RandomRaggedShape
171+
const Array1<int32_t> &src_row_splits = axes[0].row_splits;
172+
Array1<int32_t> row_splits = src_row_splits.To(context);
173+
// don't call `shape.RowIds(axis)` here, it will make the test meaningless
174+
// as `shape.RowId(axis)` calls `RowSplitsToRowIds()` internally.
175+
// note `row_ids` is on CPU as it is created with RandomRaggedShape
176+
const Array1<int32_t> &expected_row_ids = axes[0].row_ids;
177+
int32_t num_rows = row_splits.Dim() - 1;
178+
int32_t num_elements = row_splits[num_rows];
179+
ASSERT_GE(num_elements, min_num_elements);
180+
ASSERT_EQ(expected_row_ids.Dim(), num_elements);
181+
Array1<int32_t> row_ids(context, num_elements);
182+
int32_t *row_ids_data = row_ids.Data();
183+
EXPECT_EQ(row_ids.Dim(), num_elements);
184+
RowSplitsToRowIds(context, num_rows, row_splits.Data(), num_elements,
185+
row_ids_data);
186+
// copy data from CPU/GPU to CPU
187+
Array1<int32_t> cpu_array = row_ids.To(cpu);
188+
std::vector<int32_t> cpu_data(cpu_array.Data(),
189+
cpu_array.Data() + cpu_array.Dim());
190+
std::vector<int32_t> expected_row_ids_data(
191+
expected_row_ids.Data(),
192+
expected_row_ids.Data() + expected_row_ids.Dim());
193+
EXPECT_EQ(cpu_data, expected_row_ids_data);
194+
}
162195
}
163196

164197
TEST(UtilsTest, RowSplitsToRowIds) {
@@ -230,6 +263,36 @@ void TestRowIdsToRowSplits() {
230263
cpu_array.Data() + cpu_array.Dim());
231264
EXPECT_EQ(cpu_data, row_splits_vec);
232265
}
266+
267+
{
268+
// test with random large size
269+
const int32_t min_num_elements = 2000;
270+
RaggedShape shape = RandomRaggedShape(true, 2, 2, min_num_elements, 10000);
271+
ASSERT_EQ(shape.NumAxes(), 2);
272+
const auto &axes = shape.Axes();
273+
ASSERT_EQ(axes.size(), 1);
274+
// note `src_row_ids` is on CPU as it is created with RandomRaggedShape
275+
const Array1<int32_t> &src_row_ids = axes[0].row_ids;
276+
Array1<int32_t> row_ids = src_row_ids.To(context);
277+
// note `row_splits` is on CPU as it is created with RandomRaggedShape
278+
const Array1<int32_t> &expected_row_splits = axes[0].row_splits;
279+
int32_t num_elements = row_ids.Dim();
280+
ASSERT_GE(num_elements, min_num_elements);
281+
int32_t num_rows = expected_row_splits.Dim() - 1;
282+
Array1<int32_t> row_splits(context, num_rows + 1);
283+
EXPECT_EQ(row_splits.Dim(), num_rows + 1);
284+
int32_t *row_splits_data = row_splits.Data();
285+
RowIdsToRowSplits(context, num_elements, row_ids.Data(), false, num_rows,
286+
row_splits_data);
287+
// copy data from CPU/GPU to CPU
288+
Array1<int32_t> cpu_array = row_splits.To(cpu);
289+
std::vector<int32_t> cpu_data(cpu_array.Data(),
290+
cpu_array.Data() + cpu_array.Dim());
291+
std::vector<int32_t> expected_row_splits_data(
292+
expected_row_splits.Data(),
293+
expected_row_splits.Data() + expected_row_splits.Dim());
294+
EXPECT_EQ(cpu_data, expected_row_splits_data);
295+
}
233296
}
234297

235298
TEST(UtilsTest, RowIdsToRowSplits) {

0 commit comments

Comments
 (0)