Skip to content

Commit

Permalink
Adds support for large number of items and large number of segments t…
Browse files Browse the repository at this point in the history
…o `DeviceSegmentedSort` (NVIDIA#3308)

* fixes segment offset generation

* switches to analytical verification

* switches to analytical verification for pairs

* addresses review comments

* introduces segment offset type

* adds tests for large number of segments

* adds support for large number of segments

* drops segment offset type

* fixes thrust namespace

* removes about-to-be-deprecated cub iterators

* no exec specifier on defaulted ctor

* fixes gcc7 linker error

* uses local_segment_index_t throughout

* determine offset type based on type returned by segment iterator begin/end iterators

* minor style improvements
  • Loading branch information
elstehle authored and davebayer committed Jan 22, 2025
1 parent f08d6a6 commit 5a0094c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 45 deletions.
75 changes: 38 additions & 37 deletions cub/cub/device/dispatch/dispatch_segmented_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@
#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

#include <cuda/cmath>
#include <cuda/std/__algorithm/max.h>
#include <cuda/std/__algorithm/min.h>
#include <cuda/std/type_traits>

#include <type_traits>
Expand All @@ -70,7 +68,10 @@

CUB_NAMESPACE_BEGIN

namespace detail::segmented_sort
namespace detail
{

namespace segmented_sort
{
// Type used to index within segments within a single invocation
using local_segment_index_t = ::cuda::std::uint32_t;
Expand Down Expand Up @@ -107,6 +108,8 @@ _CCCL_HOST_DEVICE OffsetIteratorT<Iterator, OffsetItT> make_offset_iterator(cons
{
return OffsetIteratorT<Iterator, OffsetItT>{it, offset_it};
}
} // namespace segmented_sort
} // namespace detail

/**
* @brief Fallback kernel, in case there's not enough segments to
Expand Down Expand Up @@ -170,7 +173,7 @@ __launch_bounds__(ChainedPolicyT::ActivePolicy::LargeSegmentPolicy::BLOCK_THREAD
using LargeSegmentPolicyT = typename ActivePolicyT::LargeSegmentPolicy;
using MediumPolicyT = typename ActivePolicyT::SmallAndMediumSegmentedSortPolicyT::MediumPolicyT;

const auto segment_id = static_cast<local_segment_index_t>(blockIdx.x);
const auto segment_id = static_cast<detail::segmented_sort::local_segment_index_t>(blockIdx.x);
OffsetT segment_begin = d_begin_offsets[segment_id];
OffsetT segment_end = d_end_offsets[segment_id];
OffsetT num_items = segment_end - segment_begin;
Expand Down Expand Up @@ -334,19 +337,19 @@ template <bool IS_DESCENDING,
typename OffsetT>
__launch_bounds__(ChainedPolicyT::ActivePolicy::SmallAndMediumSegmentedSortPolicyT::BLOCK_THREADS)
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortKernelSmall(
local_segment_index_t small_segments,
local_segment_index_t medium_segments,
local_segment_index_t medium_blocks,
const local_segment_index_t* d_small_segments_indices,
const local_segment_index_t* d_medium_segments_indices,
detail::segmented_sort::local_segment_index_t small_segments,
detail::segmented_sort::local_segment_index_t medium_segments,
detail::segmented_sort::local_segment_index_t medium_blocks,
const detail::segmented_sort::local_segment_index_t* d_small_segments_indices,
const detail::segmented_sort::local_segment_index_t* d_medium_segments_indices,
const KeyT* d_keys_in,
KeyT* d_keys_out,
const ValueT* d_values_in,
ValueT* d_values_out,
BeginOffsetIteratorT d_begin_offsets,
EndOffsetIteratorT d_end_offsets)
{
using local_segment_index_t = local_segment_index_t;
using local_segment_index_t = detail::segmented_sort::local_segment_index_t;

const local_segment_index_t tid = threadIdx.x;
const local_segment_index_t bid = blockIdx.x;
Expand Down Expand Up @@ -458,7 +461,7 @@ template <bool IS_DESCENDING,
typename OffsetT>
__launch_bounds__(ChainedPolicyT::ActivePolicy::LargeSegmentPolicy::BLOCK_THREADS)
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortKernelLarge(
const local_segment_index_t* d_segments_indices,
const detail::segmented_sort::local_segment_index_t* d_segments_indices,
const KeyT* d_keys_in_orig,
KeyT* d_keys_out_orig,
device_double_buffer<KeyT> d_keys_double_buffer,
Expand All @@ -470,7 +473,7 @@ __launch_bounds__(ChainedPolicyT::ActivePolicy::LargeSegmentPolicy::BLOCK_THREAD
{
using ActivePolicyT = typename ChainedPolicyT::ActivePolicy;
using LargeSegmentPolicyT = typename ActivePolicyT::LargeSegmentPolicy;
using local_segment_index_t = local_segment_index_t;
using local_segment_index_t = detail::segmented_sort::local_segment_index_t;

constexpr int small_tile_size = LargeSegmentPolicyT::BLOCK_THREADS * LargeSegmentPolicyT::ITEMS_PER_THREAD;

Expand Down Expand Up @@ -577,12 +580,12 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
device_double_buffer<ValueT> d_values_double_buffer,
BeginOffsetIteratorT d_begin_offsets,
EndOffsetIteratorT d_end_offsets,
local_segment_index_t* group_sizes,
local_segment_index_t* large_and_medium_segments_indices,
local_segment_index_t* small_segments_indices,
detail::segmented_sort::local_segment_index_t* group_sizes,
detail::segmented_sort::local_segment_index_t* large_and_medium_segments_indices,
detail::segmented_sort::local_segment_index_t* small_segments_indices,
cudaStream_t stream)
{
using local_segment_index_t = local_segment_index_t;
using local_segment_index_t = detail::segmented_sort::local_segment_index_t;

cudaError error = cudaSuccess;

Expand Down Expand Up @@ -699,7 +702,7 @@ template <typename ChainedPolicyT,
__launch_bounds__(1) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortContinuationKernel(
LargeKernelT large_kernel,
SmallKernelT small_kernel,
local_segment_index_t num_segments,
detail::segmented_sort::local_segment_index_t num_segments,
KeyT* d_current_keys,
KeyT* d_final_keys,
device_double_buffer<KeyT> d_keys_double_buffer,
Expand All @@ -708,9 +711,9 @@ __launch_bounds__(1) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortContin
device_double_buffer<ValueT> d_values_double_buffer,
BeginOffsetIteratorT d_begin_offsets,
EndOffsetIteratorT d_end_offsets,
local_segment_index_t* group_sizes,
local_segment_index_t* large_and_medium_segments_indices,
local_segment_index_t* small_segments_indices)
detail::segmented_sort::local_segment_index_t* group_sizes,
detail::segmented_sort::local_segment_index_t* large_and_medium_segments_indices,
detail::segmented_sort::local_segment_index_t* small_segments_indices)
{
using ActivePolicyT = typename ChainedPolicyT::ActivePolicy;
using LargeSegmentPolicyT = typename ActivePolicyT::LargeSegmentPolicy;
Expand Down Expand Up @@ -961,7 +964,7 @@ struct DispatchSegmentedSort
constexpr auto num_segments_per_invocation_limit =
static_cast<global_segment_offset_t>(::cuda::std::numeric_limits<int>::max());
auto const max_num_segments_per_invocation = static_cast<global_segment_offset_t>(
(::cuda::std::min)(static_cast<global_segment_offset_t>(num_segments), num_segments_per_invocation_limit));
::cuda::std::min(static_cast<global_segment_offset_t>(num_segments), num_segments_per_invocation_limit));

large_and_medium_segments_indices.grow(max_num_segments_per_invocation);
small_segments_indices.grow(max_num_segments_per_invocation);
Expand Down Expand Up @@ -1078,22 +1081,20 @@ struct DispatchSegmentedSort
// Partition input segments into size groups and assign specialized
// kernels for each of them.
error = SortWithPartitioning<LargeSegmentPolicyT, SmallAndMediumPolicyT>(
detail::segmented_sort::DeviceSegmentedSortKernelLarge<
IS_DESCENDING,
MaxPolicyT,
KeyT,
ValueT,
StreamingBeginOffsetIteratorT,
StreamingEndOffsetIteratorT,
OffsetT>,
detail::segmented_sort::DeviceSegmentedSortKernelSmall<
IS_DESCENDING,
MaxPolicyT,
KeyT,
ValueT,
StreamingBeginOffsetIteratorT,
StreamingEndOffsetIteratorT,
OffsetT>,
DeviceSegmentedSortKernelLarge<IS_DESCENDING,
MaxPolicyT,
KeyT,
ValueT,
StreamingBeginOffsetIteratorT,
StreamingEndOffsetIteratorT,
OffsetT>,
DeviceSegmentedSortKernelSmall<IS_DESCENDING,
MaxPolicyT,
KeyT,
ValueT,
StreamingBeginOffsetIteratorT,
StreamingEndOffsetIteratorT,
OffsetT>,
three_way_partition_temp_storage_bytes,
d_keys_double_buffer,
d_values_double_buffer,
Expand Down
7 changes: 3 additions & 4 deletions cub/test/catch2_segmented_sort_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct segment_index_to_offset_op
OffsetT segment_size;
OffsetT num_items;

_CCCL_HOST_DEVICE __forceinline__ OffsetT operator()(SegmentIndexT i)
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE OffsetT operator()(SegmentIndexT i)
{
if (i < num_empty_segments)
{
Expand All @@ -103,16 +103,15 @@ struct mod_n
std::size_t mod;

template <typename IndexT>
_CCCL_HOST_DEVICE __forceinline__ T operator()(IndexT x)
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE T operator()(IndexT x)
{
return static_cast<T>(x % mod);
}
};

template <typename KeyT>
class short_key_verification_helper
struct short_key_verification_helper
{
private:
using key_t = KeyT;
// The histogram size of the keys being sorted for later verification
const std::int64_t max_histo_size = std::int64_t{1} << ::cuda::std::numeric_limits<key_t>::digits;
Expand Down
2 changes: 0 additions & 2 deletions cub/test/catch2_test_device_segmented_sort_keys.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ C2H_TEST("DeviceSegmentedSortKeys: Unspecified segments, random keys", "[keys][s
test_unspecified_segments_random<KeyT>(C2H_SEED(4));
}

#if defined(CCCL_TEST_ENABLE_LARGE_SEGMENTED_SORT)

C2H_TEST("DeviceSegmentedSortKeys: very large number of segments", "[keys][segmented][sort][device]", all_offset_types)
try
{
Expand Down
2 changes: 0 additions & 2 deletions cub/test/catch2_test_device_segmented_sort_pairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,6 @@ C2H_TEST("DeviceSegmentedSortPairs: Unspecified segments, random key/values",
test_unspecified_segments_random<KeyT, ValueT>(C2H_SEED(4));
}

#if defined(CCCL_TEST_ENABLE_LARGE_SEGMENTED_SORT)

C2H_TEST("DeviceSegmentedSortPairs: very large num. items and num. segments",
"[pairs][segmented][sort][device]",
all_offset_types)
Expand Down

0 comments on commit 5a0094c

Please sign in to comment.