Skip to content

Commit

Permalink
Refactor thrust::[stable_]partition[_copy] to use `cub::DeviceParti…
Browse files Browse the repository at this point in the history
…tion` (#1435)

* extends agentif to implement partitioning to two iterators

* refactors thrust::partition to use cub algorithms

* adds tests for large number of items to thrust::partition

* fixes negative input ranges
  • Loading branch information
elstehle authored Feb 28, 2024
1 parent d1c8a50 commit 1d78f0d
Show file tree
Hide file tree
Showing 4 changed files with 498 additions and 1,020 deletions.
223 changes: 145 additions & 78 deletions cub/cub/agent/agent_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include <cub/block/block_store.cuh>
#include <cub/grid/grid_queue.cuh>
#include <cub/iterator/cache_modified_input_iterator.cuh>
#include <cub/util_type.cuh>

#include <cuda/std/type_traits>

Expand Down Expand Up @@ -121,6 +122,19 @@ struct AgentSelectIfPolicy
* Thread block abstractions
******************************************************************************/

namespace detail
{
template <typename SelectedOutputItT, typename RejectedOutputItT>
struct partition_distinct_output_t
{
using selected_iterator_t = SelectedOutputItT;
using rejected_iterator_t = RejectedOutputItT;

selected_iterator_t selected_it;
rejected_iterator_t rejected_it;
};
} // namespace detail

/**
* @brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in
* device-wide selection
Expand All @@ -139,8 +153,8 @@ struct AgentSelectIfPolicy
* Random-access input iterator type for selections (NullType* if a selection functor or
* discontinuity flagging is to be used for selection)
*
* @tparam SelectedOutputIteratorT
* Random-access output iterator type for selection_flags items
* @tparam OutputIteratorWrapperT
* Either a random-access iterator or an instance of the `partition_distinct_output_t` template.
*
* @tparam SelectOpT
* Selection operator type (NullType if selections or discontinuity flagging is to be used for
Expand All @@ -159,7 +173,7 @@ struct AgentSelectIfPolicy
template <typename AgentSelectIfPolicyT,
typename InputIteratorT,
typename FlagsInputIteratorT,
typename SelectedOutputIteratorT,
typename OutputIteratorWrapperT,
typename SelectOpT,
typename EqualityOpT,
typename OffsetT,
Expand Down Expand Up @@ -279,13 +293,13 @@ struct AgentSelectIf
// Per-thread fields
//---------------------------------------------------------------------

_TempStorage &temp_storage; ///< Reference to temp_storage
WrappedInputIteratorT d_in; ///< Input items
SelectedOutputIteratorT d_selected_out; ///< Unique output items
WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable)
_TempStorage& temp_storage; ///< Reference to temp_storage
WrappedInputIteratorT d_in; ///< Input items
OutputIteratorWrapperT d_selected_out; ///< Unique output items
WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable)
InequalityWrapper<EqualityOpT> inequality_op; ///< T inequality operator
SelectOpT select_op; ///< Selection operator
OffsetT num_items; ///< Total number of input items
SelectOpT select_op; ///< Selection operator
OffsetT num_items; ///< Total number of input items

//---------------------------------------------------------------------
// Constructor
Expand Down Expand Up @@ -316,7 +330,7 @@ struct AgentSelectIf
_CCCL_DEVICE _CCCL_FORCEINLINE AgentSelectIf(TempStorage &temp_storage,
InputIteratorT d_in,
FlagsInputIteratorT d_flags_in,
SelectedOutputIteratorT d_selected_out,
OutputIteratorWrapperT d_selected_out,
SelectOpT select_op,
EqualityOpT equality_op,
OffsetT num_items)
Expand Down Expand Up @@ -477,10 +491,10 @@ struct AgentSelectIf
//---------------------------------------------------------------------

/**
* Scatter flagged items to output offsets (specialized for direct scattering)
* Scatter flagged items to output offsets (specialized for direct scattering).
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterDirect(
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterSelectedDirect(
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
Expand Down Expand Up @@ -519,19 +533,17 @@ struct AgentSelectIf
* Marker type indicating whether to keep rejected items in the second partition
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterTwoPhase(InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int /*num_tile_items*/,
int num_tile_selections,
OffsetT num_selections_prefix,
OffsetT /*num_rejected_prefix*/,
Int2Type<false> /*is_keep_rejects*/)
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterSelectedTwoPhase(
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_selections,
OffsetT num_selections_prefix)
{
CTA_SYNC();

// Compact and scatter items
#pragma unroll
// Compact and scatter items
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix;
Expand All @@ -550,7 +562,52 @@ struct AgentSelectIf
}

/**
* @brief Scatter flagged items to output offsets (specialized for two-phase scattering)
* @brief Scatter flagged items. Specialized for selection algorithm that simply discards rejected items
*
* @param num_tile_items
* Number of valid items in this tile
*
* @param num_tile_selections
* Number of selections in this tile
*
* @param num_selections_prefix
* Total number of selections prior to this tile
*
* @param num_rejected_prefix
* Total number of rejections prior to this tile
*
* @param num_selections
* Total number of selections including this tile
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
_CCCL_DEVICE _CCCL_FORCEINLINE void Scatter(
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items,
int num_tile_selections,
OffsetT num_selections_prefix,
OffsetT num_rejected_prefix,
OffsetT num_selections,
Int2Type<false> /*is_keep_rejects*/)
{
// Do a two-phase scatter if two-phase is enabled and the average number of selection_flags items per thread is
// greater than one
if (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS))
{
ScatterSelectedTwoPhase<IS_LAST_TILE, IS_FIRST_TILE>(
items, selection_flags, selection_indices, num_tile_selections, num_selections_prefix);
}
else
{
ScatterSelectedDirect<IS_LAST_TILE, IS_FIRST_TILE>(
items, selection_flags, selection_indices, num_selections);
}
}

/**
* @brief Scatter flagged items. Specialized for partitioning algorithm that writes rejected items to a second
* partition.
*
* @param num_tile_items
* Number of valid items in this tile
Expand All @@ -568,13 +625,14 @@ struct AgentSelectIf
* Marker type indicating whether to keep rejected items in the second partition
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterTwoPhase(InputT (&items)[ITEMS_PER_THREAD],
_CCCL_DEVICE _CCCL_FORCEINLINE void Scatter(InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items,
int num_tile_selections,
OffsetT num_selections_prefix,
OffsetT num_rejected_prefix,
OffsetT num_selections,
Int2Type<true> /*is_keep_rejects*/)
{
CTA_SYNC();
Expand All @@ -595,76 +653,83 @@ struct AgentSelectIf
temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM];
}

// Ensure all threads finished scattering to shared memory
CTA_SYNC();

// Gather items from shared memory and scatter to global
ScatterPartitionsToGlobal<IS_LAST_TILE, IS_FIRST_TILE>(
num_tile_items, tile_num_rejections, num_selections_prefix, num_rejected_prefix, d_selected_out);
}

/**
* @brief Second phase of scattering partitioned items to global memory. Specialized for partitioning to two
* distinct partitions.
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE, typename SelectedItT, typename RejectedItT>
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterPartitionsToGlobal(
int num_tile_items,
int tile_num_rejections,
OffsetT num_selections_prefix,
OffsetT num_rejected_prefix,
detail::partition_distinct_output_t<SelectedItT, RejectedItT> partitioned_out_it_wrapper)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x;
int rejection_idx = item_idx;
int selection_idx = item_idx - tile_num_rejections;
OffsetT scatter_offset = (item_idx < tile_num_rejections) ?
num_items - num_rejected_prefix - rejection_idx - 1 :
num_selections_prefix + selection_idx;
int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x;
int rejection_idx = item_idx;
int selection_idx = item_idx - tile_num_rejections;
OffsetT scatter_offset =
(item_idx < tile_num_rejections)
? num_rejected_prefix + rejection_idx
: num_selections_prefix + selection_idx;

InputT item = temp_storage.raw_exchange.Alias()[item_idx];

if (!IS_LAST_TILE || (item_idx < num_tile_items))
{
d_selected_out[scatter_offset] = item;
if (item_idx >= tile_num_rejections)
{
partitioned_out_it_wrapper.selected_it[scatter_offset] = item;
}
else
{
partitioned_out_it_wrapper.rejected_it[scatter_offset] = item;
}
}
}
}

/**
* @brief Scatter flagged items
*
* @param num_tile_items
* Number of valid items in this tile
*
* @param num_tile_selections
* Number of selections in this tile
*
* @param num_selections_prefix
* Total number of selections prior to this tile
*
* @param num_rejected_prefix
* Total number of rejections prior to this tile
*
* @param num_selections
* Total number of selections including this tile
* @brief Second phase of scattering partitioned items to global memory. Specialized for partitioning to a single
* iterator, where selected items are written in order from the beginning of the itereator and rejected items are
* writtem from the iterators end backwards.
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
_CCCL_DEVICE _CCCL_FORCEINLINE void Scatter(InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items,
int num_tile_selections,
OffsetT num_selections_prefix,
OffsetT num_rejected_prefix,
OffsetT num_selections)
template <bool IS_LAST_TILE, bool IS_FIRST_TILE, typename PartitionedOutputItT>
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterPartitionsToGlobal(
int num_tile_items,
int tile_num_rejections,
OffsetT num_selections_prefix,
OffsetT num_rejected_prefix,
PartitionedOutputItT partitioned_out_it)
{
// Do a two-phase scatter if (a) keeping both partitions or (b) two-phase is enabled and the average number of selection_flags items per thread is greater than one
if (KEEP_REJECTS || (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS)))
{
ScatterTwoPhase<IS_LAST_TILE, IS_FIRST_TILE>(
items,
selection_flags,
selection_indices,
num_tile_items,
num_tile_selections,
num_selections_prefix,
num_rejected_prefix,
Int2Type<KEEP_REJECTS>());
}
else
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
ScatterDirect<IS_LAST_TILE, IS_FIRST_TILE>(
items,
selection_flags,
selection_indices,
num_selections);
int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x;
int rejection_idx = item_idx;
int selection_idx = item_idx - tile_num_rejections;
OffsetT scatter_offset =
(item_idx < tile_num_rejections)
? num_items - num_rejected_prefix - rejection_idx - 1
: num_selections_prefix + selection_idx;

InputT item = temp_storage.raw_exchange.Alias()[item_idx];

if (!IS_LAST_TILE || (item_idx < num_tile_items))
{
partitioned_out_it[scatter_offset] = item;
}
}
}

Expand Down Expand Up @@ -736,7 +801,8 @@ struct AgentSelectIf
num_tile_selections,
0,
0,
num_tile_selections);
num_tile_selections,
cub::Int2Type<KEEP_REJECTS>{});

return num_tile_selections;
}
Expand Down Expand Up @@ -791,7 +857,7 @@ struct AgentSelectIf
OffsetT num_tile_selections = prefix_op.GetBlockAggregate();
OffsetT num_selections = prefix_op.GetInclusivePrefix();
OffsetT num_selections_prefix = prefix_op.GetExclusivePrefix();
OffsetT num_rejected_prefix = (tile_idx * TILE_ITEMS) - num_selections_prefix;
OffsetT num_rejected_prefix = tile_offset - num_selections_prefix;

// Discount any out-of-bounds selections
if (IS_LAST_TILE)
Expand All @@ -810,7 +876,8 @@ struct AgentSelectIf
num_tile_selections,
num_selections_prefix,
num_rejected_prefix,
num_selections);
num_selections,
cub::Int2Type<KEEP_REJECTS>{});

return num_selections;
}
Expand Down
Loading

0 comments on commit 1d78f0d

Please sign in to comment.