Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor sm100 select tunings #1

Open
wants to merge 1 commit into
base: tune_select_if_flag_unique
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions cub/cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -421,11 +421,13 @@ template <typename InputIteratorT,
typename OffsetT,
bool KeepRejects,
bool MayAlias = false,
typename PolicyHub = detail::select::policy_hub<cub::detail::value_t<InputIteratorT>,
cub::detail::value_t<FlagsInputIteratorT>,
detail::select::per_partition_offset_t,
MayAlias,
KeepRejects>>
typename PolicyHub = detail::select::policy_hub<
cub::detail::value_t<InputIteratorT>,
cub::detail::value_t<FlagsInputIteratorT>,
detail::select::per_partition_offset_t,
MayAlias,
KeepRejects,
/* IsUnique */ ::cuda::std::is_same<EqualityOpT, ::cuda::std::equal_to<>>::value>>
struct DispatchSelectIf
{
/******************************************************************************
Expand Down
80 changes: 42 additions & 38 deletions cub/cub/device/dispatch/tuning/tuning_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ enum class input_size
template <class InputT, flagged, keep_rejects, offset_size OffsetSize, primitive, input_size InputSize>
struct sm80_tuning;

// select::if
// select::if / select::unique
template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
Expand Down Expand Up @@ -145,7 +145,7 @@ struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4,
{};
#endif

// select::flagged
// select::flagged[If]
template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
Expand Down Expand Up @@ -307,7 +307,7 @@ struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4
template <class InputT, flagged, keep_rejects, offset_size OffsetSize, primitive, input_size InputSize>
struct sm90_tuning;

// select::if
// select::if / select::unique
template <class Input>
struct sm90_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
Expand Down Expand Up @@ -360,7 +360,7 @@ struct sm90_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4,
{};
#endif

// select::flagged
// select::flagged[If]
template <class Input>
struct sm90_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
Expand Down Expand Up @@ -519,12 +519,12 @@ struct sm90_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4
{};
#endif

template <class InputT, flagged, keep_rejects, offset_size OffsetSize, primitive, input_size InputSize, may_alias>
struct sm100_tuning;

// select::if
template <class InputT, primitive, input_size InputSize, may_alias>
struct sm100_select_if_tuning;

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1, may_alias::no>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_1, may_alias::no>
{
// trp_0.ld_0.ipt_22.tpb_384.ns_0.dcid_2.l2w_915 1.099232 0.980183 1.096778 1.545455
static constexpr int threads = 384;
Expand All @@ -535,7 +535,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1, may_alias::yes>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_1, may_alias::yes>
{
// trp_1.ld_0.ipt_20.tpb_448.ns_596.dcid_6.l2w_295 1.214635 1.001421 1.207023 1.307692
static constexpr int threads = 448;
Expand All @@ -546,7 +546,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2, may_alias::no>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_2, may_alias::no>
{
// trp_1.ld_0.ipt_20.tpb_256.ns_516.dcid_7.l2w_685 1.065598 0.937984 1.067343 1.452153
static constexpr int threads = 256;
Expand All @@ -557,7 +557,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2, may_alias::yes>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_2, may_alias::yes>
{
// trp_1.ld_0.ipt_20.tpb_384.ns_1060.dcid_5.l2w_375 1.109871 0.973142 1.105415 1.459135
static constexpr int threads = 384;
Expand All @@ -568,7 +568,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4, may_alias::no>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_4, may_alias::no>
{
// trp_1.ld_0.ipt_15.tpb_384.ns_1508.dcid_5.l2w_585 1.201993 0.920103 1.185134 1.441805
static constexpr int threads = 384;
Expand All @@ -579,7 +579,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4, may_alias::yes>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_4, may_alias::yes>
{
// trp_1.ld_0.ipt_19.tpb_512.ns_928.dcid_7.l2w_770 1.258815 1.000000 1.235251 1.444884
static constexpr int threads = 512;
Expand All @@ -591,12 +591,11 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi

// baseline remained the fastest, so fall back to previous tuning
// template <class Input>
// struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8,
// may_alias::no>
// struct sm100_select_if_tuning<Input, primitive::yes, input_size::_8, may_alias::no>
// {};

template <class Input>
struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8, may_alias::yes>
struct sm100_select_if_tuning<Input, primitive::yes, input_size::_8, may_alias::yes>
{
// trp_1.ld_0.ipt_23.tpb_384.ns_1140.dcid_7.l2w_520 1.081506 0.955298 1.088848 1.248971
static constexpr int threads = 384;
Expand All @@ -609,7 +608,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
// TODO(gonidelis): Tune for I128.
#if CUB_IS_INT128_ENABLED
// template <>
// struct sm100_tuning<__int128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
// struct sm100_select_if_tuning<__int128_t, primitive::no, input_size::_16>
// {
// // static constexpr int threads = 512;
// // static constexpr int nominal_4b_items = 5;
Expand All @@ -620,7 +619,7 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
// };

// template <>
// struct sm100_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
// struct sm100_select_if_tuning<__uint128_t, primitive::no, input_size::_16>
// {
// // static constexpr int threads = 512;
// // static constexpr int nominal_4b_items = 5;
Expand All @@ -631,9 +630,12 @@ struct sm100_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primi
// };
#endif

// select::flagged
// select::flagged[If]
template <class InputT, primitive, input_size, may_alias>
struct sm100_select_flagged_tuning;

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1, may_alias::no>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_1, may_alias::no>
{
// trp_0.ld_0.ipt_20.tpb_896.ns_84.dcid_7.l2w_480 1.254262 0.846154 1.222437 1.462665
static constexpr int threads = 896;
Expand All @@ -644,7 +646,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1, may_alias::yes>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_1, may_alias::yes>
{
// trp_0.ld_0.ipt_20.tpb_1024.ns_360.dcid_6.l2w_380 1.274174 0.748441 1.227123 1.610039
static constexpr int threads = 1024;
Expand All @@ -655,7 +657,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2, may_alias::no>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_2, may_alias::no>
{
// trp_0.ld_0.ipt_22.tpb_256.ns_1292.dcid_5.l2w_750 1.283400 1.002841 1.267822 1.445913
static constexpr int threads = 256;
Expand All @@ -666,7 +668,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2, may_alias::yes>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_2, may_alias::yes>
{
// trp_1.ld_0.ipt_20.tpb_448.ns_136.dcid_2.l2w_760 1.318819 0.994090 1.289173 1.551415
static constexpr int threads = 448;
Expand All @@ -677,7 +679,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4, may_alias::no>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_4, may_alias::no>
{
// trp_0.ld_0.ipt_14.tpb_512.ns_844.dcid_6.l2w_675 1.207911 1.068001 1.208890 1.455636
static constexpr int threads = 512;
Expand All @@ -688,7 +690,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4, may_alias::yes>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_4, may_alias::yes>
{
// trp_1.ld_0.ipt_14.tpb_384.ns_524.dcid_7.l2w_635 1.256212 1.004808 1.241086 1.373337
static constexpr int threads = 384;
Expand All @@ -699,7 +701,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8, may_alias::no>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_8, may_alias::no>
{
// trp_0.ld_1.ipt_22.tpb_320.ns_660.dcid_7.l2w_1030 1.162087 0.997167 1.154955 1.395010
static constexpr int threads = 320;
Expand All @@ -710,7 +712,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
};

template <class Input>
struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8, may_alias::yes>
struct sm100_select_flagged_tuning<Input, primitive::yes, input_size::_8, may_alias::yes>
{
// trp_1.ld_1.ipt_21.tpb_384.ns_1316.dcid_5.l2w_990 1.221365 1.019231 1.213141 1.372951
static constexpr int threads = 384;
Expand All @@ -723,7 +725,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
// TODO(gonidelis): Tune for I128.
#if CUB_IS_INT128_ENABLED
// template <>
// struct sm100_tuning<__int128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
// struct sm100_select_flagged_tuning<__int128_t, primitive::no, input_size::_16>
// {
// static constexpr int threads = 512;
// static constexpr int nominal_4b_items = 3;
Expand All @@ -734,7 +736,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
// };

// template <>
// struct sm100_tuning<__uint128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
// struct sm100_select_flagged_tuning<__uint128_t, primitive::no, input_size::_16>
// {
// static constexpr int threads = 512;
// static constexpr int nominal_4b_items = 3;
Expand Down Expand Up @@ -787,7 +789,7 @@ constexpr may_alias should_alias()
return Alias ? may_alias::yes : may_alias::no;
}

template <class InputT, class FlagT, class OffsetT, bool MayAlias, bool KeepRejects>
template <class InputT, class FlagT, class OffsetT, bool MayAlias, bool KeepRejects, bool IsUnique>
struct policy_hub
{
template <CacheLoadModifier LoadModifier>
Expand Down Expand Up @@ -864,14 +866,16 @@ struct policy_hub
template <typename Tuning>
static auto select_agent_policy100(long) -> typename Policy900::SelectIfPolicyT;

using SelectIfPolicyT =
decltype(select_agent_policy100<sm100_tuning<InputT,
is_flagged<FlagT>(),
are_rejects_kept<KeepRejects>(),
classify_offset_size<OffsetT>(),
is_primitive<InputT>(),
classify_input_size<InputT>(),
should_alias<MayAlias>()>>(0));
using tuning = ::cuda::std::_If<
is_flagged<FlagT>() == flagged::yes,
sm100_select_flagged_tuning<InputT, is_primitive<InputT>(), classify_input_size<InputT>(), should_alias<MayAlias>()>,
//::cuda::std::_If<
// IsUnique,
// sm100_select_unique_tuning<>,
sm100_select_if_tuning<InputT, is_primitive<InputT>(), classify_input_size<InputT>(), should_alias<MayAlias>()>
//>
>;
using SelectIfPolicyT = decltype(select_agent_policy100<tuning>(0));
};

using MaxPolicy = Policy1000;
Expand Down
Loading