diff --git a/cub/cub/device/dispatch/dispatch_select_if.cuh b/cub/cub/device/dispatch/dispatch_select_if.cuh index c6f2acc993a..41bd847172f 100644 --- a/cub/cub/device/dispatch/dispatch_select_if.cuh +++ b/cub/cub/device/dispatch/dispatch_select_if.cuh @@ -421,11 +421,13 @@ template , - cub::detail::value_t, - detail::select::per_partition_offset_t, - MayAlias, - KeepRejects>> + typename PolicyHub = detail::select::policy_hub< + cub::detail::value_t, + cub::detail::value_t, + detail::select::per_partition_offset_t, + MayAlias, + KeepRejects, + /* IsUnique */ ::cuda::std::is_same>::value>> struct DispatchSelectIf { /****************************************************************************** diff --git a/cub/cub/device/dispatch/tuning/tuning_select_if.cuh b/cub/cub/device/dispatch/tuning/tuning_select_if.cuh index d962a5f6e76..3a39b2e2954 100644 --- a/cub/cub/device/dispatch/tuning/tuning_select_if.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_select_if.cuh @@ -92,7 +92,7 @@ enum class input_size template struct sm80_tuning; -// select::if +// select::if / select::unique template struct sm80_tuning { @@ -145,7 +145,7 @@ struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4, {}; #endif -// select::flagged +// select::flagged[If] template struct sm80_tuning { @@ -307,7 +307,7 @@ struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4 template struct sm90_tuning; -// select::if +// select::if / select::unique template struct sm90_tuning { @@ -360,7 +360,7 @@ struct sm90_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4, {}; #endif -// select::flagged +// select::flagged[If] template struct sm90_tuning { @@ -519,12 +519,12 @@ struct sm90_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4 {}; #endif -template -struct sm100_tuning; - // select::if +template +struct sm100_select_if_tuning; + template -struct sm100_tuning +struct sm100_select_if_tuning { // trp_0.ld_0.ipt_22.tpb_384.ns_0.dcid_2.l2w_915 1.099232 0.980183 1.096778 1.545455 static constexpr int threads = 384; @@ -535,7 +535,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_if_tuning { // trp_1.ld_0.ipt_20.tpb_448.ns_596.dcid_6.l2w_295 1.214635 1.001421 1.207023 1.307692 static constexpr int threads = 448; @@ -546,7 +546,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_if_tuning { // trp_1.ld_0.ipt_20.tpb_256.ns_516.dcid_7.l2w_685 1.065598 0.937984 1.067343 1.452153 static constexpr int threads = 256; @@ -557,7 +557,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_if_tuning { // trp_1.ld_0.ipt_20.tpb_384.ns_1060.dcid_5.l2w_375 1.109871 0.973142 1.105415 1.459135 static constexpr int threads = 384; @@ -568,7 +568,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_if_tuning { // trp_1.ld_0.ipt_15.tpb_384.ns_1508.dcid_5.l2w_585 1.201993 0.920103 1.185134 1.441805 static constexpr int threads = 384; @@ -579,7 +579,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_if_tuning { // trp_1.ld_0.ipt_19.tpb_512.ns_928.dcid_7.l2w_770 1.258815 1.000000 1.235251 1.444884 static constexpr int threads = 512; @@ -591,12 +591,11 @@ struct sm100_tuning -// struct sm100_tuning +// struct sm100_select_if_tuning // {}; template -struct sm100_tuning +struct sm100_select_if_tuning { // trp_1.ld_0.ipt_23.tpb_384.ns_1140.dcid_7.l2w_520 1.081506 0.955298 1.088848 1.248971 static constexpr int threads = 384; @@ -609,7 +608,7 @@ struct sm100_tuning -// struct sm100_tuning<__int128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +// struct sm100_select_if_tuning<__int128_t, primitive::no, input_size::_16> // { // // static constexpr int threads = 512; // // static constexpr int nominal_4b_items = 5; @@ -620,7 +619,7 @@ struct sm100_tuning -// struct sm100_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +// struct sm100_select_if_tuning<__uint128_t, primitive::no, input_size::_16> // { // // static constexpr int threads = 512; // // static constexpr int nominal_4b_items = 5; @@ -631,9 +630,12 @@ struct sm100_tuning +struct sm100_select_flagged_tuning; + template -struct sm100_tuning +struct sm100_select_flagged_tuning { // trp_0.ld_0.ipt_20.tpb_896.ns_84.dcid_7.l2w_480 1.254262 0.846154 1.222437 1.462665 static constexpr int threads = 896; @@ -644,7 +646,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_flagged_tuning { // trp_0.ld_0.ipt_20.tpb_1024.ns_360.dcid_6.l2w_380 1.274174 0.748441 1.227123 1.610039 static constexpr int threads = 1024; @@ -655,7 +657,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_flagged_tuning { // trp_0.ld_0.ipt_22.tpb_256.ns_1292.dcid_5.l2w_750 1.283400 1.002841 1.267822 1.445913 static constexpr int threads = 256; @@ -666,7 +668,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_flagged_tuning { // trp_1.ld_0.ipt_20.tpb_448.ns_136.dcid_2.l2w_760 1.318819 0.994090 1.289173 1.551415 static constexpr int threads = 448; @@ -677,7 +679,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_flagged_tuning { // trp_0.ld_0.ipt_14.tpb_512.ns_844.dcid_6.l2w_675 1.207911 1.068001 1.208890 1.455636 static constexpr int threads = 512; @@ -688,7 +690,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_flagged_tuning { // trp_1.ld_0.ipt_14.tpb_384.ns_524.dcid_7.l2w_635 1.256212 1.004808 1.241086 1.373337 static constexpr int threads = 384; @@ -699,7 +701,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_flagged_tuning { // trp_0.ld_1.ipt_22.tpb_320.ns_660.dcid_7.l2w_1030 1.162087 0.997167 1.154955 1.395010 static constexpr int threads = 320; @@ -710,7 +712,7 @@ struct sm100_tuning -struct sm100_tuning +struct sm100_select_flagged_tuning { // trp_1.ld_1.ipt_21.tpb_384.ns_1316.dcid_5.l2w_990 1.221365 1.019231 1.213141 1.372951 static constexpr int threads = 384; @@ -723,7 +725,7 @@ struct sm100_tuning -// struct sm100_tuning<__int128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +// struct sm100_select_flagged_tuning<__int128_t, primitive::no, input_size::_16> // { // static constexpr int threads = 512; // static constexpr int nominal_4b_items = 3; @@ -734,7 +736,7 @@ struct sm100_tuning -// struct sm100_tuning<__uint128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +// struct sm100_select_flagged_tuning<__uint128_t, primitive::no, input_size::_16> // { // static constexpr int threads = 512; // static constexpr int nominal_4b_items = 3; @@ -787,7 +789,7 @@ constexpr may_alias should_alias() return Alias ? may_alias::yes : may_alias::no; } -template +template struct policy_hub { template @@ -864,14 +866,16 @@ struct policy_hub template static auto select_agent_policy100(long) -> typename Policy900::SelectIfPolicyT; - using SelectIfPolicyT = - decltype(select_agent_policy100(), - are_rejects_kept(), - classify_offset_size(), - is_primitive(), - classify_input_size(), - should_alias()>>(0)); + using tuning = ::cuda::std::_If< + is_flagged() == flagged::yes, + sm100_select_flagged_tuning(), classify_input_size(), should_alias()>, + //::cuda::std::_If< + // IsUnique, + // sm100_select_unique_tuning<>, + sm100_select_if_tuning(), classify_input_size(), should_alias()> + //> + >; + using SelectIfPolicyT = decltype(select_agent_policy100(0)); }; using MaxPolicy = Policy1000;