Skip to content

Commit

Permalink
Add b200 tunings for reduce.by_key
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis authored and bernhardmgruber committed Jan 30, 2025
1 parent 38983eb commit 8d73bb7
Showing 1 changed file with 265 additions and 1 deletion.
266 changes: 265 additions & 1 deletion cub/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,252 @@ struct sm90_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::no, primitive
using delay_constructor = detail::no_delay_constructor_t<1150>;
};

template <class KeyT,
class AccumT,
primitive_op PrimitiveOp,
primitive_key PrimitiveKey = is_primitive_key<KeyT>(),
primitive_accum PrimitiveAccum = is_primitive_accum<AccumT>(),
key_size KeySize = classify_key_size<KeyT>(),
accum_size AccumSize = classify_accum_size<AccumT>()>
struct sm100_tuning;

// 8-bit key
template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_1, accum_size::_1>
{
// ipt_13.tpb_576.trp_0.ld_1.ns_2044.dcid_5.l2w_240 1.161888 0.848558 1.134941 1.299109
static constexpr int items = 13;
static constexpr int threads = 576;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backon_jitter_window_constructor_t<2044, 240>;
static constexpr CacheLoadModifier load_modifier = LOAD_CA;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_1, accum_size::_2>
{
// ipt_10.tpb_224.trp_0.ld_0.ns_244.dcid_4.l2w_390 1.313932 1.260540 1.319588 1.427374
static constexpr int items = 10;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backoff_jitter_window_constructor_t<224, 390>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_1, accum_size::_4>
{
// ipt_14.tpb_128.trp_0.ld_0.ns_248.dcid_2.l2w_285 1.118109 1.051534 1.134336 1.326788
static constexpr int items = 14;
static constexpr int threads = 128;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backoff_constructor_t<248, 285>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_1, accum_size::_8>
{
// ipt_19.tpb_128.trp_1.ld_0.ns_132.dcid_1.l2w_540 1.113820 1.002404 1.105014 1.202296
static constexpr int items = 19;
static constexpr int threads = 128;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = fixed_delay_constructor_t<132, 540>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

// todo(gonidelis): Add tunings for I128.
// template <class KeyT, class AccumT>
// struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::no, key_size::_1,
// accum_size::_16>
// {
// static constexpr int threads = 128;
// static constexpr int items = 11;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = detail::no_delay_constructor_t<1100>;
// };

// 16-bit key
template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_2, accum_size::_1>
{
// ipt_14.tpb_128.trp_1.ld_0.ns_164.dcid_2.l2w_290 1.239579 1.119705 1.239111 1.313112
static constexpr int items = 14;
static constexpr int threads = 128;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = detail::exponential_backoff_constructor_t<164, 290>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_2, accum_size::_2>
{
// ipt_14.tpb_256.trp_1.ld_0.ns_180.dcid_2.l2w_975 1.145635 1.012658 1.139956 1.251546
static constexpr int items = 14;
static constexpr int threads = 256;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backoff_constructor_t<180, 975>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_2, accum_size::_4>
{
// ipt_11.tpb_256.trp_0.ld_0.ns_224.dcid_2.l2w_550 1.066293 1.000109 1.073092 1.181818
static constexpr int items = 11;
static constexpr int threads = 256;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backoff_constructor_t<224, 550>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_2, accum_size::_8>
{
// ipt_10.tpb_160.trp_1.ld_0.ns_156.dcid_1.l2w_725 1.045007 1.002105 1.049690 1.141827
static constexpr int items = 10;
static constexpr int threads = 160;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = fixed_delay_constructor_t<156, 725>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

// todo(gonidelis): Add tunings for I128.
// template <class KeyT, class AccumT>
// struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::no, key_size::_2,
// accum_size::_16>
// {
// static constexpr int threads = 128;
// static constexpr int items = 11;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = detail::no_delay_constructor_t<1100>;
// };

// 32-bit key
template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_4, accum_size::_1>
{
// ipt_10.tpb_224.trp_0.ld_0.ns_324.dcid_2.l2w_285 1.157217 1.073724 1.166510 1.356940
static constexpr int items = 10;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backoff_constructor_t<324, 285>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_4, accum_size::_2>
{
// ipt_11.tpb_256.trp_0.ld_0.ns_1984.dcid_5.l2w_115 1.214155 1.128842 1.214093 1.364476
static constexpr int items = 11;
static constexpr int threads = 256;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backon_jitter_window_constructor_t<1984, 115>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_4, accum_size::_4>
{
// ipt_14.tpb_224.trp_1.ld_0.ns_476.dcid_5.l2w_1005 1.187378 1.119705 1.185397 1.258420

static constexpr int items = 14;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backon_jitter_window_constructor_t<476, 1005>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_4, accum_size::_8>
{
// ipt_10.tpb_256.trp_1.ld_0.ns_1868.dcid_7.l2w_145 1.142915 1.020581 1.137459 1.237913
static constexpr int items = 10;
static constexpr int threads = 256;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backon_constructor_t<1868, 145>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

// todo(gonidelis): Add tunings for I128.
// template <class KeyT, class AccumT>
// struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::no, key_size::_4,
// accum_size::_16>
// {
// static constexpr int threads = 128;
// static constexpr int items = 11;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = detail::no_delay_constructor_t<1100>;
// };

// 64-bit key
template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_8, accum_size::_1>
{
// ipt_9.tpb_224.trp_1.ld_0.ns_1940.dcid_5.l2w_460 1.157294 1.075650 1.153566 1.250729
static constexpr int items = 9;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backon_jitter_window_constructor_t<1940, 460>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_8, accum_size::_2>
{
// ipt_11.tpb_224.trp_1.ld_1.ns_392.dcid_2.l2w_550 1.104034 1.007212 1.099543 1.220401
static constexpr int items = 11;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backoff_constructor_t<392, 550>;
static constexpr CacheLoadModifier load_modifier = LOAD_CA;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_8, accum_size::_4>
{
// ipt_9.tpb_224.trp_1.ld_0.ns_244.dcid_2.l2w_475 1.130098 1.000000 1.130661 1.215722
static constexpr int items = 9;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backoff_constructor_t<244, 475>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

template <class KeyT, class AccumT>
struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::yes, key_size::_8, accum_size::_8>
{
// ipt_9.tpb_224.trp_1.ld_0.ns_196.dcid_2.l2w_340 1.272056 1.142857 1.262499 1.352941
static constexpr int items = 9;
static constexpr int threads = 224;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backoff_constructor_t<196, 340>;
static constexpr CacheLoadModifier load_modifier = LOAD_DEFAULT;
};

// todo(gonidelis): Add tunings for I128.
// template <class KeyT, class AccumT>
// struct sm100_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::no, key_size::_8,
// accum_size::_16>
// {
// static constexpr int threads = 128;
// static constexpr int items = 11;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = detail::no_delay_constructor_t<1125>;
// };

// todo(gonidelis): Add tunings for 128-bit key.
// 128-bit key
// template <class KeyT, class AccumT>
// struct sm90_tuning<KeyT, AccumT, primitive_op::yes, primitive_key::yes, primitive_accum::no, key_size::_16,
// accum_size::_1>
// {
// static constexpr int threads = 128;
// static constexpr int items = 11;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = detail::no_delay_constructor_t<1125>;
// };

template <class ReductionOpT, class AccumT, class KeyT>
struct policy_hub
{
Expand Down Expand Up @@ -668,7 +914,25 @@ struct policy_hub
decltype(select_agent_policy<sm90_tuning<KeyT, AccumT, is_primitive_op<ReductionOpT>()>>(0));
};

using MaxPolicy = Policy900;
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
{
// Use values from tuning if a specialization exists, otherwise pick the default
template <typename Tuning>
static auto select_agent_policy(int)
-> AgentReduceByKeyPolicy<Tuning::threads,
Tuning::items,
Tuning::load_algorithm,
Tuning::load_modifier,
BLOCK_SCAN_WARP_SCANS,
typename Tuning::delay_constructor>;

template <typename Tuning>
static auto select_agent_policy(long) -> typename Policy900::ReduceByKeyPolicyT;

using ReduceByKeyPolicyT =
decltype(select_agent_policy<sm100_tuning<KeyT, AccumT, is_primitive_op<ReductionOpT>()>>(0));
};
using MaxPolicy = Policy1000;
};
} // namespace reduce_by_key
} // namespace detail
Expand Down

0 comments on commit 8d73bb7

Please sign in to comment.