Skip to content

Commit

Permalink
Adds support for large number of items to DeviceScan::*ByKey family…
Browse files Browse the repository at this point in the history
… of algorithms (#2477)

* experimenting with bool tile state

* fixes perf regression from different tile state

* fixes support for large offset types

* adapts interface for scanbykey

* adds tests for large number of items for scanbykey

* fixes naming

* makes thrust scan_by_key use unsigned offset types

* moves scan_by_key_op to detail ns
  • Loading branch information
elstehle authored Oct 8, 2024
1 parent 16f9a1a commit 951c822
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 73 deletions.
33 changes: 16 additions & 17 deletions cub/cub/agent/agent_scan_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,10 @@ struct AgentScanByKey

using KeyT = cub::detail::value_t<KeysInputIteratorT>;
using InputT = cub::detail::value_t<ValuesInputIteratorT>;
using SizeValuePairT = KeyValuePair<OffsetT, AccumT>;
using KeyValuePairT = KeyValuePair<KeyT, AccumT>;
using ReduceBySegmentOpT = ReduceBySegmentOp<ScanOpT>;
using FlagValuePairT = KeyValuePair<int, AccumT>;
using ReduceBySegmentOpT = detail::ScanBySegmentOp<ScanOpT>;

using ScanTileStateT = ReduceByKeyScanTileState<AccumT, OffsetT>;
using ScanTileStateT = ReduceByKeyScanTileState<AccumT, int>;

// Constants
// Inclusive scan if no init_value type is provided
Expand Down Expand Up @@ -175,9 +174,9 @@ struct AgentScanByKey

using DelayConstructorT = typename AgentScanByKeyPolicyT::detail::delay_constructor_t;
using TilePrefixCallbackT =
TilePrefixCallbackOp<SizeValuePairT, ReduceBySegmentOpT, ScanTileStateT, 0, DelayConstructorT>;
TilePrefixCallbackOp<FlagValuePairT, ReduceBySegmentOpT, ScanTileStateT, 0, DelayConstructorT>;

using BlockScanT = BlockScan<SizeValuePairT, BLOCK_THREADS, AgentScanByKeyPolicyT::SCAN_ALGORITHM, 1, 1>;
using BlockScanT = BlockScan<FlagValuePairT, BLOCK_THREADS, AgentScanByKeyPolicyT::SCAN_ALGORITHM, 1, 1>;

union TempStorage_
{
Expand Down Expand Up @@ -216,14 +215,14 @@ struct AgentScanByKey

// Exclusive scan specialization
_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTile(
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], SizeValuePairT& tile_aggregate, Int2Type<false> /* is_inclusive */)
FlagValuePairT (&scan_items)[ITEMS_PER_THREAD], FlagValuePairT& tile_aggregate, Int2Type<false> /* is_inclusive */)
{
BlockScanT(storage.scan_storage.scan).ExclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate);
}

// Inclusive scan specialization
_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTile(
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], SizeValuePairT& tile_aggregate, Int2Type<true> /* is_inclusive */)
FlagValuePairT (&scan_items)[ITEMS_PER_THREAD], FlagValuePairT& tile_aggregate, Int2Type<true> /* is_inclusive */)
{
BlockScanT(storage.scan_storage.scan).InclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate);
}
Expand All @@ -234,8 +233,8 @@ struct AgentScanByKey

// Exclusive scan specialization (with prefix from predecessors)
_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTile(
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
SizeValuePairT& tile_aggregate,
FlagValuePairT (&scan_items)[ITEMS_PER_THREAD],
FlagValuePairT& tile_aggregate,
TilePrefixCallbackT& prefix_op,
Int2Type<false> /* is_inclusive */)
{
Expand All @@ -245,8 +244,8 @@ struct AgentScanByKey

// Inclusive scan specialization (with prefix from predecessors)
_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTile(
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
SizeValuePairT& tile_aggregate,
FlagValuePairT (&scan_items)[ITEMS_PER_THREAD],
FlagValuePairT& tile_aggregate,
TilePrefixCallbackT& prefix_op,
Int2Type<true> /* is_inclusive */)
{
Expand All @@ -263,7 +262,7 @@ struct AgentScanByKey
OffsetT num_remaining,
AccumT (&values)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD])
FlagValuePairT (&scan_items)[ITEMS_PER_THREAD])
{
// Zip values and segment_flags
#pragma unroll
Expand All @@ -281,7 +280,7 @@ struct AgentScanByKey
}

_CCCL_DEVICE _CCCL_FORCEINLINE void
UnzipValues(AccumT (&values)[ITEMS_PER_THREAD], SizeValuePairT (&scan_items)[ITEMS_PER_THREAD])
UnzipValues(AccumT (&values)[ITEMS_PER_THREAD], FlagValuePairT (&scan_items)[ITEMS_PER_THREAD])
{
// Unzip values and segment_flags
#pragma unroll
Expand Down Expand Up @@ -321,7 +320,7 @@ struct AgentScanByKey
KeyT keys[ITEMS_PER_THREAD];
AccumT values[ITEMS_PER_THREAD];
OffsetT segment_flags[ITEMS_PER_THREAD];
SizeValuePairT scan_items[ITEMS_PER_THREAD];
FlagValuePairT scan_items[ITEMS_PER_THREAD];

if (IS_LAST_TILE)
{
Expand Down Expand Up @@ -359,7 +358,7 @@ struct AgentScanByKey
ZipValuesAndFlags<IS_LAST_TILE>(num_remaining, values, segment_flags, scan_items);

// Exclusive scan of values and segment_flags
SizeValuePairT tile_aggregate;
FlagValuePairT tile_aggregate;
ScanTile(scan_items, tile_aggregate, Int2Type<IS_INCLUSIVE>());

if (threadIdx.x == 0)
Expand All @@ -382,7 +381,7 @@ struct AgentScanByKey
// Zip values and segment_flags
ZipValuesAndFlags<IS_LAST_TILE>(num_remaining, values, segment_flags, scan_items);

SizeValuePairT tile_aggregate;
FlagValuePairT tile_aggregate;
TilePrefixCallbackT prefix_op(tile_state, storage.scan_storage.prefix, pair_scan_op, tile_idx);
ScanTile(scan_items, tile_aggregate, prefix_op, Int2Type<IS_INCLUSIVE>());
}
Expand Down
68 changes: 44 additions & 24 deletions cub/cub/device/device_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,9 @@ struct DeviceScan
//! **[inferred]** Functor type having member
//! `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys
//!
//! @tparam NumItemsT
//! **[inferred]** An integral type representing the number of input elements
//!
//! @param[in] d_temp_storage
//! Device-accessible allocation of temporary storage. When `nullptr`, the
//! required allocation size is written to `temp_storage_bytes` and no work is done.
Expand Down Expand Up @@ -1558,21 +1561,22 @@ struct DeviceScan
template <typename KeysInputIteratorT,
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_RUNTIME_FUNCTION static cudaError_t ExclusiveSumByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op = EqualityOpT(),
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceScan::ExclusiveSumByKey");

// Signed integer type for global offsets
using OffsetT = int;
// Unsigned integer type for global offsets
using OffsetT = detail::choose_offset_t<NumItemsT>;
using InitT = cub::detail::value_t<ValuesInputIteratorT>;

// Initial value
Expand Down Expand Up @@ -1601,14 +1605,15 @@ struct DeviceScan
template <typename KeysInputIteratorT,
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION static cudaError_t ExclusiveSumByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op,
cudaStream_t stream,
bool debug_synchronous)
Expand Down Expand Up @@ -1721,6 +1726,9 @@ struct DeviceScan
//! **[inferred]** Functor type having member
//! `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys
//!
//! @tparam NumItemsT
//! **[inferred]** An integral type representing the number of input elements
//!
//! @param[in] d_temp_storage
//! Device-accessible allocation of temporary storage. When `nullptr`, the
//! required allocation size is written to `temp_storage_bytes` and no work is done.
Expand Down Expand Up @@ -1761,7 +1769,8 @@ struct DeviceScan
typename ValuesOutputIteratorT,
typename ScanOpT,
typename InitValueT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_RUNTIME_FUNCTION static cudaError_t ExclusiveScanByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
Expand All @@ -1770,14 +1779,14 @@ struct DeviceScan
ValuesOutputIteratorT d_values_out,
ScanOpT scan_op,
InitValueT init_value,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op = EqualityOpT(),
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceScan::ExclusiveScanByKey");

// Signed integer type for global offsets
using OffsetT = int;
// Unsigned integer type for global offsets
using OffsetT = detail::choose_offset_t<NumItemsT>;

return DispatchScanByKey<
KeysInputIteratorT,
Expand All @@ -1804,7 +1813,8 @@ struct DeviceScan
typename ValuesOutputIteratorT,
typename ScanOpT,
typename InitValueT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION static cudaError_t ExclusiveScanByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
Expand All @@ -1813,7 +1823,7 @@ struct DeviceScan
ValuesOutputIteratorT d_values_out,
ScanOpT scan_op,
InitValueT init_value,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op,
cudaStream_t stream,
bool debug_synchronous)
Expand Down Expand Up @@ -1904,6 +1914,9 @@ struct DeviceScan
//! **[inferred]** Functor type having member
//! `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys
//!
//! @tparam NumItemsT
//! **[inferred]** An integral type representing the number of input elements
//!
//! @param[in] d_temp_storage
//! Device-accessible allocation of temporary storage.
//! When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done.
Expand Down Expand Up @@ -1934,21 +1947,22 @@ struct DeviceScan
template <typename KeysInputIteratorT,
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_RUNTIME_FUNCTION static cudaError_t InclusiveSumByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op = EqualityOpT(),
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceScan::InclusiveSumByKey");

// Signed integer type for global offsets
using OffsetT = int;
// Unsigned integer type for global offsets
using OffsetT = detail::choose_offset_t<NumItemsT>;

return DispatchScanByKey<
KeysInputIteratorT,
Expand All @@ -1973,14 +1987,15 @@ struct DeviceScan
template <typename KeysInputIteratorT,
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION static cudaError_t InclusiveSumByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op,
cudaStream_t stream,
bool debug_synchronous)
Expand Down Expand Up @@ -2084,6 +2099,9 @@ struct DeviceScan
//! **[inferred]** Functor type having member
//! `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys
//!
//! @tparam NumItemsT
//! **[inferred]** An integral type representing the number of input elements
//!
//! @param[in] d_temp_storage
//! Device-accessible allocation of temporary storage.
//! When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done.
Expand Down Expand Up @@ -2118,22 +2136,23 @@ struct DeviceScan
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename ScanOpT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_RUNTIME_FUNCTION static cudaError_t InclusiveScanByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
ScanOpT scan_op,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op = EqualityOpT(),
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceScan::InclusiveScanByKey");

// Signed integer type for global offsets
using OffsetT = int;
// Unsigned integer type for global offsets
using OffsetT = detail::choose_offset_t<NumItemsT>;

return DispatchScanByKey<
KeysInputIteratorT,
Expand All @@ -2159,15 +2178,16 @@ struct DeviceScan
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename ScanOpT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION static cudaError_t InclusiveScanByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
ScanOpT scan_op,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op,
cudaStream_t stream,
bool debug_synchronous)
Expand Down
Loading

0 comments on commit 951c822

Please sign in to comment.