Skip to content

Commit

Permalink
Add b200 policies for cub.select.unique_by_key (#3557)
Browse files Browse the repository at this point in the history
Co-authored-by: Giannis Gonidelis <[email protected]>
  • Loading branch information
bernhardmgruber and gonidelis authored Feb 4, 2025
1 parent f61670e commit 16386fe
Showing 1 changed file with 266 additions and 1 deletion.
267 changes: 266 additions & 1 deletion cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,254 @@ struct sm90_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::no, key_size
using delay_constructor = no_delay_constructor_t<1155>;
};

template <class KeyT,
class ValueT,
primitive_key PrimitiveKey = is_primitive_key<KeyT>(),
primitive_val PrimitiveAccum = is_primitive_val<ValueT>(),
key_size KeySize = classify_key_size<KeyT>(),
val_size AccumSize = classify_val_size<ValueT>()>
struct sm100_tuning;

// 8-bit key
template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_1, val_size::_1>
{
// ipt_12.tpb_512.trp_0.ld_0.ns_948.dcid_5.l2w_955 1.121279 1.000000 1.114566 1.43765
static constexpr int threads = 512;
static constexpr int items = 12;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_jitter_window_constructor_t<948, 955>;
};

template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_1, val_size::_2>
{
// ipt_14.tpb_512.trp_0.ld_0.ns_1228.dcid_7.l2w_320 1.151229 1.007229 1.151131 1.443520
static constexpr int threads = 512;
static constexpr int items = 14;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_constructor_t<1228, 320>;
};

template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_1, val_size::_4>
{
// ipt_14.tpb_512.trp_0.ld_0.ns_2016.dcid_7.l2w_620 1.165300 1.095238 1.164478 1.266667
static constexpr int threads = 512;
static constexpr int items = 14;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_constructor_t<2016, 620>;
};

template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_1, val_size::_8>
{
// ipt_10.tpb_384.trp_0.ld_0.ns_1728.dcid_5.l2w_980 1.118716 0.997167 1.116537 1.400000
static constexpr int threads = 384;
static constexpr int items = 10;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_jitter_window_constructor_t<1728, 980>;
};

// TODO(gonidelis): Tune for I128.
#if CUB_IS_INT128_ENABLED
// template <class KeyT, class ValueT>
// struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::no, key_size::_1, val_size::_16>
// {
// static constexpr int threads = 288;
// static constexpr int items = 7;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
// using delay_constructor = fixed_delay_constructor_t<344, 1165>;
// };
#endif

// 16-bit key
template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_2, val_size::_1>
{
// ipt_14.tpb_512.trp_0.ld_0.ns_508.dcid_7.l2w_1020 1.171886 0.906530 1.157128 1.457933
static constexpr int threads = 512;
static constexpr int items = 14;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_constructor_t<508, 1020>;
};

template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_2, val_size::_2>
{
// ipt_12.tpb_384.trp_0.ld_0.ns_928.dcid_7.l2w_605 1.166564 0.997579 1.154805 1.406709
static constexpr int threads = 384;
static constexpr int items = 12;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_constructor_t<928, 605>;
};

template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_2, val_size::_4>
{
// ipt_11.tpb_384.trp_0.ld_1.ns_1620.dcid_7.l2w_810 1.144483 1.011085 1.152798 1.393750
static constexpr int threads = 384;
static constexpr int items = 11;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_CA;
using delay_constructor = exponential_backon_constructor_t<1620, 810>;
};

template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_2, val_size::_8>
{
// ipt_10.tpb_384.trp_0.ld_0.ns_1984.dcid_5.l2w_935 1.605554 1.177083 1.564488 1.946224
static constexpr int threads = 384;
static constexpr int items = 10;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_jitter_window_constructor_t<1984, 935>;
};

// TODO(gonidelis): Tune for I128.
#if CUB_IS_INT128_ENABLED
// template <class KeyT, class ValueT>
// struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::no, key_size::_2, val_size::_16>
// {
// static constexpr int threads = 224;
// static constexpr int items = 9;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
// using delay_constructor = fixed_delay_constructor_t<424, 1055>;
// };
#endif

// 32-bit key
template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_4, val_size::_1>
{
// ipt_14.tpb_512.trp_0.ld_0.ns_1136.dcid_7.l2w_605 1.148057 0.848558 1.133064 1.451074
static constexpr int threads = 512;
static constexpr int items = 14;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_constructor_t<1136, 605>;
};

template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_4, val_size::_2>
{
// ipt_11.tpb_384.trp_0.ld_0.ns_656.dcid_7.l2w_825 1.216312 1.090485 1.211800 1.535714
static constexpr int threads = 384;
static constexpr int items = 11;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_constructor_t<656, 825>;
};

// todo(gonidelis): tuning performs very well for medium input size, regresses for large input sizes.
// find better tuning.
template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_4, val_size::_4>
: sm90_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_4, val_size::_4>
{
// // ipt_14.tpb_512.trp_0.ld_0.ns_408.dcid_7.l2w_960 1.136333 0.995833 1.144371 1.448687
// static constexpr int threads = 512;
// static constexpr int items = 14;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
// using delay_constructor = exponential_backon_constructor_t<408, 960>;
};

template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_4, val_size::_8>
{
// ipt_10.tpb_384.trp_0.ld_0.ns_1012.dcid_5.l2w_800 1.164713 1.014819 1.174307 1.526042
static constexpr int threads = 384;
static constexpr int items = 10;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_jitter_window_constructor_t<1012, 800>;
};

// TODO(gonidelis): Tune for I128.
#if CUB_IS_INT128_ENABLED
// template <class KeyT, class ValueT>
// struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::no, key_size::_4, val_size::_16>
// {
// static constexpr int threads = 384;
// static constexpr int items = 7;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
// using delay_constructor = no_delay_constructor_t<1025>;
// };
#endif

// 64-bit key

// todo(gonidelis): tuning regresses for large input sizes. find better tuning.
template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_8, val_size::_1>
: sm90_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_8, val_size::_1>
{
// // ipt_9.tpb_384.trp_0.ld_0.ns_1064.dcid_7.l2w_600 1.085831 0.972452 1.080521 1.397089
// static constexpr int threads = 384;
// static constexpr int items = 9;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
// using delay_constructor = exponential_backon_constructor_t<1064, 600>;
};

template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_8, val_size::_2>
{
// ipt_10.tpb_384.trp_0.ld_0.ns_864.dcid_5.l2w_1130 1.124095 0.985748 1.120262 1.391304
static constexpr int threads = 384;
static constexpr int items = 10;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_jitter_window_constructor_t<864, 1130>;
};

template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_8, val_size::_4>
{
// ipt_10.tpb_384.trp_0.ld_0.ns_772.dcid_5.l2w_665 1.152243 1.019816 1.166636 1.517526
static constexpr int threads = 384;
static constexpr int items = 10;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
using delay_constructor = exponential_backon_jitter_window_constructor_t<772, 665>;
};

// todo(gonidelis): tuning regresses for large input sizes. find better tuning.
template <class KeyT, class ValueT>
struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_8, val_size::_8>
: sm90_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::yes, key_size::_8, val_size::_8>
{
// // ipt_7.tpb_576.trp_0.ld_0.ns_1132.dcid_5.l2w_1115 1.120721 0.977642 1.131594 1.449407
// static constexpr int threads = 576;
// static constexpr int items = 7;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
// using delay_constructor = exponential_backon_jitter_window_constructor_t<1132, 1115>;
};

// TODO(gonidelis): Tune for I128.
#if CUB_IS_INT128_ENABLED
// template <class KeyT, class ValueT>
// struct sm100_tuning<KeyT, ValueT, primitive_key::yes, primitive_val::no, key_size::_8, val_size::_16>
// {
// static constexpr int threads = 256;
// static constexpr int items = 9;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
// using delay_constructor = no_delay_constructor_t<1155>;
// };
#endif

template <class KeyT, class ValueT>
struct policy_hub
{
Expand Down Expand Up @@ -575,7 +823,24 @@ struct policy_hub
using UniqueByKeyPolicyT = decltype(select_agent_policy<sm90_tuning<KeyT, ValueT>>(0));
};

using MaxPolicy = Policy900;
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
{
// Use values from tuning if a specialization exists, otherwise pick Policy900
template <typename Tuning>
static auto select_agent_policy100(int)
-> AgentUniqueByKeyPolicy<Tuning::threads,
Tuning::items,
Tuning::load_algorithm,
Tuning::load_modifier,
BLOCK_SCAN_WARP_SCANS,
typename Tuning::delay_constructor>;
template <typename Tuning>
static auto select_agent_policy100(long) -> typename Policy900::UniqueByKeyPolicyT;

using UniqueByKeyPolicyT = decltype(select_agent_policy100<sm100_tuning<KeyT, ValueT>>(0));
};

using MaxPolicy = Policy1000;
};

} // namespace unique_by_key
Expand Down

0 comments on commit 16386fe

Please sign in to comment.