Skip to content

Commit

Permalink
Refactor sm100 select tunings
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Feb 3, 2025
1 parent afeb37e commit b655860
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 43 deletions.
12 changes: 7 additions & 5 deletions cub/cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -421,11 +421,13 @@ template <typename InputIteratorT,
typename OffsetT,
bool KeepRejects,
bool MayAlias = false,
typename PolicyHub = detail::select::policy_hub<cub::detail::value_t<InputIteratorT>,
cub::detail::value_t<FlagsInputIteratorT>,
detail::select::per_partition_offset_t,
MayAlias,
KeepRejects>>
typename PolicyHub = detail::select::policy_hub<
cub::detail::value_t<InputIteratorT>,
cub::detail::value_t<FlagsInputIteratorT>,
detail::select::per_partition_offset_t,
MayAlias,
KeepRejects,
/* IsUnique */ ::cuda::std::is_same<EqualityOpT, ::cuda::std::equal_to<>>::value>>
struct DispatchSelectIf
{
/******************************************************************************
Expand Down
80 changes: 42 additions & 38 deletions cub/cub/device/dispatch/tuning/tuning_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ enum class input_size
template <class InputT, flagged, keep_rejects, offset_size OffsetSize, primitive, input_size InputSize>
struct sm80_tuning;

// select::if
// select::if / select::unique
template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
Expand Down Expand Up @@ -145,7 +145,7 @@ struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4,
{};
#endif

// select::flagged
// select::flagged[If]
template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
Expand Down Expand Up @@ -307,7 +307,7 @@ struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4
template <class InputT, flagged, keep_rejects, offset_size OffsetSize, primitive, input_size InputSize>
struct sm90_tuning;

// select::if
// select::if / select::unique
template <class Input>
struct sm90_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
Expand Down Expand Up @@ -360,7 +360,7 @@ struct sm90_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4,
{};
#endif

// select::flagged
// select::flagged[If]
template <class Input>
struct sm90_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
Expand Down Expand Up @@ -519,12 +519,12 @@ struct sm90_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4
{};
#endif

template <class InputT, flagged, keep_rejects, offset_size OffsetSize, primitive, input_size InputSize, may_alias>
struct sm100_tuning;

// select::if
template <class InputT, primitive, input_size InputSize, may_alias>
struct sm100_select_if_tuning;

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1, may_alias::no>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_1, may_alias::no>
{
// trp_0.ld_0.ipt_22.tpb_384.ns_0.dcid_2.l2w_915 1.099232 0.980183 1.096778 1.545455
static constexpr int threads = 384;
Expand All @@ -535,7 +535,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1, may_alias::yes>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_1, may_alias::yes>
{
// trp_1.ld_0.ipt_20.tpb_448.ns_596.dcid_6.l2w_295 1.214635 1.001421 1.207023 1.307692
static constexpr int threads = 448;
Expand All @@ -546,7 +546,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2, may_alias::no>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_2, may_alias::no>
{
// trp_1.ld_0.ipt_20.tpb_256.ns_516.dcid_7.l2w_685 1.065598 0.937984 1.067343 1.452153
static constexpr int threads = 256;
Expand All @@ -557,7 +557,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2, may_alias::yes>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_2, may_alias::yes>
{
// trp_1.ld_0.ipt_20.tpb_384.ns_1060.dcid_5.l2w_375 1.109871 0.973142 1.105415 1.459135
static constexpr int threads = 384;
Expand All @@ -568,7 +568,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4, may_alias::no>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_4, may_alias::no>
{
// trp_1.ld_0.ipt_15.tpb_384.ns_1508.dcid_5.l2w_585 1.201993 0.920103 1.185134 1.441805
static constexpr int threads = 384;
Expand All @@ -579,7 +579,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4, may_alias::yes>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_4, may_alias::yes>
{
// trp_1.ld_0.ipt_19.tpb_512.ns_928.dcid_7.l2w_770 1.258815 1.000000 1.235251 1.444884
static constexpr int threads = 512;
Expand All @@ -591,12 +591,11 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi

// baseline remained the fastest, so fall back to previous tuning
// template <class Input>
// struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8,
// may_alias::no>
// struct sm100_select_if_tuning<Input, primitive::yes, input_size::_8, may_alias::no>
// {};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8, may_alias::yes>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_8, may_alias::yes>
{
// trp_1.ld_0.ipt_23.tpb_384.ns_1140.dcid_7.l2w_520 1.081506 0.955298 1.088848 1.248971
static constexpr int threads = 384;
Expand All @@ -609,7 +608,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
// TODO(gonidelis): Tune for I128.
#if CUB_IS_INT128_ENABLED
// template <>
// struct sm100_tuning<__int128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
// struct sm100_select_if_tuning<__int128_t, primitive::no, input_size::_16>
// {
// // static constexpr int threads = 512;
// // static constexpr int nominal_4b_items = 5;
Expand All @@ -620,7 +619,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
// };

// template <>
// struct sm100_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
// struct sm100_select_if_tuning<__uint128_t, primitive::no, input_size::_16>
// {
// // static constexpr int threads = 512;
// // static constexpr int nominal_4b_items = 5;
Expand All @@ -631,9 +630,12 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
// };
#endif

// select::flagged
// select::flagged[If]
template <class InputT, primitive, input_size, may_alias>
struct sm100_select_flagged_tuning;

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1, may_alias::no>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_1, may_alias::no>
{
// trp_0.ld_0.ipt_20.tpb_896.ns_84.dcid_7.l2w_480 1.254262 0.846154 1.222437 1.462665
static constexpr int threads = 896;
Expand All @@ -644,7 +646,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1, may_alias::yes>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_1, may_alias::yes>
{
// trp_0.ld_0.ipt_20.tpb_1024.ns_360.dcid_6.l2w_380 1.274174 0.748441 1.227123 1.610039
static constexpr int threads = 1024;
Expand All @@ -655,7 +657,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2, may_alias::no>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_2, may_alias::no>
{
// trp_0.ld_0.ipt_22.tpb_256.ns_1292.dcid_5.l2w_750 1.283400 1.002841 1.267822 1.445913
static constexpr int threads = 256;
Expand All @@ -666,7 +668,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2, may_alias::yes>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_2, may_alias::yes>
{
// trp_1.ld_0.ipt_20.tpb_448.ns_136.dcid_2.l2w_760 1.318819 0.994090 1.289173 1.551415
static constexpr int threads = 448;
Expand All @@ -677,7 +679,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4, may_alias::no>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_4, may_alias::no>
{
// trp_0.ld_0.ipt_14.tpb_512.ns_844.dcid_6.l2w_675 1.207911 1.068001 1.208890 1.455636
static constexpr int threads = 512;
Expand All @@ -688,7 +690,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4, may_alias::yes>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_4, may_alias::yes>
{
// trp_1.ld_0.ipt_14.tpb_384.ns_524.dcid_7.l2w_635 1.256212 1.004808 1.241086 1.373337
static constexpr int threads = 384;
Expand All @@ -699,7 +701,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8, may_alias::no>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_8, may_alias::no>
{
// trp_0.ld_1.ipt_22.tpb_320.ns_660.dcid_7.l2w_1030 1.162087 0.997167 1.154955 1.395010
static constexpr int threads = 320;
Expand All @@ -710,7 +712,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8, may_alias::yes>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_8, may_alias::yes>
{
// trp_1.ld_1.ipt_21.tpb_384.ns_1316.dcid_5.l2w_990 1.221365 1.019231 1.213141 1.372951
static constexpr int threads = 384;
Expand All @@ -723,7 +725,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
// TODO(gonidelis): Tune for I128.
#if CUB_IS_INT128_ENABLED
// template <>
// struct sm100_tuning<__int128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
// struct sm100_select_flagged_tuning<__int128_t, primitive::no, input_size::_16>
// {
// static constexpr int threads = 512;
// static constexpr int nominal_4b_items = 3;
Expand All @@ -734,7 +736,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
// };

// template <>
// struct sm100_tuning<__uint128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
// struct sm100_select_flagged_tuning<__uint128_t, primitive::no, input_size::_16>
// {
// static constexpr int threads = 512;
// static constexpr int nominal_4b_items = 3;
Expand Down Expand Up @@ -787,7 +789,7 @@ constexpr may_alias should_alias()
return Alias ? may_alias::yes : may_alias::no;
}

template <class InputT, class FlagT, class OffsetT, bool MayAlias, bool KeepRejects>
template <class InputT, class FlagT, class OffsetT, bool MayAlias, bool KeepRejects, bool IsUnique>
struct policy_hub
{
template <CacheLoadModifier LoadModifier>
Expand Down Expand Up @@ -864,14 +866,16 @@ struct policy_hub
template <typename Tuning>
static auto select_agent_policy100(long) -> typename Policy900::SelectIfPolicyT;

using SelectIfPolicyT =
decltype(select_agent_policy100<sm100_tuning<InputT,
is_flagged<FlagT>(),
are_rejects_kept<KeepRejects>(),
classify_offset_size<OffsetT>(),
is_primitive<InputT>(),
classify_input_size<InputT>(),
should_alias<MayAlias>()>>(0));
using tuning = ::cuda::std::_If<
is_flagged<FlagT>() == flagged::yes,
sm100_select_flagged_tuning<InputT, is_primitive<InputT>(), classify_input_size<InputT>(), should_alias<MayAlias>()>,
//::cuda::std::_If<
// IsUnique,
// sm100_select_unique_tuning<>,
sm100_select_if_tuning<InputT, is_primitive<InputT>(), classify_input_size<InputT>(), should_alias<MayAlias>()>
//>
>;
using SelectIfPolicyT = decltype(select_agent_policy100<tuning>(0));
};

using MaxPolicy = Policy1000;
Expand Down

0 comments on commit b655860

Please sign in to comment.