From de1c34096276955aab36e8c93f23f8d9dae1510a Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Fri, 31 Jan 2025 07:35:47 +0100 Subject: [PATCH] Add b200 policies for reduce (#3612) * Add b200 policies for cub.device.reduce.sum * Add b200 policies for reduce.min --------- Co-authored-by: Giannis Gonidelis --- .../device/dispatch/tuning/tuning_reduce.cuh | 321 +++++++++++++++++- 1 file changed, 320 insertions(+), 1 deletion(-) diff --git a/cub/cub/device/dispatch/tuning/tuning_reduce.cuh b/cub/cub/device/dispatch/tuning/tuning_reduce.cuh index d4719820752..94c1b7127d6 100644 --- a/cub/cub/device/dispatch/tuning/tuning_reduce.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_reduce.cuh @@ -75,6 +75,299 @@ CUB_RUNTIME_FUNCTION ReducePolicyWrapper MakeReducePolicyWrapper(Policy { return ReducePolicyWrapper{policy}; } +enum class offset_size +{ + _4, + _8, + unknown +}; +enum class op_type +{ + plus, + min_or_max, + unknown +}; +enum class accum_size +{ + _1, + _2, + _4, + _8, + _16, + unknown +}; +template +constexpr accum_size classify_accum_size() +{ + return sizeof(AccumT) == 1 ? accum_size::_1 + : sizeof(AccumT) == 2 ? accum_size::_2 + : sizeof(AccumT) == 4 ? accum_size::_4 + : sizeof(AccumT) == 8 ? accum_size::_8 + : sizeof(AccumT) == 16 + ? accum_size::_16 + : accum_size::unknown; +} +template +constexpr offset_size classify_offset_size() +{ + return sizeof(OffsetT) == 4 ? offset_size::_4 : sizeof(OffsetT) == 8 ? offset_size::_8 : offset_size::unknown; +} + +template +struct is_plus +{ + static constexpr bool value = false; +}; + +template +struct is_plus<::cuda::std::plus> +{ + static constexpr bool value = true; +}; +template +struct is_min_or_max +{ + static constexpr bool value = false; +}; +template +struct is_min_or_max<::cuda::minimum> +{ + static constexpr bool value = true; +}; +template +struct is_min_or_max<::cuda::maximum> +{ + static constexpr bool value = true; +}; + +template +constexpr op_type classify_op() +{ + return is_plus::value + ? op_type::plus + : (is_min_or_max::value ? op_type::min_or_max : op_type::unknown); +} + +template (), + offset_size OffsetSize = classify_offset_size(), + accum_size AccumSize = classify_accum_size()> +struct sm100_tuning; + +// sum +template +struct sm100_tuning +{ + // todo(gonidelis): Very low performance, we need more runs. + // ipt_16.tpb_256.ipv_2 1.001174 1.0 1.001044 1.004175 + static constexpr int items = 16; + static constexpr int threads = 256; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_18.tpb_288.ipv_2 1.032068 0.997167 1.028244 1.115809 + static constexpr int items = 18; + static constexpr int threads = 288; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_15.tpb_960.ipv_1 1.040241 0.988042 1.038795 1.167139 + static constexpr int items = 15; + static constexpr int threads = 960; + static constexpr int items_per_vec_load = 1; +}; + +template +struct sm100_tuning +{ + // ipt_15.tpb_512.ipv_2 1.019887 1.0 1.017636 1.058036 + static constexpr int items = 15; + static constexpr int threads = 512; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_14.tpb_288.ipv_2 1.036897 1.000000 1.032813 1.13125 + static constexpr int items = 14; + static constexpr int threads = 288; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_12.tpb_224.ipv_2 1.032496 1.000000 1.028899 1.115596 + static constexpr int items = 12; + static constexpr int threads = 224; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_14.tpb_288.ipv_1 1.050725 1.000000 1.048286 1.181818 + static constexpr int items = 14; + static constexpr int threads = 288; + static constexpr int items_per_vec_load = 1; +}; + +template +struct sm100_tuning +{ + // ipt_15.tpb_512.ipv_1 1.019414 1.000000 1.017218 1.057143 + static constexpr int items = 15; + static constexpr int threads = 512; + static constexpr int items_per_vec_load = 1; +}; + +template +struct sm100_tuning +{ + // ipt_16.tpb_512.ipv_2 1.061295 1.000000 1.065478 1.167139 + static constexpr int items = 16; + static constexpr int threads = 512; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_16.tpb_640.ipv_1 1.017834 1.000000 1.015835 1.057092 + static constexpr int items = 16; + static constexpr int threads = 640; + static constexpr int items_per_vec_load = 1; +}; + +// min or max +template +struct sm100_tuning +{ + // ipt_16.tpb_128.ipv_2 1.021369 0.998557 1.019009 1.077479 + static constexpr int items = 16; + static constexpr int threads = 128; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_16.tpb_256.ipv_2 1.038750 1.0 1.034382 1.117647 + static constexpr int items = 16; + static constexpr int threads = 256; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_12.tpb_448.ipv_1 1.037834 1.000000 1.036212 1.144847 + static constexpr int items = 12; + static constexpr int threads = 448; + static constexpr int items_per_vec_load = 1; +}; + +template +struct sm100_tuning +{ + // ipt_15.tpb_512.ipv_2 1.020165 1.0 1.018162 1.058036 + static constexpr int items = 15; + static constexpr int threads = 512; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_16.tpb_320.ipv_2 1.009217 1.0 1.008197 1.032787 + static constexpr int items = 16; + static constexpr int threads = 320; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_18.tpb_448.ipv_2 1.032745 0.966480 1.032123 1.162011 + static constexpr int items = 18; + static constexpr int threads = 448; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_15.tpb_512.ipv_2 1.019901 1.0 1.017648 1.058036 + static constexpr int items = 15; + static constexpr int threads = 512; + static constexpr int items_per_vec_load = 2; +}; + +// same as base, so fall back to Policy600 +// template +// struct sm100_tuning {}; + +template +struct sm100_tuning +{ + // ipt_16.tpb_224.ipv_2 1.031922 0.997989 1.028396 1.115596 + static constexpr int items = 16; + static constexpr int threads = 224; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_14.tpb_416.ipv_1 1.047490 1.000000 1.045455 1.181818 + static constexpr int items = 14; + static constexpr int threads = 416; + static constexpr int items_per_vec_load = 1; +}; + +template +struct sm100_tuning +{ + // ipt_21.tpb_384.ipv_2 1.021487 1.0 1.019033 1.057143 + static constexpr int items = 21; + static constexpr int threads = 384; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_17.tpb_512.ipv_2 1.003412 0.980713 1.003111 1.031730 + static constexpr int items = 17; + static constexpr int threads = 512; + static constexpr int items_per_vec_load = 2; +}; + +template +struct sm100_tuning +{ + // ipt_18.tpb_448.ipv_1 1.023427 1.000000 1.022287 1.083333 + static constexpr int items = 18; + static constexpr int threads = 448; + static constexpr int items_per_vec_load = 1; +}; + +template +struct sm100_tuning +{ + // ipt_16.tpb_320.ipv_2 1.018602 1.0 1.016518 1.059821 + static constexpr int items = 16; + static constexpr int threads = 320; + static constexpr int items_per_vec_load = 1; +}; template struct policy_hub @@ -117,7 +410,33 @@ struct policy_hub using SegmentedReducePolicy = ReducePolicy; }; - using MaxPolicy = Policy600; + struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy600> + { + // Use values from tuning if a specialization exists, otherwise pick Policy600 + template + static auto select_agent_policy(int) + -> AgentReducePolicy; + // use Policy600 as DefaultPolicy + template + static auto select_agent_policy(long) -> typename Policy600::ReducePolicy; + + using ReducePolicy = + decltype(select_agent_policy(), + classify_offset_size(), + classify_accum_size()>>(0)); + + using SingleTilePolicy = ReducePolicy; + using SegmentedReducePolicy = ReducePolicy; + }; + + using MaxPolicy = Policy1000; }; } // namespace reduce } // namespace detail