Skip to content

Commit

Permalink
Internalize cub::KernelConfig (#3683) (#3688)
Browse files Browse the repository at this point in the history
(cherry picked from commit 32d05c7)

Co-authored-by: Federico Busato <[email protected]>
  • Loading branch information
github-actions[bot] and fbusato authored Feb 5, 2025
1 parent d0d196c commit 95919f9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 18 deletions.
8 changes: 4 additions & 4 deletions cub/cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1184,11 +1184,11 @@ struct DispatchRadixSort
struct PassConfig
{
UpsweepKernelT upsweep_kernel;
KernelConfig upsweep_config;
detail::KernelConfig upsweep_config;
ScanKernelT scan_kernel;
KernelConfig scan_config;
detail::KernelConfig scan_config;
DownsweepKernelT downsweep_kernel;
KernelConfig downsweep_config;
detail::KernelConfig downsweep_config;
int radix_bits;
int radix_digits;
int max_downsweep_grid_size;
Expand Down Expand Up @@ -2135,7 +2135,7 @@ struct DispatchSegmentedRadixSort
struct PassConfig
{
SegmentedKernelT segmented_kernel;
KernelConfig segmented_config;
detail::KernelConfig segmented_config;
int radix_bits;
int radix_digits;

Expand Down
6 changes: 4 additions & 2 deletions cub/cub/device/dispatch/dispatch_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ struct DispatchReduce
}

// Init regular kernel configuration
KernelConfig reduce_config;
detail::KernelConfig reduce_config;
(void) reduce_config;
error = CubDebug(reduce_config.Init(reduce_kernel, active_policy.Reduce(), launcher_factory));
if (cudaSuccess != error)
{
Expand Down Expand Up @@ -949,7 +950,8 @@ struct DispatchSegmentedReduce
}

// Init kernel configuration
KernelConfig segmented_reduce_config;
detail::KernelConfig segmented_reduce_config;
(void) &segmented_reduce_config;
error =
CubDebug(segmented_reduce_config.Init<typename ActivePolicyT::SegmentedReducePolicy>(segmented_reduce_kernel));
if (cudaSuccess != error)
Expand Down
21 changes: 9 additions & 12 deletions cub/cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -656,25 +656,18 @@ CUB_RUNTIME_FUNCTION detail::PolicyWrapper<PolicyT> MakePolicyWrapper(PolicyT po

namespace detail
{

struct TripleChevronFactory;
}

/**
* Kernel dispatch configuration
*/
struct KernelConfig
{
int block_threads;
int items_per_thread;
int tile_size;
int sm_occupancy;

CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE KernelConfig()
: block_threads(0)
, items_per_thread(0)
, tile_size(0)
, sm_occupancy(0)
{}
int block_threads{0};
int items_per_thread{0};
int tile_size{0};
int sm_occupancy{0};

template <typename AgentPolicyT, typename KernelPtrT, typename LauncherFactory = detail::TripleChevronFactory>
CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t
Expand All @@ -687,6 +680,10 @@ struct KernelConfig
}
};

} // namespace detail

using KernelConfig CCCL_DEPRECATED_BECAUSE("This class is considered an implementation detail") = detail::KernelConfig;

/// Helper for dispatching into a policy chain
template <int PolicyPtxVersion, typename PolicyT, typename PrevPolicyT>
struct ChainedPolicy
Expand Down

0 comments on commit 95919f9

Please sign in to comment.