Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add b200 tunings for histogram #3616

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions cub/cub/device/dispatch/dispatch_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

#include <cub/config.cuh>

#include <cuda/std/__type_traits/is_void.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
Expand Down Expand Up @@ -554,8 +556,7 @@ template <int NUM_CHANNELS,
typename CounterT,
typename LevelT,
typename OffsetT,
typename PolicyHub =
detail::histogram::policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS>>
typename PolicyHub = void> // if user passes a custom Policy this should not be void
struct DispatchHistogram
{
static_assert(NUM_CHANNELS <= 4, "Histograms only support up to 4 channels");
Expand Down Expand Up @@ -920,8 +921,14 @@ public:
cudaStream_t stream,
Int2Type<false> /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
// Should we call DispatchHistogram<....., PolicyHub=void> in DeviceHistogram?
static constexpr bool isEven = 0;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1091,8 +1098,13 @@ public:
cudaStream_t stream,
Int2Type<true> /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 0;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1226,8 +1238,13 @@ public:
cudaStream_t stream,
Int2Type<false> /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 1;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1412,8 +1429,13 @@ public:
cudaStream_t stream,
Int2Type<true> /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 1;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down
186 changes: 184 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ enum class sample_size
{
_1,
_2,
_4,
_8,
unknown
};

Expand Down Expand Up @@ -125,7 +127,164 @@ struct sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sampl
static constexpr bool work_stealing = false;
};

template <class SampleT, class CounterT, int NumChannels, int NumActiveChannels>
template <bool IsEven,
class SampleT,
int NumChannels,
int NumActiveChannels,
counter_size CounterSize,
primitive_sample PrimitiveSample = is_primitive_sample<SampleT>(),
sample_size SampleSize = classify_sample_size<SampleT>()>
struct sm100_tuning;

// even
template <class SampleT>
struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_12.tpb_928.rle_0.ws_0.mem_1.ld_2.laid_0.vec_2 1.033332 0.940517 1.031835 1.195876
static constexpr int items = 12;
static constexpr int threads = 928;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_CA;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr int tune_vec_size = 1 << 2;
};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
{};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4>
{};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8>
{};

// range
template <class SampleT>
struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_12.tpb_448.rle_0.ws_0.mem_1.ld_1.laid_0.vec_2 1.078987 0.985542 1.085118 1.175637
static constexpr int items = 12;
static constexpr int threads = 448;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr int tune_vec_size = 1 << 2;
};

// same as base
template <class SampleT>
struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
{};

template <class SampleT>
struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4>
{
// ipt_9.tpb_1024.rle_1.ws_0.mem_1.ld_0.laid_1.vec_0 1.358537 1.001009 1.373329 2.614104
static constexpr int items = 9;
static constexpr int threads = 1024;
static constexpr bool rle_compress = true;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr int tune_vec_size = 1 << 0;
};

template <class SampleT>
struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8>
{
// ipt_7.tpb_544.rle_1.ws_0.mem_1.ld_1.laid_0.vec_0 1.105331 0.934888 1.108557 1.391657
static constexpr int items = 7;
static constexpr int threads = 544;
static constexpr bool rle_compress = true;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr int tune_vec_size = 1 << 0;
};

// multi.even
template <class SampleT>
struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_9.tpb_1024.rle_0.ws_0.mem_1.ld_1.laid_1.vec_0 1.629591 0.997416 1.570900 2.772504
static constexpr int items = 9;
static constexpr int threads = 1024;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr int tune_vec_size = 1 << 0;
};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_2>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
{};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_4>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4>
{};

// same as base
template <class SampleT>
struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_8>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8>
{};

// multi.range
template <class SampleT>
struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_7.tpb_160.rle_0.ws_0.mem_1.ld_1.laid_1.vec_1 1.210837 0.99556 1.189049 1.939584
static constexpr int items = 7;
static constexpr int threads = 160;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr int tune_vec_size = 1 << 1;
};

// same as base
template <class SampleT>
struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_2>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2>
{};

// same as base
template <class SampleT>
struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_4>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4>
{};

// same as base
template <class SampleT>
struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_8>
: sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8>
{};

template <class SampleT, class CounterT, int NumChannels, int NumActiveChannels, bool IsEven>
struct policy_hub
{
// TODO(bgruber): move inside t_scale in C++14
Expand Down Expand Up @@ -166,7 +325,30 @@ struct policy_hub
sm90_tuning<SampleT, NumChannels, NumActiveChannels, histogram::classify_counter_size<CounterT>()>>(0));
};

using MaxPolicy = Policy900;
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
{
// Use values from tuning if a specialization exists, otherwise pick Policy900
template <typename Tuning>
static auto select_agent_policy(int)
-> AgentHistogramPolicy<Tuning::threads,
Tuning::items,
Tuning::load_algorithm,
Tuning::load_modifier,
Tuning::rle_compress,
Tuning::mem_preference,
Tuning::work_stealing,
Tuning::vec_size>;

template <typename Tuning>
static auto select_agent_policy(long) -> typename Policy900::AgentHistogramPolicyT;

using AgentHistogramPolicyT =
decltype(select_agent_policy<
sm100_tuning<IsEven, SampleT, NumChannels, NumActiveChannels, histogram::classify_counter_size<CounterT>()>>(
0));
};

using MaxPolicy = Policy1000;
};
} // namespace histogram
} // namespace detail
Expand Down
Loading