diff --git a/cub/cub/device/device_merge.cuh b/cub/cub/device/device_merge.cuh index 7135546a0e6..814bad75248 100644 --- a/cub/cub/device/device_merge.cuh +++ b/cub/cub/device/device_merge.cuh @@ -76,16 +76,19 @@ struct DeviceMerge void* d_temp_storage, std::size_t& temp_storage_bytes, KeyIteratorIn1 keys_in1, - int num_keys1, + ::cuda::std::int64_t num_keys1, KeyIteratorIn2 keys_in2, - int num_keys2, + ::cuda::std::int64_t num_keys2, KeyIteratorOut keys_out, CompareOp compare_op = {}, cudaStream_t stream = nullptr) { CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergeKeys"); + + using offset_t = ::cuda::std::int64_t; + return detail::merge:: - dispatch_t:: + dispatch_t:: dispatch( d_temp_storage, temp_storage_bytes, @@ -161,16 +164,19 @@ struct DeviceMerge std::size_t& temp_storage_bytes, KeyIteratorIn1 keys_in1, ValueIteratorIn1 values_in1, - int num_pairs1, + ::cuda::std::int64_t num_pairs1, KeyIteratorIn2 keys_in2, ValueIteratorIn2 values_in2, - int num_pairs2, + ::cuda::std::int64_t num_pairs2, KeyIteratorOut keys_out, ValueIteratorOut values_out, CompareOp compare_op = {}, cudaStream_t stream = nullptr) { CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergePairs"); + + using offset_t = ::cuda::std::int64_t; + return detail::merge::dispatch_t< KeyIteratorIn1, ValueIteratorIn1, @@ -178,7 +184,7 @@ struct DeviceMerge ValueIteratorIn2, KeyIteratorOut, ValueIteratorOut, - int, + offset_t, CompareOp>::dispatch(d_temp_storage, temp_storage_bytes, keys_in1, diff --git a/cub/test/catch2_test_device_merge.cu b/cub/test/catch2_test_device_merge.cu index ae0d3f84baa..4835f597710 100644 --- a/cub/test/catch2_test_device_merge.cu +++ b/cub/test/catch2_test_device_merge.cu @@ -20,103 +20,8 @@ DECLARE_LAUNCH_WRAPPER(cub::DeviceMerge::MergePairs, merge_pairs); DECLARE_LAUNCH_WRAPPER(cub::DeviceMerge::MergeKeys, merge_keys); -// TODO(bgruber): replace the following by the CUB device API directly, once we have figured out how to handle different -// offset types -namespace detail -{ -template > -CUB_RUNTIME_FUNCTION static cudaError_t merge_keys_custom_offset_type( - void* d_temp_storage, - std::size_t& temp_storage_bytes, - KeyIteratorIn1 keys_in1, - Offset num_keys1, - KeyIteratorIn2 keys_in2, - Offset num_keys2, - KeyIteratorOut keys_out, - CompareOp compare_op = {}, - cudaStream_t stream = 0) -{ - CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergeKeys"); - return cub::detail::merge::dispatch_t< - KeyIteratorIn1, - cub::NullType*, - KeyIteratorIn2, - cub::NullType*, - KeyIteratorOut, - cub::NullType*, - Offset, - CompareOp>::dispatch(d_temp_storage, - temp_storage_bytes, - keys_in1, - nullptr, - num_keys1, - keys_in2, - nullptr, - num_keys2, - keys_out, - nullptr, - compare_op, - stream); -} - -template > -CUB_RUNTIME_FUNCTION static cudaError_t merge_pairs_custom_offset_type( - void* d_temp_storage, - std::size_t& temp_storage_bytes, - KeyIteratorIn1 keys_in1, - ValueIteratorIn1 values_in1, - Offset num_pairs1, - KeyIteratorIn2 keys_in2, - ValueIteratorIn2 values_in2, - Offset num_pairs2, - KeyIteratorOut keys_out, - ValueIteratorOut values_out, - CompareOp compare_op = {}, - cudaStream_t stream = 0) -{ - CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceMerge::MergePairs"); - return cub::detail::merge::dispatch_t< - KeyIteratorIn1, - ValueIteratorIn1, - KeyIteratorIn2, - ValueIteratorIn2, - KeyIteratorOut, - ValueIteratorOut, - Offset, - CompareOp>::dispatch(d_temp_storage, - temp_storage_bytes, - keys_in1, - values_in1, - num_pairs1, - keys_in2, - values_in2, - num_pairs2, - keys_out, - values_out, - compare_op, - stream); -} -} // namespace detail - -DECLARE_LAUNCH_WRAPPER(detail::merge_keys_custom_offset_type, merge_keys_custom_offset_type); -DECLARE_LAUNCH_WRAPPER(detail::merge_pairs_custom_offset_type, merge_pairs_custom_offset_type); - using types = c2h::type_list; -// gevtushenko: there is no code path in CUB and Thrust that leads to unsigned offsets, so let's safe some compile time -using offset_types = c2h::type_list; - template , @@ -223,11 +128,27 @@ C2H_TEST("DeviceMerge::MergeKeys large key types", "[merge][device]", c2h::type_ }); } -C2H_TEST("DeviceMerge::MergeKeys offset types", "[merge][device]", offset_types) +C2H_TEST("DeviceMerge::MergeKeys works for large number of items", "[merge][device]") + +try +{ + using key_t = char; + using offset_t = int64_t; + + // Clamp 64-bit offset type problem sizes to just slightly larger than 2^32 items + const auto num_items_int_max = static_cast(::cuda::std::numeric_limits::max()); + + // Generate the input sizes to test for + const offset_t num_items_lhs = + GENERATE_COPY(values({num_items_int_max + offset_t{1000000}, num_items_int_max - 1, offset_t{3}})); + const offset_t num_items_rhs = + GENERATE_COPY(values({num_items_int_max + offset_t{1000000}, num_items_int_max, offset_t{3}})); + + test_keys(num_items_lhs, num_items_rhs, ::cuda::std::less<>{}); +} +catch (const std::bad_alloc&) { - using key_t = int; - using offset_t = c2h::get<0, TestType>; - test_keys(3623, 6346, ::cuda::std::less<>{}, merge_keys_custom_offset_type); + // allocation failure is not a test failure, so we can run tests on smaller GPUs } C2H_TEST("DeviceMerge::MergeKeys input sizes", "[merge][device]") @@ -385,14 +306,6 @@ C2H_TEST("DeviceMerge::MergePairs value types", "[merge][device]", types) test_pairs(); } -C2H_TEST("DeviceMerge::MergePairs offset types", "[merge][device]", offset_types) -{ - using key_t = int; - using value_t = int; - using offset_t = c2h::get<0, TestType>; - test_pairs(3623, 6346, ::cuda::std::less<>{}, merge_pairs_custom_offset_type); -} - C2H_TEST("DeviceMerge::MergePairs input sizes", "[merge][device]") { using key_t = int; @@ -410,7 +323,7 @@ try using key_t = char; using value_t = char; const auto size = std::int64_t{1} << GENERATE(30, 31, 32, 33); - test_pairs(size, size, ::cuda::std::less<>{}, merge_pairs_custom_offset_type); + test_pairs(size, size, ::cuda::std::less<>{}); } catch (const std::bad_alloc&) {