Skip to content

Commit

Permalink
Add b200 tunings for histogram
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis authored and bernhardmgruber committed Jan 30, 2025
1 parent a00de21 commit fe5d12e
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 12 deletions.
42 changes: 32 additions & 10 deletions cub/cub/device/dispatch/dispatch_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

#include <cub/config.cuh>

#include <cuda/std/__type_traits/is_void.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
Expand Down Expand Up @@ -554,8 +556,7 @@ template <int NUM_CHANNELS,
typename CounterT,
typename LevelT,
typename OffsetT,
typename PolicyHub =
detail::histogram::policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS>>
typename PolicyHub = void> // if user passes a custom Policy this should not be void
struct DispatchHistogram
{
static_assert(NUM_CHANNELS <= 4, "Histograms only support up to 4 channels");
Expand Down Expand Up @@ -920,8 +921,14 @@ public:
cudaStream_t stream,
Int2Type<false> /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
// Should we call DispatchHistogram<....., PolicyHub=void> in DeviceHistogram?
static constexpr bool isEven = 0;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1091,8 +1098,13 @@ public:
cudaStream_t stream,
Int2Type<true> /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 0;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1226,8 +1238,13 @@ public:
cudaStream_t stream,
Int2Type<false> /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 1;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1412,8 +1429,13 @@ public:
cudaStream_t stream,
Int2Type<true> /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 1;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down
186 changes: 184 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ enum class sample_size
{
_1,
_2,
_4,
_8,
unknown
};

Expand Down Expand Up @@ -125,7 +127,164 @@ struct sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sampl
static constexpr bool work_stealing = false;
};

template <class SampleT, class CounterT, int NumChannels, int NumActiveChannels>
template <bool IsEven,
class SampleT,
int NumChannels,
int NumActiveChannels,
counter_size CounterSize,
primitive_sample PrimitiveSample = is_primitive_sample<SampleT>(),
sample_size SampleSize = classify_sample_size<SampleT>()>
struct sm100_tuning;

// even
template <class SampleT>
struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_12.tpb_928.rle_0.ws_0.mem_1.ld_2.laid_0.vec_2 1.033332 0.940517 1.031835 1.195876
static constexpr int items = 12;
static constexpr int threads = 928;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_CA;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr int tune_vec_size = 1 << 2;
};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
{};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4>
{};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8>
{};

// range
template <class SampleT>
struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_12.tpb_448.rle_0.ws_0.mem_1.ld_1.laid_0.vec_2 1.078987 0.985542 1.085118 1.175637
static constexpr int items = 12;
static constexpr int threads = 448;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr int tune_vec_size = 1 << 2;
};

// same as base
template <class SampleT>
struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
{};

template <class SampleT>
struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4>
{
// ipt_9.tpb_1024.rle_1.ws_0.mem_1.ld_0.laid_1.vec_0 1.358537 1.001009 1.373329 2.614104
static constexpr int items = 9;
static constexpr int threads = 1024;
static constexpr bool rle_compress = true;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr int tune_vec_size = 1 << 0;
};

template <class SampleT>
struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8>
{
// ipt_7.tpb_544.rle_1.ws_0.mem_1.ld_1.laid_0.vec_0 1.105331 0.934888 1.108557 1.391657
static constexpr int items = 7;
static constexpr int threads = 544;
static constexpr bool rle_compress = true;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr int tune_vec_size = 1 << 0;
};

// multi.even
template <class SampleT>
struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_9.tpb_1024.rle_0.ws_0.mem_1.ld_1.laid_1.vec_0 1.629591 0.997416 1.570900 2.772504
static constexpr int items = 9;
static constexpr int threads = 1024;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr int tune_vec_size = 1 << 0;
};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_2>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
{};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_4>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4>
{};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_8>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8>
{};

// multi.range
template <class SampleT>
struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_7.tpb_160.rle_0.ws_0.mem_1.ld_1.laid_1.vec_1 1.210837 0.99556 1.189049 1.939584
static constexpr int items = 7;
static constexpr int threads = 160;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr int tune_vec_size = 1 << 1;
};

// same as base
template <class SampleT>
struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_2>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
{};

// same as base
template <class SampleT>
struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_4>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4>
{};

// same as base
template <class SampleT>
struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_8>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8>
{};

template <class SampleT, class CounterT, int NumChannels, int NumActiveChannels, bool IsEven>
struct policy_hub
{
// TODO(bgruber): move inside t_scale in C++14
Expand Down Expand Up @@ -166,7 +325,30 @@ struct policy_hub
sm90_tuning<SampleT, NumChannels, NumActiveChannels, histogram::classify_counter_size<CounterT>()>>(0));
};

using MaxPolicy = Policy900;
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
{
// Use values from tuning if a specialization exists, otherwise pick Policy900
template <typename Tuning>
static auto select_agent_policy(int)
-> AgentHistogramPolicy<Tuning::threads,
Tuning::items,
Tuning::load_algorithm,
Tuning::load_modifier,
Tuning::rle_compress,
Tuning::mem_preference,
Tuning::work_stealing,
Tuning::vec_size>;

template <typename Tuning>
static auto select_agent_policy(long) -> typename Policy900::AgentHistogramPolicyT;

using AgentHistogramPolicyT =
decltype(select_agent_policy<
sm100_tuning<IsEven, SampleT, NumChannels, NumActiveChannels, histogram::classify_counter_size<CounterT>()>>(
0));
};

using MaxPolicy = Policy1000;
};
} // namespace histogram
} // namespace detail
Expand Down

0 comments on commit fe5d12e

Please sign in to comment.