Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for large num items to DeviceMerge #3530

Merged
merged 6 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions cub/cub/device/device_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,19 @@ struct DeviceMerge
void* d_temp_storage,
std::size_t& temp_storage_bytes,
KeyIteratorIn1 keys_in1,
int num_keys1,
::cuda::std::int64_t num_keys1,
KeyIteratorIn2 keys_in2,
int num_keys2,
::cuda::std::int64_t num_keys2,
KeyIteratorOut keys_out,
CompareOp compare_op = {},
cudaStream_t stream = nullptr)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergeKeys");

using offset_t = ::cuda::std::int64_t;

return detail::merge::
dispatch_t<KeyIteratorIn1, NullType*, KeyIteratorIn2, NullType*, KeyIteratorOut, NullType*, int, CompareOp>::
dispatch_t<KeyIteratorIn1, NullType*, KeyIteratorIn2, NullType*, KeyIteratorOut, NullType*, offset_t, CompareOp>::
dispatch(
d_temp_storage,
temp_storage_bytes,
Expand Down Expand Up @@ -161,24 +164,27 @@ struct DeviceMerge
std::size_t& temp_storage_bytes,
KeyIteratorIn1 keys_in1,
ValueIteratorIn1 values_in1,
int num_pairs1,
::cuda::std::int64_t num_pairs1,
KeyIteratorIn2 keys_in2,
ValueIteratorIn2 values_in2,
int num_pairs2,
::cuda::std::int64_t num_pairs2,
KeyIteratorOut keys_out,
ValueIteratorOut values_out,
CompareOp compare_op = {},
cudaStream_t stream = nullptr)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergePairs");

using offset_t = ::cuda::std::int64_t;

return detail::merge::dispatch_t<
KeyIteratorIn1,
ValueIteratorIn1,
KeyIteratorIn2,
ValueIteratorIn2,
KeyIteratorOut,
ValueIteratorOut,
int,
offset_t,
CompareOp>::dispatch(d_temp_storage,
temp_storage_bytes,
keys_in1,
Expand Down
226 changes: 70 additions & 156 deletions cub/test/catch2_test_device_merge.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,103 +20,8 @@
DECLARE_LAUNCH_WRAPPER(cub::DeviceMerge::MergePairs, merge_pairs);
DECLARE_LAUNCH_WRAPPER(cub::DeviceMerge::MergeKeys, merge_keys);

// TODO(bgruber): replace the following by the CUB device API directly, once we have figured out how to handle different
// offset types
namespace detail
{
template <typename KeyIteratorIn1,
typename KeyIteratorIn2,
typename KeyIteratorOut,
typename Offset,
typename CompareOp = ::cuda::std::less<>>
CUB_RUNTIME_FUNCTION static cudaError_t merge_keys_custom_offset_type(
void* d_temp_storage,
std::size_t& temp_storage_bytes,
KeyIteratorIn1 keys_in1,
Offset num_keys1,
KeyIteratorIn2 keys_in2,
Offset num_keys2,
KeyIteratorOut keys_out,
CompareOp compare_op = {},
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergeKeys");
return cub::detail::merge::dispatch_t<
KeyIteratorIn1,
cub::NullType*,
KeyIteratorIn2,
cub::NullType*,
KeyIteratorOut,
cub::NullType*,
Offset,
CompareOp>::dispatch(d_temp_storage,
temp_storage_bytes,
keys_in1,
nullptr,
num_keys1,
keys_in2,
nullptr,
num_keys2,
keys_out,
nullptr,
compare_op,
stream);
}

template <typename KeyIteratorIn1,
typename ValueIteratorIn1,
typename KeyIteratorIn2,
typename ValueIteratorIn2,
typename KeyIteratorOut,
typename ValueIteratorOut,
typename Offset,
typename CompareOp = ::cuda::std::less<>>
CUB_RUNTIME_FUNCTION static cudaError_t merge_pairs_custom_offset_type(
void* d_temp_storage,
std::size_t& temp_storage_bytes,
KeyIteratorIn1 keys_in1,
ValueIteratorIn1 values_in1,
Offset num_pairs1,
KeyIteratorIn2 keys_in2,
ValueIteratorIn2 values_in2,
Offset num_pairs2,
KeyIteratorOut keys_out,
ValueIteratorOut values_out,
CompareOp compare_op = {},
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergePairs");
return cub::detail::merge::dispatch_t<
KeyIteratorIn1,
ValueIteratorIn1,
KeyIteratorIn2,
ValueIteratorIn2,
KeyIteratorOut,
ValueIteratorOut,
Offset,
CompareOp>::dispatch(d_temp_storage,
temp_storage_bytes,
keys_in1,
values_in1,
num_pairs1,
keys_in2,
values_in2,
num_pairs2,
keys_out,
values_out,
compare_op,
stream);
}
} // namespace detail

DECLARE_LAUNCH_WRAPPER(detail::merge_keys_custom_offset_type, merge_keys_custom_offset_type);
DECLARE_LAUNCH_WRAPPER(detail::merge_pairs_custom_offset_type, merge_pairs_custom_offset_type);

using types = c2h::type_list<std::uint8_t, std::int16_t, std::uint32_t, double>;

// gevtushenko: there is no code path in CUB and Thrust that leads to unsigned offsets, so let's safe some compile time
using offset_types = c2h::type_list<std::int32_t, std::int64_t>;

template <typename Key,
typename Offset,
typename CompareOp = ::cuda::std::less<Key>,
Expand Down Expand Up @@ -172,62 +77,79 @@ struct fallback_test_policy_hub
};

// TODO(bgruber): This test alone increases compile time from 1m16s to 8m43s. What's going on?
C2H_TEST("DeviceMerge::MergeKeys large key types", "[merge][device]", c2h::type_list<large_type_vsmem, large_type_fallb>)
// C2H_TEST("DeviceMerge::MergeKeys large key types", "[merge][device]", c2h::type_list<large_type_vsmem,
// large_type_fallb>)
// {
elstehle marked this conversation as resolved.
Show resolved Hide resolved
// using key_t = c2h::get<0, TestType>;
// using offset_t = int;

// constexpr auto agent_sm = sizeof(key_t) * 128 * 7;
// constexpr auto fallback_sm =
// sizeof(key_t) * cub::detail::merge::fallback_BLOCK_THREADS * cub::detail::merge::fallback_ITEMS_PER_THREAD;
// static_assert(agent_sm > cub::detail::max_smem_per_block,
// "key_t is not big enough to exceed SM and trigger fallback policy");
// static_assert(
// ::cuda::std::is_same<key_t, large_type_fallb>::value == (fallback_sm <= cub::detail::max_smem_per_block),
// "SM consumption by fallback policy should fit into max_smem_per_block");

// test_keys<key_t, offset_t>(
// 3623,
// 6346,
// ::cuda::std::less<key_t>{},
// [](const key_t* k1, offset_t s1, const key_t* k2, offset_t s2, key_t* r, ::cuda::std::less<key_t> co) {
// using dispatch_t = cub::detail::merge::dispatch_t<
// const key_t*,
// const cub::NullType*,
// const key_t*,
// const cub::NullType*,
// key_t*,
// cub::NullType*,
// offset_t,
// ::cuda::std::less<key_t>,
// fallback_test_policy_hub>; // use a fixed policy for this test so the needed shared memory is deterministic

// std::size_t temp_storage_bytes = 0;
// dispatch_t::dispatch(
// nullptr, temp_storage_bytes, k1, nullptr, s1, k2, nullptr, s2, r, nullptr, co, cudaStream_t{0});

// c2h::device_vector<char> temp_storage(temp_storage_bytes);
// dispatch_t::dispatch(
// thrust::raw_pointer_cast(temp_storage.data()),
// temp_storage_bytes,
// k1,
// nullptr,
// s1,
// k2,
// nullptr,
// s2,
// r,
// nullptr,
// co,
// cudaStream_t{0});
// });
// }

C2H_TEST("DeviceMerge::MergeKeys offset types", "[merge][device]")
elstehle marked this conversation as resolved.
Show resolved Hide resolved

try
{
using key_t = c2h::get<0, TestType>;
using offset_t = int;
using key_t = char;
using offset_t = int64_t;

constexpr auto agent_sm = sizeof(key_t) * 128 * 7;
constexpr auto fallback_sm =
sizeof(key_t) * cub::detail::merge::fallback_BLOCK_THREADS * cub::detail::merge::fallback_ITEMS_PER_THREAD;
static_assert(agent_sm > cub::detail::max_smem_per_block,
"key_t is not big enough to exceed SM and trigger fallback policy");
static_assert(
::cuda::std::is_same<key_t, large_type_fallb>::value == (fallback_sm <= cub::detail::max_smem_per_block),
"SM consumption by fallback policy should fit into max_smem_per_block");

test_keys<key_t, offset_t>(
3623,
6346,
::cuda::std::less<key_t>{},
[](const key_t* k1, offset_t s1, const key_t* k2, offset_t s2, key_t* r, ::cuda::std::less<key_t> co) {
using dispatch_t = cub::detail::merge::dispatch_t<
const key_t*,
const cub::NullType*,
const key_t*,
const cub::NullType*,
key_t*,
cub::NullType*,
offset_t,
::cuda::std::less<key_t>,
fallback_test_policy_hub>; // use a fixed policy for this test so the needed shared memory is deterministic

std::size_t temp_storage_bytes = 0;
dispatch_t::dispatch(
nullptr, temp_storage_bytes, k1, nullptr, s1, k2, nullptr, s2, r, nullptr, co, cudaStream_t{0});

c2h::device_vector<char> temp_storage(temp_storage_bytes);
dispatch_t::dispatch(
thrust::raw_pointer_cast(temp_storage.data()),
temp_storage_bytes,
k1,
nullptr,
s1,
k2,
nullptr,
s2,
r,
nullptr,
co,
cudaStream_t{0});
});
}
// Clamp 64-bit offset type problem sizes to just slightly larger than 2^32 items
const auto num_items_int_max = static_cast<offset_t>(::cuda::std::numeric_limits<std::int32_t>::max());

C2H_TEST("DeviceMerge::MergeKeys offset types", "[merge][device]", offset_types)
// Generate the input sizes to test for
const offset_t num_items_lhs =
GENERATE_COPY(values({num_items_int_max + offset_t{1000000}, num_items_int_max - 1, offset_t{3}}));
const offset_t num_items_rhs =
GENERATE_COPY(values({num_items_int_max + offset_t{1000000}, num_items_int_max, offset_t{3}}));

test_keys<key_t, offset_t>(num_items_lhs, num_items_rhs, ::cuda::std::less<>{});
}
catch (const std::bad_alloc&)
{
using key_t = int;
using offset_t = c2h::get<0, TestType>;
test_keys<key_t, offset_t>(3623, 6346, ::cuda::std::less<>{}, merge_keys_custom_offset_type);
// allocation failure is not a test failure, so we can run tests on smaller GPUs
}

C2H_TEST("DeviceMerge::MergeKeys input sizes", "[merge][device]")
Expand Down Expand Up @@ -385,14 +307,6 @@ C2H_TEST("DeviceMerge::MergePairs value types", "[merge][device]", types)
test_pairs<key_t, value_t, offset_t>();
}

C2H_TEST("DeviceMerge::MergePairs offset types", "[merge][device]", offset_types)
{
using key_t = int;
using value_t = int;
using offset_t = c2h::get<0, TestType>;
test_pairs<key_t, value_t, offset_t>(3623, 6346, ::cuda::std::less<>{}, merge_pairs_custom_offset_type);
}

C2H_TEST("DeviceMerge::MergePairs input sizes", "[merge][device]")
{
using key_t = int;
Expand All @@ -410,7 +324,7 @@ try
using key_t = char;
using value_t = char;
const auto size = std::int64_t{1} << GENERATE(30, 31, 32, 33);
test_pairs<key_t, value_t>(size, size, ::cuda::std::less<>{}, merge_pairs_custom_offset_type);
test_pairs<key_t, value_t>(size, size, ::cuda::std::less<>{});
}
catch (const std::bad_alloc&)
{
Expand Down
Loading