Skip to content

Commit

Permalink
Add Python wrappers for c.parallel merge_sort API (#3763)
Browse files Browse the repository at this point in the history
* Fix issue with converting types to strings in c.parallel merge_sort

* Add option to specify prefix for iterator methods to avoid name collisions

* Return error if output iterators are passed to c.parallel merge_sort

* Use `launcher_factory.PtxVersion()` in dispatch merge sort due to cudaErrorUnsupportedPtxVersion error

* Remove `cccl_type_enum_to_string` and replace with `cccl_type_enum_to_name` due to inconsistencies with the datatype being returned for INT64 and UINT64
  • Loading branch information
NaderAlAwar authored Feb 19, 2025
1 parent 2e89ed3 commit d7a1d6a
Show file tree
Hide file tree
Showing 15 changed files with 731 additions and 178 deletions.
6 changes: 3 additions & 3 deletions c/parallel/src/for/for_op_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

static std::string get_for_kernel_iterator(cccl_iterator_t iter)
{
const auto input_it_value_t = cccl_type_enum_to_string(iter.value_type.type);
const auto offset_t = cccl_type_enum_to_string(cccl_type_enum::UINT64);
const auto input_it_value_t = cccl_type_enum_to_name(iter.value_type.type);
const auto offset_t = cccl_type_enum_to_name(cccl_type_enum::UINT64);

constexpr std::string_view stateful_iterator =
R"XXX(
Expand Down Expand Up @@ -74,7 +74,7 @@ using for_each_iterator_t = input_iterator_state_t;

static std::string get_for_kernel_user_op(cccl_op_t user_op, cccl_iterator_t iter)
{
auto value_t = cccl_type_enum_to_string(iter.value_type.type);
auto value_t = cccl_type_enum_to_name(iter.value_type.type);

constexpr std::string_view op_format =
R"XXX(
Expand Down
23 changes: 16 additions & 7 deletions c/parallel/src/merge_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,11 @@ extern "C" CCCL_C_API CUresult cccl_device_merge_sort_build(
const int cc = cc_major * 10 + cc_minor;
const auto policy = merge_sort::get_policy(cc, output_keys_it.value_type.size);

const auto input_keys_it_value_t = cccl_type_enum_to_string(input_keys_it.value_type.type);
const auto input_items_it_value_t = cccl_type_enum_to_string(input_items_it.value_type.type);
const auto output_keys_it_value_t = cccl_type_enum_to_string(output_keys_it.value_type.type);
const auto output_items_it_value_t = cccl_type_enum_to_string(output_items_it.value_type.type);
const auto offset_t = cccl_type_enum_to_string(cccl_type_enum::INT64);
const auto input_keys_it_value_t = cccl_type_enum_to_name(input_keys_it.value_type.type);
const auto input_items_it_value_t = cccl_type_enum_to_name(input_items_it.value_type.type);
const auto output_keys_it_value_t = cccl_type_enum_to_name(output_keys_it.value_type.type);
const auto output_items_it_value_t = cccl_type_enum_to_name(output_items_it.value_type.type);
const auto offset_t = cccl_type_enum_to_name(cccl_type_enum::INT64);

const std::string input_keys_iterator_src = make_kernel_input_iterator(
offset_t,
Expand Down Expand Up @@ -456,8 +456,17 @@ extern "C" CCCL_C_API CUresult cccl_device_merge_sort(
cccl_op_t op,
CUstream stream) noexcept
{
bool pushed = false;
if (cccl_iterator_kind_t::iterator == d_out_keys.type || cccl_iterator_kind_t::iterator == d_out_items.type)
{
// See https://github.com/NVIDIA/cccl/issues/3722
fflush(stderr);
printf("\nERROR in cccl_device_merge_sort(): merge sort output cannot be an iterator\n");
fflush(stdout);
return CUDA_ERROR_UNKNOWN;
}

CUresult error = CUDA_SUCCESS;
bool pushed = false;
try
{
pushed = try_push_context();
Expand Down Expand Up @@ -494,7 +503,7 @@ extern "C" CCCL_C_API CUresult cccl_device_merge_sort(
catch (const std::exception& exc)
{
fflush(stderr);
printf("\nEXCEPTION in cccl_device_reduce(): %s\n", exc.what());
printf("\nEXCEPTION in cccl_device_merge_sort(): %s\n", exc.what());
fflush(stdout);
error = CUDA_ERROR_UNKNOWN;
}
Expand Down
6 changes: 3 additions & 3 deletions c/parallel/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_build(
const int cc = cc_major * 10 + cc_minor;
const cccl_type_info accum_t = reduce::get_accumulator_type(op, input_it, init);
const auto policy = reduce::get_policy(cc, accum_t);
const auto accum_cpp = cccl_type_enum_to_string(accum_t.type);
const auto input_it_value_t = cccl_type_enum_to_string(input_it.value_type.type);
const auto offset_t = cccl_type_enum_to_string(cccl_type_enum::UINT64);
const auto accum_cpp = cccl_type_enum_to_name(accum_t.type);
const auto input_it_value_t = cccl_type_enum_to_name(input_it.value_type.type);
const auto offset_t = cccl_type_enum_to_name(cccl_type_enum::UINT64);

const std::string input_iterator_src =
make_kernel_input_iterator(offset_t, "input_iterator_state_t", input_it_value_t, input_it);
Expand Down
6 changes: 3 additions & 3 deletions c/parallel/src/scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ extern "C" CCCL_C_API CUresult cccl_device_scan_build(
const int cc = cc_major * 10 + cc_minor;
const cccl_type_info accum_t = scan::get_accumulator_type(op, input_it, init);
const auto policy = scan::get_policy(cc, accum_t);
const auto accum_cpp = cccl_type_enum_to_string(accum_t.type);
const auto input_it_value_t = cccl_type_enum_to_string(input_it.value_type.type);
const auto offset_t = cccl_type_enum_to_string(cccl_type_enum::UINT64);
const auto accum_cpp = cccl_type_enum_to_name(accum_t.type);
const auto input_it_value_t = cccl_type_enum_to_name(input_it.value_type.type);
const auto offset_t = cccl_type_enum_to_name(cccl_type_enum::UINT64);

const std::string input_iterator_src =
make_kernel_input_iterator(offset_t, "input_iterator_state_t", input_it_value_t, input_it);
Expand Down
47 changes: 0 additions & 47 deletions c/parallel/src/util/types.cpp

This file was deleted.

114 changes: 37 additions & 77 deletions c/parallel/src/util/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,87 +25,47 @@ std::string cccl_type_enum_to_name(cccl_type_enum type, bool is_pointer = false)
{
std::string result;

if (is_pointer)
switch (type)
{
switch (type)
{
case cccl_type_enum::INT8:

check(nvrtcGetTypeName<::cuda::std::int8_t*>(&result));
break;
case cccl_type_enum::INT16:
check(nvrtcGetTypeName<::cuda::std::int16_t*>(&result));
break;
case cccl_type_enum::INT32:
check(nvrtcGetTypeName<::cuda::std::int32_t*>(&result));
break;
case cccl_type_enum::INT64:
check(nvrtcGetTypeName<::cuda::std::int64_t*>(&result));
break;
case cccl_type_enum::UINT8:
check(nvrtcGetTypeName<::cuda::std::uint8_t*>(&result));
break;
case cccl_type_enum::UINT16:
check(nvrtcGetTypeName<::cuda::std::uint16_t*>(&result));
break;
case cccl_type_enum::UINT32:
check(nvrtcGetTypeName<::cuda::std::uint32_t*>(&result));
break;
case cccl_type_enum::UINT64:
check(nvrtcGetTypeName<::cuda::std::uint64_t*>(&result));
break;
case cccl_type_enum::FLOAT32:
check(nvrtcGetTypeName<float*>(&result));
break;
case cccl_type_enum::FLOAT64:
check(nvrtcGetTypeName<double*>(&result));
break;
case cccl_type_enum::STORAGE:
check(nvrtcGetTypeName<StorageT*>(&result));
break;
}
case cccl_type_enum::INT8:
result = "::cuda::std::int8_t";
break;
case cccl_type_enum::INT16:
result = "::cuda::std::int16_t";
break;
case cccl_type_enum::INT32:
result = "::cuda::std::int32_t";
break;
case cccl_type_enum::INT64:
result = "::cuda::std::int64_t";
break;
case cccl_type_enum::UINT8:
result = "::cuda::std::uint8_t";
break;
case cccl_type_enum::UINT16:
result = "::cuda::std::uint16_t";
break;
case cccl_type_enum::UINT32:
result = "::cuda::std::uint32_t";
break;
case cccl_type_enum::UINT64:
result = "::cuda::std::uint64_t";
break;
case cccl_type_enum::FLOAT32:
result = "float";
break;
case cccl_type_enum::FLOAT64:
result = "double";
break;
case cccl_type_enum::STORAGE:
check(nvrtcGetTypeName<StorageT>(&result));
break;
}
else

if (is_pointer)
{
switch (type)
{
case cccl_type_enum::INT8:
check(nvrtcGetTypeName<::cuda::std::int8_t>(&result));
break;
case cccl_type_enum::INT16:
check(nvrtcGetTypeName<::cuda::std::int16_t>(&result));
break;
case cccl_type_enum::INT32:
check(nvrtcGetTypeName<::cuda::std::int32_t>(&result));
break;
case cccl_type_enum::INT64:
check(nvrtcGetTypeName<::cuda::std::int64_t>(&result));
break;
case cccl_type_enum::UINT8:
check(nvrtcGetTypeName<::cuda::std::uint8_t>(&result));
break;
case cccl_type_enum::UINT16:
check(nvrtcGetTypeName<::cuda::std::uint16_t>(&result));
break;
case cccl_type_enum::UINT32:
check(nvrtcGetTypeName<::cuda::std::uint32_t>(&result));
break;
case cccl_type_enum::UINT64:
check(nvrtcGetTypeName<::cuda::std::uint64_t>(&result));
break;
case cccl_type_enum::FLOAT32:
check(nvrtcGetTypeName<float>(&result));
break;
case cccl_type_enum::FLOAT64:
check(nvrtcGetTypeName<double>(&result));
break;
case cccl_type_enum::STORAGE:
check(nvrtcGetTypeName<StorageT>(&result));
break;
}
result += "*";
}

return result;
}

std::string_view cccl_type_enum_to_string(cccl_type_enum type);
104 changes: 74 additions & 30 deletions c/parallel/test/test_merge_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,36 +254,6 @@ TEST_CASE("DeviceMergeSort::SortKeys works with input iterators", "[merge_sort]"
REQUIRE(expected_keys == std::vector<TestType>(input_keys_ptr));
}

// TEST_CASE("DeviceMergeSort::SortKeys works with output iterators", "[merge_sort]")
// {
// using TestType = int;
// const int num_items = GENERATE_COPY(take(2, random(1, 1000000)), values({500, 1000000, 2000000}));

// operation_t op = make_operation("op", get_merge_sort_op(get_type_info<TestType>().type));
// iterator_t<TestType, random_access_iterator_state_t> output_keys_it =
// make_iterator<TestType, random_access_iterator_state_t>(
// "struct random_access_iterator_state_t { int* d_input; };\n",
// {"advance",
// "extern \"C\" __device__ void advance(random_access_iterator_state_t* state, unsigned long long offset) {\n"
// " state->d_input += offset;\n"
// "}"},
// {"dereference",
// "extern \"C\" __device__ void dereference(random_access_iterator_state_t* state, int x) {\n"
// " *state->d_input = x;\n"
// "}"});
// std::vector<TestType> input_keys = make_shuffled_sequence<TestType>(num_items);
// std::vector<TestType> expected_keys = input_keys;

// pointer_t<TestType> input_keys_it(input_keys);
// pointer_t<TestType> input_items_it;
// output_keys_it.state.d_input = input_keys_it.ptr;

// merge_sort(input_keys_it, input_items_it, output_keys_it, input_items_it, num_items, op);

// std::sort(expected_keys.begin(), expected_keys.end());
// REQUIRE(expected_keys == std::vector<TestType>(input_keys_it));
// }

struct item_random_access_iterator_state_t
{
int* d_input;
Expand Down Expand Up @@ -341,3 +311,77 @@ TEST_CASE("DeviceMergeSort::SortPairs works with input iterators", "[merge_sort]
REQUIRE(expected_keys == std::vector<TestType>(input_keys_ptr));
REQUIRE(expected_items == std::vector<item_t>(input_items_ptr));
}

// These tests with output iterators are currently failing https://github.com/NVIDIA/cccl/issues/3722
#ifdef NEVER_DEFINED
TEST_CASE("DeviceMergeSort::SortKeys works with output iterators", "[merge_sort]")
{
using TestType = int;
const int num_items = GENERATE_COPY(take(2, random(1, 1000000)), values({500, 1000000, 2000000}));

operation_t op = make_operation("op", get_merge_sort_op(get_type_info<TestType>().type));
iterator_t<TestType, random_access_iterator_state_t> output_keys_it =
make_iterator<TestType, random_access_iterator_state_t>(
"struct random_access_iterator_state_t { int* d_input; };\n",
{"advance",
"extern \"C\" __device__ void advance(random_access_iterator_state_t* state, unsigned long long offset) {\n"
" state->d_input += offset;\n"
"}"},
{"dereference",
"extern \"C\" __device__ void dereference(random_access_iterator_state_t* state, int x) {\n"
" *state->d_input = x;\n"
"}"});
std::vector<TestType> input_keys = make_shuffled_key_ranks_vector<TestType>(num_items);
std::vector<TestType> expected_keys = input_keys;

pointer_t<TestType> input_keys_it(input_keys);
pointer_t<TestType> input_items_it;
output_keys_it.state.d_input = input_keys_it.ptr;

merge_sort(input_keys_it, input_items_it, output_keys_it, input_items_it, num_items, op);

std::sort(expected_keys.begin(), expected_keys.end());
REQUIRE(expected_keys == std::vector<TestType>(input_keys_it));
}

TEST_CASE("DeviceMergeSort::SortPairs works with output iterators for items", "[merge_sort]")
{
using TestType = int;
using item_t = int;
const int num_items = GENERATE_COPY(take(2, random(1, 1000000)), values({500, 1000000, 2000000}));

operation_t op = make_operation("op", get_merge_sort_op(get_type_info<TestType>().type));
std::vector<TestType> input_keys = make_shuffled_sequence<TestType>(num_items);
std::vector<item_t> input_items(num_items);
std::transform(input_keys.begin(), input_keys.end(), input_items.begin(), [](TestType key) {
return static_cast<item_t>(key);
});
std::vector<TestType> expected_keys = input_keys;
std::vector<item_t> expected_items = input_items;

iterator_t<item_t, item_random_access_iterator_state_t> output_items_it =
make_iterator<TestType, item_random_access_iterator_state_t>(
"struct item_random_access_iterator_state_t { int* d_input; };\n",
{"advance",
"extern \"C\" __device__ void advance(item_random_access_iterator_state_t* state, unsigned long long offset) "
"{\n"
" state->d_input += offset;\n"
"}"},
{"dereference",
"extern \"C\" __device__ void dereference(item_random_access_iterator_state_t* state, int x) {\n"
" *state->d_input = x;\n"
"}"});

pointer_t<TestType> input_keys_it(input_keys);
pointer_t<item_t> input_items_it(input_items);
output_items_it.state.d_input = input_items_it.ptr;

merge_sort(input_keys_it, input_items_it, input_keys_it, output_items_it, num_items, op);

std::sort(expected_keys.begin(), expected_keys.end());
std::sort(expected_items.begin(), expected_items.end());
REQUIRE(expected_keys == std::vector<TestType>(input_keys_it));
REQUIRE(expected_items == std::vector<item_t>(input_items_it));
}

#endif
2 changes: 1 addition & 1 deletion cub/cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ struct DispatchMergeSort
{
// Get PTX version
int ptx_version = 0;
error = CubDebug(PtxVersion(ptx_version));
error = CubDebug(launcher_factory.PtxVersion(ptx_version));
if (cudaSuccess != error)
{
break;
Expand Down
Loading

0 comments on commit d7a1d6a

Please sign in to comment.