Skip to content

Commit

Permalink
Backport to 2.8: B200 tunings for histogram (#3728)
Browse files Browse the repository at this point in the history
* Add b200 tunings for histogram (#3616)

Co-authored-by: Giannis Gonidelis <[email protected]>

* Fix SM100 histogram tunings (#3691)

The tuning data member names did not match the one used when selecting
tunings, so all SM100 tunings were SFINAE-ed out.

Also drop tunings with no benefit.

---------

Co-authored-by: Giannis Gonidelis <[email protected]>
  • Loading branch information
bernhardmgruber and gonidelis authored Feb 7, 2025
1 parent c8bda1a commit 5571cd8
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 12 deletions.
42 changes: 32 additions & 10 deletions cub/cub/device/dispatch/dispatch_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

#include <cub/config.cuh>

#include <cuda/std/__type_traits/is_void.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
Expand Down Expand Up @@ -554,8 +556,7 @@ template <int NUM_CHANNELS,
typename CounterT,
typename LevelT,
typename OffsetT,
typename PolicyHub =
detail::histogram::policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS>>
typename PolicyHub = void> // if user passes a custom Policy this should not be void
struct DispatchHistogram
{
static_assert(NUM_CHANNELS <= 4, "Histograms only support up to 4 channels");
Expand Down Expand Up @@ -920,8 +921,14 @@ public:
cudaStream_t stream,
::cuda::std::false_type /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
// Should we call DispatchHistogram<....., PolicyHub=void> in DeviceHistogram?
static constexpr bool isEven = 0;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1124,8 +1131,13 @@ public:
cudaStream_t stream,
::cuda::std::true_type /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 0;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1292,8 +1304,13 @@ public:
cudaStream_t stream,
::cuda::std::false_type /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 1;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1513,8 +1530,13 @@ public:
cudaStream_t stream,
::cuda::std::true_type /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 1;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down
74 changes: 72 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ enum class sample_size
{
_1,
_2,
_4,
_8,
unknown
};

Expand Down Expand Up @@ -125,7 +127,52 @@ struct sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sampl
static constexpr bool work_stealing = false;
};

template <class SampleT, class CounterT, int NumChannels, int NumActiveChannels>
template <bool IsEven,
class SampleT,
int NumChannels,
int NumActiveChannels,
counter_size CounterSize,
primitive_sample PrimitiveSample = is_primitive_sample<SampleT>(),
sample_size SampleSize = classify_sample_size<SampleT>()>
struct sm100_tuning;

// even
template <class SampleT>
struct sm100_tuning<true, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_12.tpb_928.rle_0.ws_0.mem_1.ld_2.laid_0.vec_2 1.033332 0.940517 1.031835 1.195876
static constexpr int items = 12;
static constexpr int threads = 928;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_CA;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr int vec_size = 1 << 2;
};

// sample_size 2/4/8 showed no benefit over SM90 during verification benchmarks

// range
template <class SampleT>
struct sm100_tuning<false, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_12.tpb_448.rle_0.ws_0.mem_1.ld_1.laid_0.vec_2 1.078987 0.985542 1.085118 1.175637
static constexpr int items = 12;
static constexpr int threads = 448;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr int vec_size = 1 << 2;
};

// sample_size 2/4/8 showed no benefit over SM90 during verification benchmarks

// multi.even and multi.range: none of the found tunings surpassed the SM90 tuning during verification benchmarks

template <class SampleT, class CounterT, int NumChannels, int NumActiveChannels, bool IsEven>
struct policy_hub
{
// TODO(bgruber): move inside t_scale in C++14
Expand Down Expand Up @@ -173,7 +220,30 @@ struct policy_hub
sm90_tuning<SampleT, NumChannels, NumActiveChannels, histogram::classify_counter_size<CounterT>()>>(0));
};

using MaxPolicy = Policy900;
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
{
// Use values from tuning if a specialization exists, otherwise pick Policy900
template <typename Tuning>
static auto select_agent_policy(int)
-> AgentHistogramPolicy<Tuning::threads,
Tuning::items,
Tuning::load_algorithm,
Tuning::load_modifier,
Tuning::rle_compress,
Tuning::mem_preference,
Tuning::work_stealing,
Tuning::vec_size>;

template <typename Tuning>
static auto select_agent_policy(long) -> typename Policy900::AgentHistogramPolicyT;

using AgentHistogramPolicyT =
decltype(select_agent_policy<
sm100_tuning<IsEven, SampleT, NumChannels, NumActiveChannels, histogram::classify_counter_size<CounterT>()>>(
0));
};

using MaxPolicy = Policy1000;
};
} // namespace histogram
} // namespace detail
Expand Down

0 comments on commit 5571cd8

Please sign in to comment.