Skip to content

Commit

Permalink
Unify policy hub handling and update documentation (#3142)
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber authored Dec 13, 2024
1 parent d5ee178 commit f299056
Show file tree
Hide file tree
Showing 19 changed files with 153 additions and 213 deletions.
2 changes: 1 addition & 1 deletion cub/cub/device/device_partition.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ private:
typename OffsetT,
typename BeginOffsetIteratorT,
typename EndOffsetIteratorT,
typename SelectedPolicy>
typename PolicyHub>
friend class DispatchSegmentedSort;

// Internal version without NVTX range
Expand Down
12 changes: 4 additions & 8 deletions cub/cub/device/dispatch/dispatch_adjacent_difference.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ template <typename InputIteratorT,
typename OffsetT,
bool MayAlias,
bool ReadLeft,
typename SelectedPolicy = detail::adjacent_difference::policy_hub<InputIteratorT, MayAlias>>
struct DispatchAdjacentDifference : public SelectedPolicy
typename PolicyHub = detail::adjacent_difference::policy_hub<InputIteratorT, MayAlias>>
struct DispatchAdjacentDifference
{
using InputT = typename std::iterator_traits<InputIteratorT>::value_type;

Expand Down Expand Up @@ -167,8 +167,6 @@ struct DispatchAdjacentDifference : public SelectedPolicy
{
using AdjacentDifferencePolicyT = typename ActivePolicyT::AdjacentDifferencePolicy;

using MaxPolicyT = typename DispatchAdjacentDifference::MaxPolicy;

cudaError error = cudaSuccess;

do
Expand Down Expand Up @@ -256,7 +254,7 @@ struct DispatchAdjacentDifference : public SelectedPolicy
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
num_tiles, AdjacentDifferencePolicyT::BLOCK_THREADS, 0, stream)
.doit(DeviceAdjacentDifferenceDifferenceKernel<
MaxPolicyT,
typename PolicyHub::MaxPolicy,
InputIteratorT,
OutputIteratorT,
DifferenceOpT,
Expand Down Expand Up @@ -297,8 +295,6 @@ struct DispatchAdjacentDifference : public SelectedPolicy
DifferenceOpT difference_op,
cudaStream_t stream)
{
using MaxPolicyT = typename DispatchAdjacentDifference::MaxPolicy;

cudaError error = cudaSuccess;
do
{
Expand All @@ -315,7 +311,7 @@ struct DispatchAdjacentDifference : public SelectedPolicy
d_temp_storage, temp_storage_bytes, d_input, d_output, num_items, difference_op, stream);

// Dispatch to chained policy
error = CubDebug(MaxPolicyT::Invoke(ptx_version, dispatch));
error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch));
if (cudaSuccess != error)
{
break;
Expand Down
16 changes: 6 additions & 10 deletions cub/cub/device/dispatch/dispatch_batch_memcpy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ template <typename InputBufferIt,
typename BufferSizeIteratorT,
typename BufferOffsetT,
typename BlockOffsetT,
typename SelectedPolicy = batch_memcpy::policy_hub<BufferOffsetT, BlockOffsetT>,
bool IsMemcpy = true>
struct DispatchBatchMemcpy : SelectedPolicy
typename PolicyHub = batch_memcpy::policy_hub<BufferOffsetT, BlockOffsetT>,
bool IsMemcpy = true>
struct DispatchBatchMemcpy
{
//------------------------------------------------------------------------------
// TYPE ALIASES
Expand Down Expand Up @@ -345,8 +345,6 @@ struct DispatchBatchMemcpy : SelectedPolicy
template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke()
{
using MaxPolicyT = typename DispatchBatchMemcpy::MaxPolicy;

// Single-pass prefix scan tile states for the prefix-sum over the number of block-level buffers
using BLevBufferOffsetTileState = cub::ScanTileState<BufferOffsetT>;

Expand Down Expand Up @@ -466,7 +464,7 @@ struct DispatchBatchMemcpy : SelectedPolicy
auto init_scan_states_kernel =
InitTileStateKernel<BLevBufferOffsetTileState, BLevBlockOffsetTileState, BlockOffsetT>;
auto batch_memcpy_non_blev_kernel = BatchMemcpyKernel<
MaxPolicyT,
typename PolicyHub::MaxPolicy,
InputBufferIt,
OutputBufferIt,
BufferSizeIteratorT,
Expand All @@ -481,7 +479,7 @@ struct DispatchBatchMemcpy : SelectedPolicy
IsMemcpy>;

auto multi_block_memcpy_kernel = MultiBlockBatchMemcpyKernel<
MaxPolicyT,
typename PolicyHub::MaxPolicy,
BufferOffsetT,
BlevBufferSrcsOutItT,
BlevBufferDstsOutItT,
Expand Down Expand Up @@ -651,8 +649,6 @@ struct DispatchBatchMemcpy : SelectedPolicy
BufferOffsetT num_buffers,
cudaStream_t stream)
{
using MaxPolicyT = typename DispatchBatchMemcpy::MaxPolicy;

cudaError_t error = cudaSuccess;

// Get PTX version
Expand All @@ -668,7 +664,7 @@ struct DispatchBatchMemcpy : SelectedPolicy
d_temp_storage, temp_storage_bytes, input_buffer_it, output_buffer_it, buffer_sizes, num_buffers, stream);

// Dispatch to chained policy
error = CubDebug(MaxPolicyT::Invoke(ptx_version, dispatch));
error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch));
if (cudaSuccess != error)
{
return error;
Expand Down
12 changes: 4 additions & 8 deletions cub/cub/device/dispatch/dispatch_for.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace for_each

// The dispatch layer is in the detail namespace until we figure out tuning API
template <class OffsetT, class OpT, class PolicyHubT = policy_hub_t>
struct dispatch_t : PolicyHubT
struct dispatch_t
{
OffsetT num_items;
OpT op;
Expand All @@ -75,7 +75,7 @@ struct dispatch_t : PolicyHubT
CUB_RUNTIME_FUNCTION
_CCCL_FORCEINLINE cudaError_t Invoke(::cuda::std::false_type /* block size is not known at compile time */)
{
using max_policy_t = typename dispatch_t::MaxPolicy;
using max_policy_t = typename PolicyHubT::MaxPolicy;

if (num_items == 0)
{
Expand Down Expand Up @@ -132,8 +132,6 @@ struct dispatch_t : PolicyHubT
CUB_RUNTIME_FUNCTION
_CCCL_FORCEINLINE cudaError_t Invoke(::cuda::std::true_type /* block size is known at compile time */)
{
using max_policy_t = typename dispatch_t::MaxPolicy;

if (num_items == 0)
{
return cudaSuccess;
Expand All @@ -157,7 +155,7 @@ struct dispatch_t : PolicyHubT

error = THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
static_cast<unsigned int>(num_tiles), static_cast<unsigned int>(block_threads), 0, stream)
.doit(detail::for_each::static_kernel<max_policy_t, OffsetT, OpT>, num_items, op);
.doit(detail::for_each::static_kernel<typename PolicyHubT::MaxPolicy, OffsetT, OpT>, num_items, op);
error = CubDebug(error);
if (cudaSuccess != error)
{
Expand All @@ -182,8 +180,6 @@ struct dispatch_t : PolicyHubT

CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t dispatch(OffsetT num_items, OpT op, cudaStream_t stream)
{
using max_policy_t = typename dispatch_t::MaxPolicy;

int ptx_version = 0;
cudaError_t error = CubDebug(PtxVersion(ptx_version));
if (cudaSuccess != error)
Expand All @@ -193,7 +189,7 @@ struct dispatch_t : PolicyHubT

dispatch_t dispatch(num_items, op, stream);

error = CubDebug(max_policy_t::Invoke(ptx_version, dispatch));
error = CubDebug(PolicyHubT::MaxPolicy::Invoke(ptx_version, dispatch));

return error;
}
Expand Down
7 changes: 3 additions & 4 deletions cub/cub/device/dispatch/dispatch_for_each_in_extents.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ namespace for_each_in_extents

// The dispatch layer is in the detail namespace until we figure out the tuning API
template <typename ExtentsType, typename OpType, typename PolicyHubT = cub::detail::for_each::policy_hub_t>
class dispatch_t : PolicyHubT
class dispatch_t
{
using index_type = typename ExtentsType::index_type;
using unsigned_index_type = ::cuda::std::make_unsigned_t<index_type>;
using max_policy_t = typename dispatch_t::MaxPolicy;
using max_policy_t = typename PolicyHubT::MaxPolicy;
// workaround for nvcc 11.1 bug related to deduction guides, vvv
using array_type = ::cuda::std::array<fast_div_mod<index_type>, ExtentsType::rank()>;

Expand Down Expand Up @@ -190,8 +190,7 @@ public:
_CCCL_NODISCARD CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t
dispatch(const ExtentsType& ext, const OpType& op, cudaStream_t stream)
{
using max_policy_t = typename dispatch_t::MaxPolicy;
int ptx_version = 0;
int ptx_version = 0;
_CUB_RETURN_IF_ERROR(CubDebug(PtxVersion(ptx_version)))
dispatch_t dispatch(ext, op, stream);
return CubDebug(max_policy_t::Invoke(ptx_version, dispatch));
Expand Down
14 changes: 7 additions & 7 deletions cub/cub/device/dispatch/dispatch_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ struct dispatch_histogram
* @tparam OffsetT
* Signed integer type for global offsets
*
* @tparam SelectedPolicy
* @tparam PolicyHub
* Implementation detail, do not specify directly, requirements on the
* content of this type are subject to breaking change.
*/
Expand All @@ -555,9 +555,9 @@ template <int NUM_CHANNELS,
typename CounterT,
typename LevelT,
typename OffsetT,
typename SelectedPolicy =
typename PolicyHub =
detail::histogram::policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS>>
struct DispatchHistogram : SelectedPolicy
struct DispatchHistogram
{
static_assert(NUM_CHANNELS <= 4, "Histograms only support up to 4 channels");
static_assert(NUM_ACTIVE_CHANNELS <= NUM_CHANNELS,
Expand Down Expand Up @@ -921,7 +921,7 @@ public:
cudaStream_t stream,
Int2Type<false> /*is_byte_sample*/)
{
using MaxPolicyT = typename SelectedPolicy::MaxPolicy;
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;

do
Expand Down Expand Up @@ -1125,7 +1125,7 @@ public:
cudaStream_t stream,
Int2Type<true> /*is_byte_sample*/)
{
using MaxPolicyT = typename SelectedPolicy::MaxPolicy;
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;

do
Expand Down Expand Up @@ -1293,7 +1293,7 @@ public:
cudaStream_t stream,
Int2Type<false> /*is_byte_sample*/)
{
using MaxPolicyT = typename SelectedPolicy::MaxPolicy;
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;

do
Expand Down Expand Up @@ -1514,7 +1514,7 @@ public:
cudaStream_t stream,
Int2Type<true> /*is_byte_sample*/)
{
using MaxPolicyT = typename SelectedPolicy::MaxPolicy;
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;

do
Expand Down
14 changes: 5 additions & 9 deletions cub/cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,8 @@ template <typename KeyInputIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename SelectedPolicy = detail::merge_sort::policy_hub<KeyIteratorT>>
struct DispatchMergeSort : SelectedPolicy
typename PolicyHub = detail::merge_sort::policy_hub<KeyIteratorT>>
struct DispatchMergeSort
{
using KeyT = cub::detail::value_t<KeyIteratorT>;
using ValueT = cub::detail::value_t<ValueIteratorT>;
Expand Down Expand Up @@ -447,8 +447,6 @@ struct DispatchMergeSort : SelectedPolicy
using BlockSortVSmemHelperT = cub::detail::vsmem_helper_impl<typename merge_sort_helper_t::block_sort_agent_t>;
using MergeAgentVSmemHelperT = cub::detail::vsmem_helper_impl<typename merge_sort_helper_t::merge_agent_t>;

using MaxPolicyT = typename DispatchMergeSort::MaxPolicy;

cudaError error = cudaSuccess;

if (num_items == 0)
Expand Down Expand Up @@ -517,7 +515,7 @@ struct DispatchMergeSort : SelectedPolicy
static_cast<int>(num_tiles), merge_sort_helper_t::policy_t::BLOCK_THREADS, 0, stream)
.doit(
DeviceMergeSortBlockSortKernel<
MaxPolicyT,
typename PolicyHub::MaxPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
Expand Down Expand Up @@ -602,7 +600,7 @@ struct DispatchMergeSort : SelectedPolicy
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
static_cast<int>(num_tiles), static_cast<int>(merge_sort_helper_t::policy_t::BLOCK_THREADS), 0, stream, true)
.doit(
DeviceMergeSortMergeKernel<MaxPolicyT,
DeviceMergeSortMergeKernel<typename PolicyHub::MaxPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
Expand Down Expand Up @@ -651,8 +649,6 @@ struct DispatchMergeSort : SelectedPolicy
CompareOpT compare_op,
cudaStream_t stream)
{
using MaxPolicyT = typename DispatchMergeSort::MaxPolicy;

cudaError error = cudaSuccess;
do
{
Expand All @@ -678,7 +674,7 @@ struct DispatchMergeSort : SelectedPolicy
ptx_version);

// Dispatch to chained policy
error = CubDebug(MaxPolicyT::Invoke(ptx_version, dispatch));
error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch));
if (cudaSuccess != error)
{
break;
Expand Down
Loading

0 comments on commit f299056

Please sign in to comment.