Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for large number of items to DeviceScan::*ByKey family of algorithms #2477

Merged
merged 9 commits into from
Oct 8, 2024
33 changes: 16 additions & 17 deletions cub/cub/agent/agent_scan_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,10 @@ struct AgentScanByKey

using KeyT = cub::detail::value_t<KeysInputIteratorT>;
using InputT = cub::detail::value_t<ValuesInputIteratorT>;
using SizeValuePairT = KeyValuePair<OffsetT, AccumT>;
using KeyValuePairT = KeyValuePair<KeyT, AccumT>;
using ReduceBySegmentOpT = ReduceBySegmentOp<ScanOpT>;
using FlagValuePairT = KeyValuePair<int, AccumT>;
using ReduceBySegmentOpT = ScanBySegmentOp<ScanOpT>;

using ScanTileStateT = ReduceByKeyScanTileState<AccumT, OffsetT>;
using ScanTileStateT = ReduceByKeyScanTileState<AccumT, int>;
Copy link
Collaborator Author

@elstehle elstehle Sep 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had tried making this bool but performance dropped for some workloads by 20%. Using int with the logical or, |, operator conserved both the semantics and performance, and made the algorithms performance almost agnostic to the offset type.


// Constants
// Inclusive scan if no init_value type is provided
Expand Down Expand Up @@ -175,9 +174,9 @@ struct AgentScanByKey

using DelayConstructorT = typename AgentScanByKeyPolicyT::detail::delay_constructor_t;
using TilePrefixCallbackT =
TilePrefixCallbackOp<SizeValuePairT, ReduceBySegmentOpT, ScanTileStateT, 0, DelayConstructorT>;
TilePrefixCallbackOp<FlagValuePairT, ReduceBySegmentOpT, ScanTileStateT, 0, DelayConstructorT>;

using BlockScanT = BlockScan<SizeValuePairT, BLOCK_THREADS, AgentScanByKeyPolicyT::SCAN_ALGORITHM, 1, 1>;
using BlockScanT = BlockScan<FlagValuePairT, BLOCK_THREADS, AgentScanByKeyPolicyT::SCAN_ALGORITHM, 1, 1>;

union TempStorage_
{
Expand Down Expand Up @@ -216,14 +215,14 @@ struct AgentScanByKey

// Exclusive scan specialization
_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTile(
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], SizeValuePairT& tile_aggregate, Int2Type<false> /* is_inclusive */)
FlagValuePairT (&scan_items)[ITEMS_PER_THREAD], FlagValuePairT& tile_aggregate, Int2Type<false> /* is_inclusive */)
{
BlockScanT(storage.scan_storage.scan).ExclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate);
}

// Inclusive scan specialization
_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTile(
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], SizeValuePairT& tile_aggregate, Int2Type<true> /* is_inclusive */)
FlagValuePairT (&scan_items)[ITEMS_PER_THREAD], FlagValuePairT& tile_aggregate, Int2Type<true> /* is_inclusive */)
{
BlockScanT(storage.scan_storage.scan).InclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate);
}
Expand All @@ -234,8 +233,8 @@ struct AgentScanByKey

// Exclusive scan specialization (with prefix from predecessors)
_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTile(
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
SizeValuePairT& tile_aggregate,
FlagValuePairT (&scan_items)[ITEMS_PER_THREAD],
FlagValuePairT& tile_aggregate,
TilePrefixCallbackT& prefix_op,
Int2Type<false> /* is_inclusive */)
{
Expand All @@ -245,8 +244,8 @@ struct AgentScanByKey

// Inclusive scan specialization (with prefix from predecessors)
_CCCL_DEVICE _CCCL_FORCEINLINE void ScanTile(
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
SizeValuePairT& tile_aggregate,
FlagValuePairT (&scan_items)[ITEMS_PER_THREAD],
FlagValuePairT& tile_aggregate,
TilePrefixCallbackT& prefix_op,
Int2Type<true> /* is_inclusive */)
{
Expand All @@ -263,7 +262,7 @@ struct AgentScanByKey
OffsetT num_remaining,
AccumT (&values)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
SizeValuePairT (&scan_items)[ITEMS_PER_THREAD])
FlagValuePairT (&scan_items)[ITEMS_PER_THREAD])
{
// Zip values and segment_flags
#pragma unroll
Expand All @@ -281,7 +280,7 @@ struct AgentScanByKey
}

_CCCL_DEVICE _CCCL_FORCEINLINE void
UnzipValues(AccumT (&values)[ITEMS_PER_THREAD], SizeValuePairT (&scan_items)[ITEMS_PER_THREAD])
UnzipValues(AccumT (&values)[ITEMS_PER_THREAD], FlagValuePairT (&scan_items)[ITEMS_PER_THREAD])
{
// Unzip values and segment_flags
#pragma unroll
Expand Down Expand Up @@ -321,7 +320,7 @@ struct AgentScanByKey
KeyT keys[ITEMS_PER_THREAD];
AccumT values[ITEMS_PER_THREAD];
OffsetT segment_flags[ITEMS_PER_THREAD];
SizeValuePairT scan_items[ITEMS_PER_THREAD];
FlagValuePairT scan_items[ITEMS_PER_THREAD];

if (IS_LAST_TILE)
{
Expand Down Expand Up @@ -359,7 +358,7 @@ struct AgentScanByKey
ZipValuesAndFlags<IS_LAST_TILE>(num_remaining, values, segment_flags, scan_items);

// Exclusive scan of values and segment_flags
SizeValuePairT tile_aggregate;
FlagValuePairT tile_aggregate;
ScanTile(scan_items, tile_aggregate, Int2Type<IS_INCLUSIVE>());

if (threadIdx.x == 0)
Expand All @@ -382,7 +381,7 @@ struct AgentScanByKey
// Zip values and segment_flags
ZipValuesAndFlags<IS_LAST_TILE>(num_remaining, values, segment_flags, scan_items);

SizeValuePairT tile_aggregate;
FlagValuePairT tile_aggregate;
TilePrefixCallbackT prefix_op(tile_state, storage.scan_storage.prefix, pair_scan_op, tile_idx);
ScanTile(scan_items, tile_aggregate, prefix_op, Int2Type<IS_INCLUSIVE>());
}
Expand Down
68 changes: 44 additions & 24 deletions cub/cub/device/device_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,9 @@ struct DeviceScan
//! **[inferred]** Functor type having member
//! `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys
//!
//! @tparam NumItemsT
//! **[inferred]** An integral type representing the number of input elements
//!
//! @param[in] d_temp_storage
//! Device-accessible allocation of temporary storage. When `nullptr`, the
//! required allocation size is written to `temp_storage_bytes` and no work is done.
Expand Down Expand Up @@ -1558,21 +1561,22 @@ struct DeviceScan
template <typename KeysInputIteratorT,
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_RUNTIME_FUNCTION static cudaError_t ExclusiveSumByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op = EqualityOpT(),
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceScan::ExclusiveSumByKey");

// Signed integer type for global offsets
using OffsetT = int;
// Unsigned integer type for global offsets
using OffsetT = detail::choose_offset_t<NumItemsT>;
using InitT = cub::detail::value_t<ValuesInputIteratorT>;

// Initial value
Expand Down Expand Up @@ -1601,14 +1605,15 @@ struct DeviceScan
template <typename KeysInputIteratorT,
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION static cudaError_t ExclusiveSumByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op,
cudaStream_t stream,
bool debug_synchronous)
Expand Down Expand Up @@ -1721,6 +1726,9 @@ struct DeviceScan
//! **[inferred]** Functor type having member
//! `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys
//!
//! @tparam NumItemsT
//! **[inferred]** An integral type representing the number of input elements
//!
//! @param[in] d_temp_storage
//! Device-accessible allocation of temporary storage. When `nullptr`, the
//! required allocation size is written to `temp_storage_bytes` and no work is done.
Expand Down Expand Up @@ -1761,7 +1769,8 @@ struct DeviceScan
typename ValuesOutputIteratorT,
typename ScanOpT,
typename InitValueT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_RUNTIME_FUNCTION static cudaError_t ExclusiveScanByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
Expand All @@ -1770,14 +1779,14 @@ struct DeviceScan
ValuesOutputIteratorT d_values_out,
ScanOpT scan_op,
InitValueT init_value,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op = EqualityOpT(),
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceScan::ExclusiveScanByKey");

// Signed integer type for global offsets
using OffsetT = int;
// Unsigned integer type for global offsets
using OffsetT = detail::choose_offset_t<NumItemsT>;
elstehle marked this conversation as resolved.
Show resolved Hide resolved

return DispatchScanByKey<
KeysInputIteratorT,
Expand All @@ -1804,7 +1813,8 @@ struct DeviceScan
typename ValuesOutputIteratorT,
typename ScanOpT,
typename InitValueT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION static cudaError_t ExclusiveScanByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
Expand All @@ -1813,7 +1823,7 @@ struct DeviceScan
ValuesOutputIteratorT d_values_out,
ScanOpT scan_op,
InitValueT init_value,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op,
cudaStream_t stream,
bool debug_synchronous)
Expand Down Expand Up @@ -1904,6 +1914,9 @@ struct DeviceScan
//! **[inferred]** Functor type having member
//! `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys
//!
//! @tparam NumItemsT
//! **[inferred]** An integral type representing the number of input elements
//!
//! @param[in] d_temp_storage
//! Device-accessible allocation of temporary storage.
//! When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done.
Expand Down Expand Up @@ -1934,21 +1947,22 @@ struct DeviceScan
template <typename KeysInputIteratorT,
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_RUNTIME_FUNCTION static cudaError_t InclusiveSumByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op = EqualityOpT(),
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceScan::InclusiveSumByKey");

// Signed integer type for global offsets
using OffsetT = int;
// Unsigned integer type for global offsets
using OffsetT = detail::choose_offset_t<NumItemsT>;

return DispatchScanByKey<
KeysInputIteratorT,
Expand All @@ -1973,14 +1987,15 @@ struct DeviceScan
template <typename KeysInputIteratorT,
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION static cudaError_t InclusiveSumByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op,
cudaStream_t stream,
bool debug_synchronous)
Expand Down Expand Up @@ -2084,6 +2099,9 @@ struct DeviceScan
//! **[inferred]** Functor type having member
//! `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys
//!
//! @tparam NumItemsT
//! **[inferred]** An integral type representing the number of input elements
//!
//! @param[in] d_temp_storage
//! Device-accessible allocation of temporary storage.
//! When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done.
Expand Down Expand Up @@ -2118,22 +2136,23 @@ struct DeviceScan
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename ScanOpT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_RUNTIME_FUNCTION static cudaError_t InclusiveScanByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
ScanOpT scan_op,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op = EqualityOpT(),
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceScan::InclusiveScanByKey");

// Signed integer type for global offsets
using OffsetT = int;
// Unsigned integer type for global offsets
using OffsetT = detail::choose_offset_t<NumItemsT>;

return DispatchScanByKey<
KeysInputIteratorT,
Expand All @@ -2159,15 +2178,16 @@ struct DeviceScan
typename ValuesInputIteratorT,
typename ValuesOutputIteratorT,
typename ScanOpT,
typename EqualityOpT = Equality>
typename EqualityOpT = Equality,
typename NumItemsT = std::uint32_t>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION static cudaError_t InclusiveScanByKey(
void* d_temp_storage,
size_t& temp_storage_bytes,
KeysInputIteratorT d_keys_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
ScanOpT scan_op,
int num_items,
NumItemsT num_items,
EqualityOpT equality_op,
cudaStream_t stream,
bool debug_synchronous)
Expand Down
Loading
Loading