diff --git a/cub/cub/device/dispatch/dispatch_histogram.cuh b/cub/cub/device/dispatch/dispatch_histogram.cuh index 53428ec9cca..dc161a11000 100644 --- a/cub/cub/device/dispatch/dispatch_histogram.cuh +++ b/cub/cub/device/dispatch/dispatch_histogram.cuh @@ -37,6 +37,8 @@ #include +#include + #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) # pragma GCC system_header #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) @@ -554,8 +556,7 @@ template , CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS>> + typename PolicyHub = void> // if user passes a custom Policy this should not be void struct DispatchHistogram { static_assert(NUM_CHANNELS <= 4, "Histograms only support up to 4 channels"); @@ -920,8 +921,14 @@ public: cudaStream_t stream, ::cuda::std::false_type /*is_byte_sample*/) { - using MaxPolicyT = typename PolicyHub::MaxPolicy; - cudaError error = cudaSuccess; + // Should we call DispatchHistogram<....., PolicyHub=void> in DeviceHistogram? + static constexpr bool isEven = 0; + using fallback_policy_hub = detail::histogram:: + policy_hub, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>; + + using MaxPolicyT = + typename cuda::std::_If::value, fallback_policy_hub, PolicyHub>::MaxPolicy; + cudaError error = cudaSuccess; do { @@ -1124,8 +1131,13 @@ public: cudaStream_t stream, ::cuda::std::true_type /*is_byte_sample*/) { - using MaxPolicyT = typename PolicyHub::MaxPolicy; - cudaError error = cudaSuccess; + static constexpr bool isEven = 0; + using fallback_policy_hub = detail::histogram:: + policy_hub, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>; + + using MaxPolicyT = + typename cuda::std::_If::value, fallback_policy_hub, PolicyHub>::MaxPolicy; + cudaError error = cudaSuccess; do { @@ -1292,8 +1304,13 @@ public: cudaStream_t stream, ::cuda::std::false_type /*is_byte_sample*/) { - using MaxPolicyT = typename PolicyHub::MaxPolicy; - cudaError error = cudaSuccess; + static constexpr bool isEven = 1; + using fallback_policy_hub = detail::histogram:: + policy_hub, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>; + + using MaxPolicyT = + typename cuda::std::_If::value, fallback_policy_hub, PolicyHub>::MaxPolicy; + cudaError error = cudaSuccess; do { @@ -1513,8 +1530,13 @@ public: cudaStream_t stream, ::cuda::std::true_type /*is_byte_sample*/) { - using MaxPolicyT = typename PolicyHub::MaxPolicy; - cudaError error = cudaSuccess; + static constexpr bool isEven = 1; + using fallback_policy_hub = detail::histogram:: + policy_hub, CounterT, NUM_CHANNELS, NUM_ACTIVE_CHANNELS, isEven>; + + using MaxPolicyT = + typename cuda::std::_If::value, fallback_policy_hub, PolicyHub>::MaxPolicy; + cudaError error = cudaSuccess; do { diff --git a/cub/cub/device/dispatch/tuning/tuning_histogram.cuh b/cub/cub/device/dispatch/tuning/tuning_histogram.cuh index 234be77ea71..3dad811958e 100644 --- a/cub/cub/device/dispatch/tuning/tuning_histogram.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_histogram.cuh @@ -60,6 +60,8 @@ enum class sample_size { _1, _2, + _4, + _8, unknown }; @@ -125,7 +127,52 @@ struct sm90_tuning +template (), + sample_size SampleSize = classify_sample_size()> +struct sm100_tuning; + +// even +template +struct sm100_tuning +{ + // ipt_12.tpb_928.rle_0.ws_0.mem_1.ld_2.laid_0.vec_2 1.033332 0.940517 1.031835 1.195876 + static constexpr int items = 12; + static constexpr int threads = 928; + static constexpr bool rle_compress = false; + static constexpr bool work_stealing = false; + static constexpr BlockHistogramMemoryPreference mem_preference = SMEM; + static constexpr CacheLoadModifier load_modifier = LOAD_CA; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr int vec_size = 1 << 2; +}; + +// sample_size 2/4/8 showed no benefit over SM90 during verification benchmarks + +// range +template +struct sm100_tuning +{ + // ipt_12.tpb_448.rle_0.ws_0.mem_1.ld_1.laid_0.vec_2 1.078987 0.985542 1.085118 1.175637 + static constexpr int items = 12; + static constexpr int threads = 448; + static constexpr bool rle_compress = false; + static constexpr bool work_stealing = false; + static constexpr BlockHistogramMemoryPreference mem_preference = SMEM; + static constexpr CacheLoadModifier load_modifier = LOAD_LDG; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr int vec_size = 1 << 2; +}; + +// sample_size 2/4/8 showed no benefit over SM90 during verification benchmarks + +// multi.even and multi.range: none of the found tunings surpassed the SM90 tuning during verification benchmarks + +template struct policy_hub { // TODO(bgruber): move inside t_scale in C++14 @@ -173,7 +220,30 @@ struct policy_hub sm90_tuning()>>(0)); }; - using MaxPolicy = Policy900; + struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900> + { + // Use values from tuning if a specialization exists, otherwise pick Policy900 + template + static auto select_agent_policy(int) + -> AgentHistogramPolicy; + + template + static auto select_agent_policy(long) -> typename Policy900::AgentHistogramPolicyT; + + using AgentHistogramPolicyT = + decltype(select_agent_policy< + sm100_tuning()>>( + 0)); + }; + + using MaxPolicy = Policy1000; }; } // namespace histogram } // namespace detail