Skip to content

Commit 8a32c4c

Browse files
authored
implement GetTransposeReordering. (k2-fsa#237)
* implement GetTransposeReordering. * remove block comments. * switch to sorting based transpose. * remove cub warnings. * add GPU version. * fix style issues. * remove duplicate code. * resolve some comments. * optimize the CPU version.
1 parent d4ca571 commit 8a32c4c

File tree

3 files changed

+121
-3
lines changed

3 files changed

+121
-3
lines changed

cmake/cub.cmake

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ function(download_cub)
88

99
include(FetchContent)
1010

11-
set(cub_URL "https://github.com/NVlabs/cub/archive/1.9.10.tar.gz")
12-
set(cub_HASH "SHA256=2bd7077a3d9741f0689e6c1eb58c6278fc96eccc27d964168bc8be1bc3a9040f")
11+
set(cub_URL "https://github.com/NVlabs/cub/archive/1.10.0.tar.gz")
12+
set(cub_HASH "SHA256=8531e09f909aa021125cffa70a250761dfc247f960d7a1a12f65e6651ffb6477")
1313

1414
FetchContent_Declare(cub
1515
URL ${cub_URL}

k2/csrc/ragged_ops.cu

+72-1
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,24 @@
44
*
55
* @copyright
66
* Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu)
7+
* Mobvoi Inc. (authors: Fangjun Kuang)
78
*
89
* @copyright
910
* See LICENSE for clarification regarding multiple authors
1011
*/
1112

12-
#include <cub/cub.cuh>
13+
#include <algorithm>
14+
#include <memory>
1315
#include <vector>
1416

17+
#include "cub/cub.cuh"
1518
#include "k2/csrc/array_ops.h"
1619
#include "k2/csrc/math.h"
20+
#include "k2/csrc/moderngpu_allocator.h"
1721
#include "k2/csrc/ragged.h"
1822
#include "k2/csrc/ragged_ops.h"
23+
#include "moderngpu/kernel_mergesort.hxx"
24+
1925
namespace {
2026

2127
/*
@@ -806,4 +812,69 @@ Ragged<int32_t> GetCountsPartitioned(Ragged<int32_t> &src,
806812
return Ragged<int32_t>(ans_ragged_shape, counts);
807813
}
808814

815+
static Array1<int32_t> GetTransposeReorderingCpu(Ragged<int32_t> &src,
816+
int32_t num_cols) {
817+
std::vector<std::vector<int32_t>> column_indexes(num_cols); // [column][row]
818+
const int32_t *values_data = src.values.Data();
819+
int32_t n = src.values.Dim();
820+
821+
for (int32_t i = 0; i != n; ++i) {
822+
int32_t bucket = values_data[i];
823+
column_indexes[bucket].push_back(i);
824+
}
825+
826+
Array1<int32_t> ans(src.Context(), n);
827+
int32_t *ans_data = ans.Data();
828+
for (int32_t i = 0; i != num_cols; ++i) {
829+
std::copy(column_indexes[i].begin(), column_indexes[i].end(), ans_data);
830+
ans_data += column_indexes[i].size();
831+
}
832+
return ans;
833+
}
834+
835+
Array1<int32_t> GetTransposeReordering(Ragged<int32_t> &src, int32_t num_cols) {
836+
ContextPtr &context = src.Context();
837+
if (src.NumAxes() < 2) {
838+
// src is empty
839+
return Array1<int32_t>(context, 0);
840+
}
841+
842+
DeviceType device_type = context->GetDeviceType();
843+
if (device_type == kCpu) return GetTransposeReorderingCpu(src, num_cols);
844+
845+
K2_CHECK_EQ(device_type, kCuda);
846+
847+
const int32_t *row_splits1_data = src.RowSplits(src.NumAxes() - 1).Data();
848+
const int32_t *row_ids1_data = src.RowIds(src.NumAxes() - 1).Data();
849+
const int32_t *value_data = src.values.Data();
850+
int32_t n = src.values.Dim();
851+
Array1<int32_t> ans = Range(context, n, 0);
852+
853+
auto lambda_comp = [=] __device__(int32_t a_idx01, int32_t b_idx01) -> bool {
854+
int32_t a_idx0 = row_ids1_data[a_idx01];
855+
int32_t b_idx0 = row_ids1_data[b_idx01];
856+
857+
int32_t a_col_index = value_data[a_idx01];
858+
int32_t b_col_index = value_data[b_idx01];
859+
860+
if (a_col_index < b_col_index) return true; // sort by column indexes
861+
if (a_col_index > b_col_index) return false;
862+
863+
// now we have a_col_index == b_col_index
864+
if (a_idx0 < b_idx0) return true; // sort by row indexes
865+
if (a_idx0 > b_idx0) return false;
866+
867+
// now we have a_idx0 == b_idx0 && a_col_index == b_col_index
868+
// this entry is duplicated in the sparse matrix.
869+
return false; // we can return either true or false here.
870+
};
871+
872+
std::unique_ptr<mgpu::context_t> mgpu_context =
873+
GetModernGpuAllocator(context->GetDeviceId());
874+
875+
K2_CUDA_SAFE_CALL(mgpu::mergesort(ans.Data(), n, lambda_comp, *mgpu_context));
876+
877+
return ans;
878+
}
879+
809880
} // namespace k2

k2/csrc/ragged_test.cu

+47
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,53 @@ TEST(RaggedShapeOpsTest, TestRenumber) {
10151015
TestRenumber<kCpu>();
10161016
TestRenumber<kCuda>();
10171017
}
1018+
TEST(GetTransposeReordering, NoDuplicates) {
1019+
// 0 0 0 9 2
1020+
// 5 8 0 0 1
1021+
// 0 0 3 0 0
1022+
// 0 6 0 0 0
1023+
std::vector<int32_t> col_indexes{3, 4, 0, 1, 4, 2, 1};
1024+
std::vector<int32_t> _row_splits{0, 2, 5, 6, 7};
1025+
for (auto &context : {GetCpuContext(), GetCudaContext()}) {
1026+
Array1<int32_t> row_splits(context, _row_splits);
1027+
RaggedShape shape = RaggedShape2(&row_splits, nullptr, -1);
1028+
Array1<int32_t> values(context, col_indexes);
1029+
1030+
Ragged<int32_t> ragged(shape, values);
1031+
Array1<int32_t> order = GetTransposeReordering(ragged, 5);
1032+
// index 0 1 2 3 4 5 6
1033+
// it maps 9 2 5 8 1 3 6 to
1034+
// 5 8 6 3 9 2 1
1035+
// so it returns
1036+
// 2 3 6 5 0 1 4
1037+
CheckArrayData(order, {2, 3, 6, 5, 0, 1, 4});
1038+
EXPECT_TRUE(context->IsCompatible(*order.Context()));
1039+
}
1040+
}
1041+
1042+
TEST(GetTransposeReordering, WithDuplicates) {
1043+
// 0 0 0 (9,9,9)
1044+
// 5 8 0 0
1045+
// 0 0 (3,3) 0
1046+
// 0 6 0 0
1047+
std::vector<int32_t> col_indexes{3, 3, 3, 0, 1, 2, 2, 1};
1048+
std::vector<int32_t> _row_splits{0, 3, 5, 7, 8};
1049+
for (auto &context : {GetCpuContext(), GetCudaContext()}) {
1050+
Array1<int32_t> row_splits(context, _row_splits);
1051+
RaggedShape shape = RaggedShape2(&row_splits, nullptr, -1);
1052+
Array1<int32_t> values(context, col_indexes);
1053+
1054+
Ragged<int32_t> ragged(shape, values);
1055+
Array1<int32_t> order = GetTransposeReordering(ragged, 4);
1056+
// index 0 1 2 3 4 5 6 7
1057+
// it maps 9 9 9 5 8 3 3 6 to
1058+
// 5 8 6 3 3 9 9 9
1059+
// so it returns
1060+
// 3 4 7 5 6 0 1 2 Note that it is stable
1061+
CheckArrayData(order, {3, 4, 7, 5, 6, 0, 1, 2});
1062+
EXPECT_TRUE(context->IsCompatible(*order.Context()));
1063+
}
1064+
}
10181065

10191066
template <DeviceType d>
10201067
void TestGetCountsPartitioned() {

0 commit comments

Comments
 (0)