diff --git a/cub/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh b/cub/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh index a5ad19df8cc..db56906e314 100644 --- a/cub/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh @@ -607,6 +607,252 @@ struct sm90_tuning; }; +template (), + primitive_accum PrimitiveAccum = is_primitive_accum(), + key_size KeySize = classify_key_size(), + accum_size AccumSize = classify_accum_size()> +struct sm100_tuning; + +// 8-bit key +template +struct sm100_tuning +{ + // ipt_13.tpb_576.trp_0.ld_1.ns_2044.dcid_5.l2w_240 1.161888 0.848558 1.134941 1.299109 + static constexpr int items = 13; + static constexpr int threads = 576; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + using delay_constructor = exponential_backon_jitter_window_constructor_t<2044, 240>; + static constexpr CacheLoadModifier load_modifier = LOAD_CA; +}; + +template +struct sm100_tuning +{ + // ipt_10.tpb_224.trp_0.ld_0.ns_244.dcid_4.l2w_390 1.313932 1.260540 1.319588 1.427374 + static constexpr int items = 10; + static constexpr int threads = 224; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + using delay_constructor = exponential_backoff_jitter_window_constructor_t<224, 390>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +template +struct sm100_tuning +{ + // ipt_14.tpb_128.trp_0.ld_0.ns_248.dcid_2.l2w_285 1.118109 1.051534 1.134336 1.326788 + static constexpr int items = 14; + static constexpr int threads = 128; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + using delay_constructor = exponential_backoff_constructor_t<248, 285>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +template +struct sm100_tuning +{ + // ipt_19.tpb_128.trp_1.ld_0.ns_132.dcid_1.l2w_540 1.113820 1.002404 1.105014 1.202296 + static constexpr int items = 19; + static constexpr int threads = 128; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + using delay_constructor = fixed_delay_constructor_t<132, 540>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +// todo(gonidelis): Add tunings for I128. +// template +// struct sm100_tuning +// { +// static constexpr int threads = 128; +// static constexpr int items = 11; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// using delay_constructor = detail::no_delay_constructor_t<1100>; +// }; + +// 16-bit key +template +struct sm100_tuning +{ + // ipt_14.tpb_128.trp_1.ld_0.ns_164.dcid_2.l2w_290 1.239579 1.119705 1.239111 1.313112 + static constexpr int items = 14; + static constexpr int threads = 128; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + using delay_constructor = detail::exponential_backoff_constructor_t<164, 290>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +template +struct sm100_tuning +{ + // ipt_14.tpb_256.trp_1.ld_0.ns_180.dcid_2.l2w_975 1.145635 1.012658 1.139956 1.251546 + static constexpr int items = 14; + static constexpr int threads = 256; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + using delay_constructor = exponential_backoff_constructor_t<180, 975>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +template +struct sm100_tuning +{ + // ipt_11.tpb_256.trp_0.ld_0.ns_224.dcid_2.l2w_550 1.066293 1.000109 1.073092 1.181818 + static constexpr int items = 11; + static constexpr int threads = 256; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + using delay_constructor = exponential_backoff_constructor_t<224, 550>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +template +struct sm100_tuning +{ + // ipt_10.tpb_160.trp_1.ld_0.ns_156.dcid_1.l2w_725 1.045007 1.002105 1.049690 1.141827 + static constexpr int items = 10; + static constexpr int threads = 160; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + using delay_constructor = fixed_delay_constructor_t<156, 725>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +// todo(gonidelis): Add tunings for I128. +// template +// struct sm100_tuning +// { +// static constexpr int threads = 128; +// static constexpr int items = 11; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// using delay_constructor = detail::no_delay_constructor_t<1100>; +// }; + +// 32-bit key +template +struct sm100_tuning +{ + // ipt_10.tpb_224.trp_0.ld_0.ns_324.dcid_2.l2w_285 1.157217 1.073724 1.166510 1.356940 + static constexpr int items = 10; + static constexpr int threads = 224; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + using delay_constructor = exponential_backoff_constructor_t<324, 285>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +template +struct sm100_tuning +{ + // ipt_11.tpb_256.trp_0.ld_0.ns_1984.dcid_5.l2w_115 1.214155 1.128842 1.214093 1.364476 + static constexpr int items = 11; + static constexpr int threads = 256; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + using delay_constructor = exponential_backon_jitter_window_constructor_t<1984, 115>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +template +struct sm100_tuning +{ + // ipt_14.tpb_224.trp_1.ld_0.ns_476.dcid_5.l2w_1005 1.187378 1.119705 1.185397 1.258420 + + static constexpr int items = 14; + static constexpr int threads = 224; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + using delay_constructor = exponential_backon_jitter_window_constructor_t<476, 1005>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +template +struct sm100_tuning +{ + // ipt_10.tpb_256.trp_1.ld_0.ns_1868.dcid_7.l2w_145 1.142915 1.020581 1.137459 1.237913 + static constexpr int items = 10; + static constexpr int threads = 256; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + using delay_constructor = exponential_backon_constructor_t<1868, 145>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +// todo(gonidelis): Add tunings for I128. +// template +// struct sm100_tuning +// { +// static constexpr int threads = 128; +// static constexpr int items = 11; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// using delay_constructor = detail::no_delay_constructor_t<1100>; +// }; + +// 64-bit key +template +struct sm100_tuning +{ + // ipt_9.tpb_224.trp_1.ld_0.ns_1940.dcid_5.l2w_460 1.157294 1.075650 1.153566 1.250729 + static constexpr int items = 9; + static constexpr int threads = 224; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + using delay_constructor = exponential_backon_jitter_window_constructor_t<1940, 460>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +template +struct sm100_tuning +{ + // ipt_11.tpb_224.trp_1.ld_1.ns_392.dcid_2.l2w_550 1.104034 1.007212 1.099543 1.220401 + static constexpr int items = 11; + static constexpr int threads = 224; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + using delay_constructor = exponential_backoff_constructor_t<392, 550>; + static constexpr CacheLoadModifier load_modifier = LOAD_CA; +}; + +template +struct sm100_tuning +{ + // ipt_9.tpb_224.trp_1.ld_0.ns_244.dcid_2.l2w_475 1.130098 1.000000 1.130661 1.215722 + static constexpr int items = 9; + static constexpr int threads = 224; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + using delay_constructor = exponential_backoff_constructor_t<244, 475>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +template +struct sm100_tuning +{ + // ipt_9.tpb_224.trp_1.ld_0.ns_196.dcid_2.l2w_340 1.272056 1.142857 1.262499 1.352941 + static constexpr int items = 9; + static constexpr int threads = 224; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + using delay_constructor = exponential_backoff_constructor_t<196, 340>; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +}; + +// todo(gonidelis): Add tunings for I128. +// template +// struct sm100_tuning +// { +// static constexpr int threads = 128; +// static constexpr int items = 11; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// using delay_constructor = detail::no_delay_constructor_t<1125>; +// }; + +// todo(gonidelis): Add tunings for 128-bit key. +// 128-bit key +// template +// struct sm90_tuning +// { +// static constexpr int threads = 128; +// static constexpr int items = 11; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// using delay_constructor = detail::no_delay_constructor_t<1125>; +// }; + template struct policy_hub { @@ -668,7 +914,25 @@ struct policy_hub decltype(select_agent_policy()>>(0)); }; - using MaxPolicy = Policy900; + struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900> + { + // Use values from tuning if a specialization exists, otherwise pick the default + template + static auto select_agent_policy(int) + -> AgentReduceByKeyPolicy; + + template + static auto select_agent_policy(long) -> typename Policy900::ReduceByKeyPolicyT; + + using ReduceByKeyPolicyT = + decltype(select_agent_policy()>>(0)); + }; + using MaxPolicy = Policy1000; }; } // namespace reduce_by_key } // namespace detail