Skip to content

Commit

Permalink
Adds support for large num items to DeviceMerge (#3530)
Browse files Browse the repository at this point in the history
* adds support for large num items

* re-enable vsmem tests

* rephrases test description
  • Loading branch information
elstehle authored Jan 30, 2025
1 parent 0c17dbd commit c02e845
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 114 deletions.
18 changes: 12 additions & 6 deletions cub/cub/device/device_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,19 @@ struct DeviceMerge
void* d_temp_storage,
std::size_t& temp_storage_bytes,
KeyIteratorIn1 keys_in1,
int num_keys1,
::cuda::std::int64_t num_keys1,
KeyIteratorIn2 keys_in2,
int num_keys2,
::cuda::std::int64_t num_keys2,
KeyIteratorOut keys_out,
CompareOp compare_op = {},
cudaStream_t stream = nullptr)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergeKeys");

using offset_t = ::cuda::std::int64_t;

return detail::merge::
dispatch_t<KeyIteratorIn1, NullType*, KeyIteratorIn2, NullType*, KeyIteratorOut, NullType*, int, CompareOp>::
dispatch_t<KeyIteratorIn1, NullType*, KeyIteratorIn2, NullType*, KeyIteratorOut, NullType*, offset_t, CompareOp>::
dispatch(
d_temp_storage,
temp_storage_bytes,
Expand Down Expand Up @@ -161,24 +164,27 @@ struct DeviceMerge
std::size_t& temp_storage_bytes,
KeyIteratorIn1 keys_in1,
ValueIteratorIn1 values_in1,
int num_pairs1,
::cuda::std::int64_t num_pairs1,
KeyIteratorIn2 keys_in2,
ValueIteratorIn2 values_in2,
int num_pairs2,
::cuda::std::int64_t num_pairs2,
KeyIteratorOut keys_out,
ValueIteratorOut values_out,
CompareOp compare_op = {},
cudaStream_t stream = nullptr)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergePairs");

using offset_t = ::cuda::std::int64_t;

return detail::merge::dispatch_t<
KeyIteratorIn1,
ValueIteratorIn1,
KeyIteratorIn2,
ValueIteratorIn2,
KeyIteratorOut,
ValueIteratorOut,
int,
offset_t,
CompareOp>::dispatch(d_temp_storage,
temp_storage_bytes,
keys_in1,
Expand Down
129 changes: 21 additions & 108 deletions cub/test/catch2_test_device_merge.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,103 +20,8 @@
DECLARE_LAUNCH_WRAPPER(cub::DeviceMerge::MergePairs, merge_pairs);
DECLARE_LAUNCH_WRAPPER(cub::DeviceMerge::MergeKeys, merge_keys);

// TODO(bgruber): replace the following by the CUB device API directly, once we have figured out how to handle different
// offset types
namespace detail
{
template <typename KeyIteratorIn1,
typename KeyIteratorIn2,
typename KeyIteratorOut,
typename Offset,
typename CompareOp = ::cuda::std::less<>>
CUB_RUNTIME_FUNCTION static cudaError_t merge_keys_custom_offset_type(
void* d_temp_storage,
std::size_t& temp_storage_bytes,
KeyIteratorIn1 keys_in1,
Offset num_keys1,
KeyIteratorIn2 keys_in2,
Offset num_keys2,
KeyIteratorOut keys_out,
CompareOp compare_op = {},
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergeKeys");
return cub::detail::merge::dispatch_t<
KeyIteratorIn1,
cub::NullType*,
KeyIteratorIn2,
cub::NullType*,
KeyIteratorOut,
cub::NullType*,
Offset,
CompareOp>::dispatch(d_temp_storage,
temp_storage_bytes,
keys_in1,
nullptr,
num_keys1,
keys_in2,
nullptr,
num_keys2,
keys_out,
nullptr,
compare_op,
stream);
}

template <typename KeyIteratorIn1,
typename ValueIteratorIn1,
typename KeyIteratorIn2,
typename ValueIteratorIn2,
typename KeyIteratorOut,
typename ValueIteratorOut,
typename Offset,
typename CompareOp = ::cuda::std::less<>>
CUB_RUNTIME_FUNCTION static cudaError_t merge_pairs_custom_offset_type(
void* d_temp_storage,
std::size_t& temp_storage_bytes,
KeyIteratorIn1 keys_in1,
ValueIteratorIn1 values_in1,
Offset num_pairs1,
KeyIteratorIn2 keys_in2,
ValueIteratorIn2 values_in2,
Offset num_pairs2,
KeyIteratorOut keys_out,
ValueIteratorOut values_out,
CompareOp compare_op = {},
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergePairs");
return cub::detail::merge::dispatch_t<
KeyIteratorIn1,
ValueIteratorIn1,
KeyIteratorIn2,
ValueIteratorIn2,
KeyIteratorOut,
ValueIteratorOut,
Offset,
CompareOp>::dispatch(d_temp_storage,
temp_storage_bytes,
keys_in1,
values_in1,
num_pairs1,
keys_in2,
values_in2,
num_pairs2,
keys_out,
values_out,
compare_op,
stream);
}
} // namespace detail

DECLARE_LAUNCH_WRAPPER(detail::merge_keys_custom_offset_type, merge_keys_custom_offset_type);
DECLARE_LAUNCH_WRAPPER(detail::merge_pairs_custom_offset_type, merge_pairs_custom_offset_type);

using types = c2h::type_list<std::uint8_t, std::int16_t, std::uint32_t, double>;

// gevtushenko: there is no code path in CUB and Thrust that leads to unsigned offsets, so let's safe some compile time
using offset_types = c2h::type_list<std::int32_t, std::int64_t>;

template <typename Key,
typename Offset,
typename CompareOp = ::cuda::std::less<Key>,
Expand Down Expand Up @@ -223,11 +128,27 @@ C2H_TEST("DeviceMerge::MergeKeys large key types", "[merge][device]", c2h::type_
});
}

C2H_TEST("DeviceMerge::MergeKeys offset types", "[merge][device]", offset_types)
C2H_TEST("DeviceMerge::MergeKeys works for large number of items", "[merge][device]")

try
{
using key_t = char;
using offset_t = int64_t;

// Clamp 64-bit offset type problem sizes to just slightly larger than 2^32 items
const auto num_items_int_max = static_cast<offset_t>(::cuda::std::numeric_limits<std::int32_t>::max());

// Generate the input sizes to test for
const offset_t num_items_lhs =
GENERATE_COPY(values({num_items_int_max + offset_t{1000000}, num_items_int_max - 1, offset_t{3}}));
const offset_t num_items_rhs =
GENERATE_COPY(values({num_items_int_max + offset_t{1000000}, num_items_int_max, offset_t{3}}));

test_keys<key_t, offset_t>(num_items_lhs, num_items_rhs, ::cuda::std::less<>{});
}
catch (const std::bad_alloc&)
{
using key_t = int;
using offset_t = c2h::get<0, TestType>;
test_keys<key_t, offset_t>(3623, 6346, ::cuda::std::less<>{}, merge_keys_custom_offset_type);
// allocation failure is not a test failure, so we can run tests on smaller GPUs
}

C2H_TEST("DeviceMerge::MergeKeys input sizes", "[merge][device]")
Expand Down Expand Up @@ -385,14 +306,6 @@ C2H_TEST("DeviceMerge::MergePairs value types", "[merge][device]", types)
test_pairs<key_t, value_t, offset_t>();
}

C2H_TEST("DeviceMerge::MergePairs offset types", "[merge][device]", offset_types)
{
using key_t = int;
using value_t = int;
using offset_t = c2h::get<0, TestType>;
test_pairs<key_t, value_t, offset_t>(3623, 6346, ::cuda::std::less<>{}, merge_pairs_custom_offset_type);
}

C2H_TEST("DeviceMerge::MergePairs input sizes", "[merge][device]")
{
using key_t = int;
Expand All @@ -410,7 +323,7 @@ try
using key_t = char;
using value_t = char;
const auto size = std::int64_t{1} << GENERATE(30, 31, 32, 33);
test_pairs<key_t, value_t>(size, size, ::cuda::std::less<>{}, merge_pairs_custom_offset_type);
test_pairs<key_t, value_t>(size, size, ::cuda::std::less<>{});
}
catch (const std::bad_alloc&)
{
Expand Down

0 comments on commit c02e845

Please sign in to comment.