Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport to 2.8: B200 tunings for histogram #3728

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions cub/cub/device/dispatch/dispatch_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

#include <cub/config.cuh>

#include <cuda/std/__type_traits/is_void.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
Expand Down Expand Up @@ -554,8 +556,7 @@ template <int NUM_CHANNELS,
typename CounterT,
typename LevelT,
typename OffsetT,
typename PolicyHub =
detail::histogram::policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS>>
typename PolicyHub = void> // if user passes a custom Policy this should not be void
struct DispatchHistogram
{
static_assert(NUM_CHANNELS <= 4, "Histograms only support up to 4 channels");
Expand Down Expand Up @@ -920,8 +921,14 @@ public:
cudaStream_t stream,
::cuda::std::false_type /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
// Should we call DispatchHistogram<....., PolicyHub=void> in DeviceHistogram?
static constexpr bool isEven = 0;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1124,8 +1131,13 @@ public:
cudaStream_t stream,
::cuda::std::true_type /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 0;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1292,8 +1304,13 @@ public:
cudaStream_t stream,
::cuda::std::false_type /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 1;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down Expand Up @@ -1513,8 +1530,13 @@ public:
cudaStream_t stream,
::cuda::std::true_type /*is_byte_sample*/)
{
using MaxPolicyT = typename PolicyHub::MaxPolicy;
cudaError error = cudaSuccess;
static constexpr bool isEven = 1;
using fallback_policy_hub = detail::histogram::
policy_hub<detail::value_t<SampleIteratorT>, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>;

using MaxPolicyT =
typename cuda::std::_If<cuda::std::is_void<PolicyHub>::value, fallback_policy_hub, PolicyHub>::MaxPolicy;
cudaError error = cudaSuccess;

do
{
Expand Down
74 changes: 72 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ enum class sample_size
{
_1,
_2,
_4,
_8,
unknown
};

Expand Down Expand Up @@ -125,7 +127,52 @@ struct sm90_tuning<SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sampl
static constexpr bool work_stealing = false;
};

template <class SampleT, class CounterT, int NumChannels, int NumActiveChannels>
template <bool IsEven,
class SampleT,
int NumChannels,
int NumActiveChannels,
counter_size CounterSize,
primitive_sample PrimitiveSample = is_primitive_sample<SampleT>(),
sample_size SampleSize = classify_sample_size<SampleT>()>
struct sm100_tuning;

// even
template <class SampleT>
struct sm100_tuning<true, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_12.tpb_928.rle_0.ws_0.mem_1.ld_2.laid_0.vec_2 1.033332 0.940517 1.031835 1.195876
static constexpr int items = 12;
static constexpr int threads = 928;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_CA;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr int vec_size = 1 << 2;
};

// sample_size 2/4/8 showed no benefit over SM90 during verification benchmarks

// range
template <class SampleT>
struct sm100_tuning<false, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1>
{
// ipt_12.tpb_448.rle_0.ws_0.mem_1.ld_1.laid_0.vec_2 1.078987 0.985542 1.085118 1.175637
static constexpr int items = 12;
static constexpr int threads = 448;
static constexpr bool rle_compress = false;
static constexpr bool work_stealing = false;
static constexpr BlockHistogramMemoryPreference mem_preference = SMEM;
static constexpr CacheLoadModifier load_modifier = LOAD_LDG;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr int vec_size = 1 << 2;
};

// sample_size 2/4/8 showed no benefit over SM90 during verification benchmarks

// multi.even and multi.range: none of the found tunings surpassed the SM90 tuning during verification benchmarks

template <class SampleT, class CounterT, int NumChannels, int NumActiveChannels, bool IsEven>
struct policy_hub
{
// TODO(bgruber): move inside t_scale in C++14
Expand Down Expand Up @@ -173,7 +220,30 @@ struct policy_hub
sm90_tuning<SampleT, NumChannels, NumActiveChannels, histogram::classify_counter_size<CounterT>()>>(0));
};

using MaxPolicy = Policy900;
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
{
// Use values from tuning if a specialization exists, otherwise pick Policy900
template <typename Tuning>
static auto select_agent_policy(int)
-> AgentHistogramPolicy<Tuning::threads,
Tuning::items,
Tuning::load_algorithm,
Tuning::load_modifier,
Tuning::rle_compress,
Tuning::mem_preference,
Tuning::work_stealing,
Tuning::vec_size>;

template <typename Tuning>
static auto select_agent_policy(long) -> typename Policy900::AgentHistogramPolicyT;

using AgentHistogramPolicyT =
decltype(select_agent_policy<
sm100_tuning<IsEven, SampleT, NumChannels, NumActiveChannels, histogram::classify_counter_size<CounterT>()>>(
0));
};

using MaxPolicy = Policy1000;
};
} // namespace histogram
} // namespace detail
Expand Down