diff --git a/cub/cub/device/dispatch/tuning/tuning_run_length_encode.cuh b/cub/cub/device/dispatch/tuning/tuning_run_length_encode.cuh index a2d62e6ab2b..7c086b13d0d 100644 --- a/cub/cub/device/dispatch/tuning/tuning_run_length_encode.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_run_length_encode.cuh @@ -232,6 +232,75 @@ struct sm90_tuning(), + primitive_key PrimitiveKey = is_primitive_key(), + length_size LengthSize = classify_length_size(), + key_size KeySize = classify_key_size()> +struct sm100_tuning; + +template +struct sm100_tuning +{ + // ipt_14.tpb_256.trp_0.ld_1.ns_468.dcid_7.l2w_300 1.202228 1.126160 1.197973 1.307692 + static constexpr int threads = 256; + static constexpr int items = 14; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_CA; + using delay_constructor = detail::exponential_backon_constructor_t<468, 300>; +}; + +template +struct sm100_tuning +{ + // ipt_14.tpb_224.trp_0.ld_0.ns_376.dcid_7.l2w_420 1.123754 1.002404 1.113839 1.274882 + static constexpr int threads = 224; + static constexpr int items = 14; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = detail::exponential_backon_constructor_t<376, 420>; +}; + +template +struct sm100_tuning +{ + // ipt_14.tpb_256.trp_0.ld_1.ns_956.dcid_7.l2w_70 1.134395 1.071951 1.137008 1.169419 + static constexpr int threads = 256; + static constexpr int items = 14; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_CA; + using delay_constructor = detail::exponential_backon_constructor_t<956, 70>; +}; + +template +struct sm100_tuning +{ + // ipt_9.tpb_224.trp_1.ld_0.ns_188.dcid_2.l2w_765 1.100140 1.020069 1.116462 1.345506 + static constexpr int threads = 224; + static constexpr int items = 9; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = detail::exponential_backoff_constructor_t<188, 765>; +}; + +// TODO(gonidelis): Tune for I128. +#if CUB_IS_INT128_ENABLED +// template +// struct sm100_tuning +// { +// static constexpr int threads = 128; +// static constexpr int items = 11; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// using delay_constructor = detail::fixed_delay_constructor_t<428, 930>; +// }; + +// template +// struct sm100_tuning +// : sm100_tuning +// {}; +#endif + // this policy is passed to DispatchReduceByKey template struct policy_hub @@ -258,7 +327,6 @@ struct policy_hub default_reduce_by_key_delay_constructor_t>; }; - // SM35 struct Policy350 : DefaultPolicy , ChainedPolicy<350, Policy350, Policy350> @@ -276,25 +344,39 @@ struct policy_hub template static auto select_agent_policy(long) -> typename DefaultPolicy::ReduceByKeyPolicyT; - // SM80 struct Policy800 : ChainedPolicy<800, Policy800, Policy350> { - using ReduceByKeyPolicyT = decltype(select_agent_policy>(0)); + using ReduceByKeyPolicyT = decltype(select_agent_policy>(0)); }; - // SM86 struct Policy860 : DefaultPolicy , ChainedPolicy<860, Policy860, Policy800> {}; - // SM90 struct Policy900 : ChainedPolicy<900, Policy900, Policy860> { - using ReduceByKeyPolicyT = decltype(select_agent_policy>(0)); + using ReduceByKeyPolicyT = decltype(select_agent_policy>(0)); + }; + + struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900> + { + // Use values from tuning if a specialization exists, otherwise pick Policy900 + template + static auto select_agent_policy100(int) + -> AgentReduceByKeyPolicy; + template + static auto select_agent_policy100(long) -> typename Policy900::ReduceByKeyPolicyT; + + using ReduceByKeyPolicyT = decltype(select_agent_policy100>(0)); }; - using MaxPolicy = Policy900; + using MaxPolicy = Policy1000; }; } // namespace encode @@ -431,6 +513,86 @@ struct sm90_tuning(), + primitive_key PrimitiveKey = is_primitive_key(), + length_size LengthSize = classify_length_size(), + key_size KeySize = classify_key_size()> +struct sm100_tuning; + +template +struct sm100_tuning +{ + // ipt_20.tpb_224.trp_1.ts_0.ld_1.ns_64.dcid_2.l2w_315 1.119878 1.003690 1.130067 1.338983 + static constexpr int threads = 224; + static constexpr int items = 20; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr bool store_with_time_slicing = false; + static constexpr CacheLoadModifier load_modifier = LOAD_CA; + using delay_constructor = detail::exponential_backoff_constructor_t<64, 315>; +}; + +template +struct sm100_tuning +{ + // ipt_20.tpb_224.trp_1.ts_0.ld_0.ns_116.dcid_7.l2w_340 1.146528 1.072769 1.152390 1.333333 + static constexpr int threads = 224; + static constexpr int items = 20; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr bool store_with_time_slicing = false; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = detail::exponential_backon_constructor_t<116, 340>; +}; + +template +struct sm100_tuning +{ + // ipt_13.tpb_224.trp_0.ts_0.ld_0.ns_252.dcid_2.l2w_470 1.113202 1.003690 1.133114 1.349296 + static constexpr int threads = 224; + static constexpr int items = 13; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr bool store_with_time_slicing = false; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = detail::exponential_backoff_constructor_t<252, 470>; +}; + +template +struct sm100_tuning +{ + // ipt_15.tpb_256.trp_1.ts_0.ld_0.ns_28.dcid_2.l2w_520 1.114944 1.033189 1.122360 1.252083 + static constexpr int threads = 256; + static constexpr int items = 15; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr bool store_with_time_slicing = false; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = detail::exponential_backoff_constructor_t<28, 520>; +}; +// Fall back to Policy900 for double, because that one performs better than the above tuning (same key_size) +// TODO(bgruber): in C++20 put a requires(!std::is_same_v) onto the above tuning and delete this one +template +struct sm100_tuning + : sm90_tuning +{}; + +// TODO(gonidelis): Tune for I128. +#if CUB_IS_INT128_ENABLED +// template +// struct sm100_tuning +// { +// static constexpr int threads = 288; +// static constexpr int items = 9; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// static constexpr bool store_with_time_slicing = false; +// using delay_constructor = detail::fixed_delay_constructor_t<484, 1150>; +// }; + +// template +// struct sm100_tuning +// : sm100_tuning +// {}; +#endif + template struct policy_hub { @@ -451,7 +613,6 @@ struct policy_hub default_reduce_by_key_delay_constructor_t>; }; - // SM35 struct Policy350 : DefaultPolicy // TODO(bgruber): I think we want `LengthT` instead of `int` , ChainedPolicy<350, Policy350, Policy350> @@ -471,25 +632,40 @@ struct policy_hub static auto select_agent_policy(long) -> typename DefaultPolicy::RleSweepPolicyT; - // SM80 struct Policy800 : ChainedPolicy<800, Policy800, Policy350> { - using RleSweepPolicyT = decltype(select_agent_policy>(0)); + using RleSweepPolicyT = decltype(select_agent_policy>(0)); }; - // SM86 struct Policy860 : DefaultPolicy // TODO(bgruber): I think we want `LengthT` instead of `int` , ChainedPolicy<860, Policy860, Policy800> {}; - // SM90 struct Policy900 : ChainedPolicy<900, Policy900, Policy860> { - using RleSweepPolicyT = decltype(select_agent_policy>(0)); + using RleSweepPolicyT = decltype(select_agent_policy>(0)); + }; + + struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900> + { + // Use values from tuning if a specialization exists, otherwise pick Policy900 + template + static auto select_agent_policy100(int) + -> AgentRlePolicy; + template + static auto select_agent_policy100(long) -> typename Policy900::RleSweepPolicyT; + + using RleSweepPolicyT = decltype(select_agent_policy100>(0)); }; - using MaxPolicy = Policy900; + using MaxPolicy = Policy1000; }; } // namespace non_trivial_runs } // namespace rle