Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add b200 tunings for reduce.by_key #3610

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 265 additions & 1 deletion cub/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,252 @@ struct sm90_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::no, primitive
using delay_constructor = detail::no_delay_constructor_t<1150>;
};

template <class KeyT,
class AccumT,
primitive_op PrimitiveOp,
primitive_key PrimitiveKey = is_primitive_key<KeyT>(),
primitive_accum PrimitiveAccum = is_primitive_accum<AccumT>(),
key_size KeySize = classify_key_size<KeyT>(),
accum_size AccumSize = classify_accum_size<AccumT>()>
struct sm100_tuning;

// 8-bit key
template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_1, accum_size::_1>
{
// ipt_13.tpb_576.trp_0.ld_1.ns_2044.dcid_5.l2w_240 1.161888 0.848558 1.134941 1.299109
static constexpr int items = 13;
static constexpr int threads = 576;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backon_jitter_window_constructor_t<2044, 240>;
static constexpr CacheLoadModifier load_modifier = LOAD_CA;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_1, accum_size::_2>
{
// ipt_10.tpb_224.trp_0.ld_0.ns_244.dcid_4.l2w_390 1.313932 1.260540 1.319588 1.427374
static constexpr int items = 10;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backoff_jitter_window_constructor_t<224, 390>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_1, accum_size::_4>
{
// ipt_14.tpb_128.trp_0.ld_0.ns_248.dcid_2.l2w_285 1.118109 1.051534 1.134336 1.326788
static constexpr int items = 14;
static constexpr int threads = 128;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backoff_constructor_t<248, 285>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_1, accum_size::_8>
{
// ipt_19.tpb_128.trp_1.ld_0.ns_132.dcid_1.l2w_540 1.113820 1.002404 1.105014 1.202296
static constexpr int items = 19;
static constexpr int threads = 128;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = fixed_delay_constructor_t<132, 540>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

// todo(gonidelis): Add tunings for I128.
// template <class KeyT, class AccumT>
// struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::no, key_size::_1,
// accum_size::_16>
// {
// static constexpr int threads = 128;
// static constexpr int items = 11;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = detail::no_delay_constructor_t<1100>;
// };

// 16-bit key
template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_2, accum_size::_1>
{
// ipt_14.tpb_128.trp_1.ld_0.ns_164.dcid_2.l2w_290 1.239579 1.119705 1.239111 1.313112
static constexpr int items = 14;
static constexpr int threads = 128;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = detail::exponential_backoff_constructor_t<164, 290>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_2, accum_size::_2>
{
// ipt_14.tpb_256.trp_1.ld_0.ns_180.dcid_2.l2w_975 1.145635 1.012658 1.139956 1.251546
static constexpr int items = 14;
static constexpr int threads = 256;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backoff_constructor_t<180, 975>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_2, accum_size::_4>
{
// ipt_11.tpb_256.trp_0.ld_0.ns_224.dcid_2.l2w_550 1.066293 1.000109 1.073092 1.181818
static constexpr int items = 11;
static constexpr int threads = 256;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backoff_constructor_t<224, 550>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_2, accum_size::_8>
{
// ipt_10.tpb_160.trp_1.ld_0.ns_156.dcid_1.l2w_725 1.045007 1.002105 1.049690 1.141827
static constexpr int items = 10;
static constexpr int threads = 160;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = fixed_delay_constructor_t<156, 725>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

// todo(gonidelis): Add tunings for I128.
// template <class KeyT, class AccumT>
// struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::no, key_size::_2,
// accum_size::_16>
// {
// static constexpr int threads = 128;
// static constexpr int items = 11;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = detail::no_delay_constructor_t<1100>;
// };

// 32-bit key
template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_4, accum_size::_1>
{
// ipt_10.tpb_224.trp_0.ld_0.ns_324.dcid_2.l2w_285 1.157217 1.073724 1.166510 1.356940
static constexpr int items = 10;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backoff_constructor_t<324, 285>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_4, accum_size::_2>
{
// ipt_11.tpb_256.trp_0.ld_0.ns_1984.dcid_5.l2w_115 1.214155 1.128842 1.214093 1.364476
static constexpr int items = 11;
static constexpr int threads = 256;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backon_jitter_window_constructor_t<1984, 115>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_4, accum_size::_4>
{
// ipt_14.tpb_224.trp_1.ld_0.ns_476.dcid_5.l2w_1005 1.187378 1.119705 1.185397 1.258420

static constexpr int items = 14;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backon_jitter_window_constructor_t<476, 1005>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_4, accum_size::_8>
{
// ipt_10.tpb_256.trp_1.ld_0.ns_1868.dcid_7.l2w_145 1.142915 1.020581 1.137459 1.237913
static constexpr int items = 10;
static constexpr int threads = 256;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backon_constructor_t<1868, 145>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

// todo(gonidelis): Add tunings for I128.
// template <class KeyT, class AccumT>
// struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::no, key_size::_4,
// accum_size::_16>
// {
// static constexpr int threads = 128;
// static constexpr int items = 11;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = detail::no_delay_constructor_t<1100>;
// };

// 64-bit key
template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_8, accum_size::_1>
{
// ipt_9.tpb_224.trp_1.ld_0.ns_1940.dcid_5.l2w_460 1.157294 1.075650 1.153566 1.250729
static constexpr int items = 9;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backon_jitter_window_constructor_t<1940, 460>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_8, accum_size::_2>
{
// ipt_11.tpb_224.trp_1.ld_1.ns_392.dcid_2.l2w_550 1.104034 1.007212 1.099543 1.220401
static constexpr int items = 11;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backoff_constructor_t<392, 550>;
static constexpr CacheLoadModifier load_modifier = LOAD_CA;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_8, accum_size::_4>
{
// ipt_9.tpb_224.trp_1.ld_0.ns_244.dcid_2.l2w_475 1.130098 1.000000 1.130661 1.215722
static constexpr int items = 9;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backoff_constructor_t<244, 475>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_8, accum_size::_8>
{
// ipt_9.tpb_224.trp_1.ld_0.ns_196.dcid_2.l2w_340 1.272056 1.142857 1.262499 1.352941
static constexpr int items = 9;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backoff_constructor_t<196, 340>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

// todo(gonidelis): Add tunings for I128.
// template <class KeyT, class AccumT>
// struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::no, key_size::_8,
// accum_size::_16>
// {
// static constexpr int threads = 128;
// static constexpr int items = 11;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = detail::no_delay_constructor_t<1125>;
// };

// todo(gonidelis): Add tunings for 128-bit key.
// 128-bit key
// template <class KeyT, class AccumT>
// struct sm90_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::no, key_size::_16,
// accum_size::_1>
// {
// static constexpr int threads = 128;
// static constexpr int items = 11;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = detail::no_delay_constructor_t<1125>;
// };

template <class ReductionOpT, class AccumT, class KeyT>
struct policy_hub
{
Expand Down Expand Up @@ -668,7 +914,25 @@ struct policy_hub
decltype(select_agent_policy<sm90_tuning<KeyT, AccumT, is_primitive_op<ReductionOpT>()>>(0));
};

using MaxPolicy = Policy900;
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
{
// Use values from tuning if a specialization exists, otherwise pick the default
template <typename Tuning>
static auto select_agent_policy(int)
-> AgentReduceByKeyPolicy<Tuning::threads,
Tuning::items,
Tuning::load_algorithm,
Tuning::load_modifier,
BLOCK_SCAN_WARP_SCANS,
typename Tuning::delay_constructor>;

template <typename Tuning>
static auto select_agent_policy(long) -> typename Policy900::ReduceByKeyPolicyT;

using ReduceByKeyPolicyT =
decltype(select_agent_policy<sm100_tuning<KeyT, AccumT, is_primitive_op<ReductionOpT>()>>(0));
};
using MaxPolicy = Policy1000;
};
} // namespace reduce_by_key
} // namespace detail
Expand Down
Loading