From fe5d12ef2e112fa17cf2986959fd503946be578c Mon Sep 17 00:00:00 2001 From: Giannis Gonidelis Date: Fri, 24 Jan 2025 01:34:22 -0800 Subject: [PATCH] Add b200 tunings for histogram --- .../device/dispatch/dispatch_histogram.cuh | 42 +++- .../dispatch/tuning/tuning_histogram.cuh | 186 +++++++++++++++++- 2 files changed, 216 insertions(+), 12 deletions(-) diff --git a/cub/cub/device/dispatch/dispatch_histogram.cuh b/cub/cub/device/dispatch/dispatch_histogram.cuh index 2c2d0a2a9ca..43944dfc0b5 100644 --- a/cub/cub/device/dispatch/dispatch_histogram.cuh +++ b/cub/cub/device/dispatch/dispatch_histogram.cuh @@ -36,6 +36,8 @@ #include +#include + #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) # pragma GCC system_header #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) @@ -554,8 +556,7 @@ template , CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS>> + typename PolicyHub = void> // if user passes a custom Policy this should not be void struct DispatchHistogram { static_assert(NUM_CHANNELS <= 4, "Histograms only support up to 4 channels"); @@ -920,8 +921,14 @@ public: cudaStream_t stream, Int2Type /*is_byte_sample*/) { - using MaxPolicyT = typename PolicyHub::MaxPolicy; - cudaError error = cudaSuccess; + // Should we call DispatchHistogram<....., PolicyHub=void> in DeviceHistogram? + static constexpr bool isEven = 0; + using fallback_policy_hub = detail::histogram:: + policy_hub, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>; + + using MaxPolicyT = + typename cuda::std::_If::value, fallback_policy_hub, PolicyHub>::MaxPolicy; + cudaError error = cudaSuccess; do { @@ -1091,8 +1098,13 @@ public: cudaStream_t stream, Int2Type /*is_byte_sample*/) { - using MaxPolicyT = typename PolicyHub::MaxPolicy; - cudaError error = cudaSuccess; + static constexpr bool isEven = 0; + using fallback_policy_hub = detail::histogram:: + policy_hub, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>; + + using MaxPolicyT = + typename cuda::std::_If::value, fallback_policy_hub, PolicyHub>::MaxPolicy; + cudaError error = cudaSuccess; do { @@ -1226,8 +1238,13 @@ public: cudaStream_t stream, Int2Type /*is_byte_sample*/) { - using MaxPolicyT = typename PolicyHub::MaxPolicy; - cudaError error = cudaSuccess; + static constexpr bool isEven = 1; + using fallback_policy_hub = detail::histogram:: + policy_hub, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>; + + using MaxPolicyT = + typename cuda::std::_If::value, fallback_policy_hub, PolicyHub>::MaxPolicy; + cudaError error = cudaSuccess; do { @@ -1412,8 +1429,13 @@ public: cudaStream_t stream, Int2Type /*is_byte_sample*/) { - using MaxPolicyT = typename PolicyHub::MaxPolicy; - cudaError error = cudaSuccess; + static constexpr bool isEven = 1; + using fallback_policy_hub = detail::histogram:: + policy_hub, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>; + + using MaxPolicyT = + typename cuda::std::_If::value, fallback_policy_hub, PolicyHub>::MaxPolicy; + cudaError error = cudaSuccess; do { diff --git a/cub/cub/device/dispatch/tuning/tuning_histogram.cuh b/cub/cub/device/dispatch/tuning/tuning_histogram.cuh index bd19489971e..3ae3f7fc58a 100644 --- a/cub/cub/device/dispatch/tuning/tuning_histogram.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_histogram.cuh @@ -60,6 +60,8 @@ enum class sample_size { _1, _2, + _4, + _8, unknown }; @@ -125,7 +127,164 @@ struct sm90_tuning +template (), + sample_size SampleSize = classify_sample_size()> +struct sm100_tuning; + +// even +template +struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1> +{ + // ipt_12.tpb_928.rle_0.ws_0.mem_1.ld_2.laid_0.vec_2 1.033332 0.940517 1.031835 1.195876 + static constexpr int items = 12; + static constexpr int threads = 928; + static constexpr bool rle_compress = false; + static constexpr bool work_stealing = false; + static constexpr BlockHistogramMemoryPreference mem_preference = SMEM; + static constexpr CacheLoadModifier load_modifier = LOAD_CA; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr int tune_vec_size = 1 << 2; +}; + +// same as base +template +struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2> + : sm90_tuning +{}; + +// same as base +template +struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4> + : sm90_tuning +{}; + +// same as base +template +struct sm100_tuning<1, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8> + : sm90_tuning +{}; + +// range +template +struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_1> +{ + // ipt_12.tpb_448.rle_0.ws_0.mem_1.ld_1.laid_0.vec_2 1.078987 0.985542 1.085118 1.175637 + static constexpr int items = 12; + static constexpr int threads = 448; + static constexpr bool rle_compress = false; + static constexpr bool work_stealing = false; + static constexpr BlockHistogramMemoryPreference mem_preference = SMEM; + static constexpr CacheLoadModifier load_modifier = LOAD_LDG; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr int tune_vec_size = 1 << 2; +}; + +// same as base +template +struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_2> + : sm90_tuning +{}; + +template +struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_4> +{ + // ipt_9.tpb_1024.rle_1.ws_0.mem_1.ld_0.laid_1.vec_0 1.358537 1.001009 1.373329 2.614104 + static constexpr int items = 9; + static constexpr int threads = 1024; + static constexpr bool rle_compress = true; + static constexpr bool work_stealing = false; + static constexpr BlockHistogramMemoryPreference mem_preference = SMEM; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr int tune_vec_size = 1 << 0; +}; + +template +struct sm100_tuning<0, SampleT, 1, 1, counter_size::_4, primitive_sample::yes, sample_size::_8> +{ + // ipt_7.tpb_544.rle_1.ws_0.mem_1.ld_1.laid_0.vec_0 1.105331 0.934888 1.108557 1.391657 + static constexpr int items = 7; + static constexpr int threads = 544; + static constexpr bool rle_compress = true; + static constexpr bool work_stealing = false; + static constexpr BlockHistogramMemoryPreference mem_preference = SMEM; + static constexpr CacheLoadModifier load_modifier = LOAD_LDG; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr int tune_vec_size = 1 << 0; +}; + +// multi.even +template +struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_1> +{ + // ipt_9.tpb_1024.rle_0.ws_0.mem_1.ld_1.laid_1.vec_0 1.629591 0.997416 1.570900 2.772504 + static constexpr int items = 9; + static constexpr int threads = 1024; + static constexpr bool rle_compress = false; + static constexpr bool work_stealing = false; + static constexpr BlockHistogramMemoryPreference mem_preference = SMEM; + static constexpr CacheLoadModifier load_modifier = LOAD_LDG; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr int tune_vec_size = 1 << 0; +}; + +// same as base +template +struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_2> + : sm90_tuning +{}; + +// same as base +template +struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_4> + : sm90_tuning +{}; + +// same as base +template +struct sm100_tuning<1, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_8> + : sm90_tuning +{}; + +// multi.range +template +struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_1> +{ + // ipt_7.tpb_160.rle_0.ws_0.mem_1.ld_1.laid_1.vec_1 1.210837 0.99556 1.189049 1.939584 + static constexpr int items = 7; + static constexpr int threads = 160; + static constexpr bool rle_compress = false; + static constexpr bool work_stealing = false; + static constexpr BlockHistogramMemoryPreference mem_preference = SMEM; + static constexpr CacheLoadModifier load_modifier = LOAD_LDG; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr int tune_vec_size = 1 << 1; +}; + +// same as base +template +struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_2> + : sm90_tuning +{}; + +// same as base +template +struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_4> + : sm90_tuning +{}; + +// same as base +template +struct sm100_tuning<0, SampleT, 4, 3, counter_size::_4, primitive_sample::yes, sample_size::_8> + : sm90_tuning +{}; + +template struct policy_hub { // TODO(bgruber): move inside t_scale in C++14 @@ -166,7 +325,30 @@ struct policy_hub sm90_tuning()>>(0)); }; - using MaxPolicy = Policy900; + struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900> + { + // Use values from tuning if a specialization exists, otherwise pick Policy900 + template + static auto select_agent_policy(int) + -> AgentHistogramPolicy; + + template + static auto select_agent_policy(long) -> typename Policy900::AgentHistogramPolicyT; + + using AgentHistogramPolicyT = + decltype(select_agent_policy< + sm100_tuning()>>( + 0)); + }; + + using MaxPolicy = Policy1000; }; } // namespace histogram } // namespace detail