From 16386feafdb68c462565403b0f337975ee57479f Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Tue, 4 Feb 2025 18:22:04 +0100 Subject: [PATCH] Add b200 policies for cub.select.unique_by_key (#3557) Co-authored-by: Giannis Gonidelis --- .../dispatch/tuning/tuning_unique_by_key.cuh | 267 +++++++++++++++++- 1 file changed, 266 insertions(+), 1 deletion(-) diff --git a/cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh b/cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh index 0c6b717de2c..f28b3b737eb 100644 --- a/cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh @@ -522,6 +522,254 @@ struct sm90_tuning; }; +template (), + primitive_val PrimitiveAccum = is_primitive_val(), + key_size KeySize = classify_key_size(), + val_size AccumSize = classify_val_size()> +struct sm100_tuning; + +// 8-bit key +template +struct sm100_tuning +{ + // ipt_12.tpb_512.trp_0.ld_0.ns_948.dcid_5.l2w_955 1.121279 1.000000 1.114566 1.43765 + static constexpr int threads = 512; + static constexpr int items = 12; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_jitter_window_constructor_t<948, 955>; +}; + +template +struct sm100_tuning +{ + // ipt_14.tpb_512.trp_0.ld_0.ns_1228.dcid_7.l2w_320 1.151229 1.007229 1.151131 1.443520 + static constexpr int threads = 512; + static constexpr int items = 14; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_constructor_t<1228, 320>; +}; + +template +struct sm100_tuning +{ + // ipt_14.tpb_512.trp_0.ld_0.ns_2016.dcid_7.l2w_620 1.165300 1.095238 1.164478 1.266667 + static constexpr int threads = 512; + static constexpr int items = 14; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_constructor_t<2016, 620>; +}; + +template +struct sm100_tuning +{ + // ipt_10.tpb_384.trp_0.ld_0.ns_1728.dcid_5.l2w_980 1.118716 0.997167 1.116537 1.400000 + static constexpr int threads = 384; + static constexpr int items = 10; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_jitter_window_constructor_t<1728, 980>; +}; + +// TODO(gonidelis): Tune for I128. +#if CUB_IS_INT128_ENABLED +// template +// struct sm100_tuning +// { +// static constexpr int threads = 288; +// static constexpr int items = 7; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +// using delay_constructor = fixed_delay_constructor_t<344, 1165>; +// }; +#endif + +// 16-bit key +template +struct sm100_tuning +{ + // ipt_14.tpb_512.trp_0.ld_0.ns_508.dcid_7.l2w_1020 1.171886 0.906530 1.157128 1.457933 + static constexpr int threads = 512; + static constexpr int items = 14; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_constructor_t<508, 1020>; +}; + +template +struct sm100_tuning +{ + // ipt_12.tpb_384.trp_0.ld_0.ns_928.dcid_7.l2w_605 1.166564 0.997579 1.154805 1.406709 + static constexpr int threads = 384; + static constexpr int items = 12; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_constructor_t<928, 605>; +}; + +template +struct sm100_tuning +{ + // ipt_11.tpb_384.trp_0.ld_1.ns_1620.dcid_7.l2w_810 1.144483 1.011085 1.152798 1.393750 + static constexpr int threads = 384; + static constexpr int items = 11; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_CA; + using delay_constructor = exponential_backon_constructor_t<1620, 810>; +}; + +template +struct sm100_tuning +{ + // ipt_10.tpb_384.trp_0.ld_0.ns_1984.dcid_5.l2w_935 1.605554 1.177083 1.564488 1.946224 + static constexpr int threads = 384; + static constexpr int items = 10; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_jitter_window_constructor_t<1984, 935>; +}; + +// TODO(gonidelis): Tune for I128. +#if CUB_IS_INT128_ENABLED +// template +// struct sm100_tuning +// { +// static constexpr int threads = 224; +// static constexpr int items = 9; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +// using delay_constructor = fixed_delay_constructor_t<424, 1055>; +// }; +#endif + +// 32-bit key +template +struct sm100_tuning +{ + // ipt_14.tpb_512.trp_0.ld_0.ns_1136.dcid_7.l2w_605 1.148057 0.848558 1.133064 1.451074 + static constexpr int threads = 512; + static constexpr int items = 14; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_constructor_t<1136, 605>; +}; + +template +struct sm100_tuning +{ + // ipt_11.tpb_384.trp_0.ld_0.ns_656.dcid_7.l2w_825 1.216312 1.090485 1.211800 1.535714 + static constexpr int threads = 384; + static constexpr int items = 11; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_constructor_t<656, 825>; +}; + +// todo(gonidelis): tuning performs very well for medium input size, regresses for large input sizes. +// find better tuning. +template +struct sm100_tuning + : sm90_tuning +{ + // // ipt_14.tpb_512.trp_0.ld_0.ns_408.dcid_7.l2w_960 1.136333 0.995833 1.144371 1.448687 + // static constexpr int threads = 512; + // static constexpr int items = 14; + // static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + // static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + // using delay_constructor = exponential_backon_constructor_t<408, 960>; +}; + +template +struct sm100_tuning +{ + // ipt_10.tpb_384.trp_0.ld_0.ns_1012.dcid_5.l2w_800 1.164713 1.014819 1.174307 1.526042 + static constexpr int threads = 384; + static constexpr int items = 10; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_jitter_window_constructor_t<1012, 800>; +}; + +// TODO(gonidelis): Tune for I128. +#if CUB_IS_INT128_ENABLED +// template +// struct sm100_tuning +// { +// static constexpr int threads = 384; +// static constexpr int items = 7; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +// using delay_constructor = no_delay_constructor_t<1025>; +// }; +#endif + +// 64-bit key + +// todo(gonidelis): tuning regresses for large input sizes. find better tuning. +template +struct sm100_tuning + : sm90_tuning +{ + // // ipt_9.tpb_384.trp_0.ld_0.ns_1064.dcid_7.l2w_600 1.085831 0.972452 1.080521 1.397089 + // static constexpr int threads = 384; + // static constexpr int items = 9; + // static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + // static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + // using delay_constructor = exponential_backon_constructor_t<1064, 600>; +}; + +template +struct sm100_tuning +{ + // ipt_10.tpb_384.trp_0.ld_0.ns_864.dcid_5.l2w_1130 1.124095 0.985748 1.120262 1.391304 + static constexpr int threads = 384; + static constexpr int items = 10; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_jitter_window_constructor_t<864, 1130>; +}; + +template +struct sm100_tuning +{ + // ipt_10.tpb_384.trp_0.ld_0.ns_772.dcid_5.l2w_665 1.152243 1.019816 1.166636 1.517526 + static constexpr int threads = 384; + static constexpr int items = 10; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + using delay_constructor = exponential_backon_jitter_window_constructor_t<772, 665>; +}; + +// todo(gonidelis): tuning regresses for large input sizes. find better tuning. +template +struct sm100_tuning + : sm90_tuning +{ + // // ipt_7.tpb_576.trp_0.ld_0.ns_1132.dcid_5.l2w_1115 1.120721 0.977642 1.131594 1.449407 + // static constexpr int threads = 576; + // static constexpr int items = 7; + // static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + // static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; + // using delay_constructor = exponential_backon_jitter_window_constructor_t<1132, 1115>; +}; + +// TODO(gonidelis): Tune for I128. +#if CUB_IS_INT128_ENABLED +// template +// struct sm100_tuning +// { +// static constexpr int threads = 256; +// static constexpr int items = 9; +// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; +// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT; +// using delay_constructor = no_delay_constructor_t<1155>; +// }; +#endif + template struct policy_hub { @@ -575,7 +823,24 @@ struct policy_hub using UniqueByKeyPolicyT = decltype(select_agent_policy>(0)); }; - using MaxPolicy = Policy900; + struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900> + { + // Use values from tuning if a specialization exists, otherwise pick Policy900 + template + static auto select_agent_policy100(int) + -> AgentUniqueByKeyPolicy; + template + static auto select_agent_policy100(long) -> typename Policy900::UniqueByKeyPolicyT; + + using UniqueByKeyPolicyT = decltype(select_agent_policy100>(0)); + }; + + using MaxPolicy = Policy1000; }; } // namespace unique_by_key