Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add b200 policies for partition.three_way #3708

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 127 additions & 1 deletion cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,114 @@ struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_8>
using delay_constructor = no_delay_constructor_t<1050>;
};

template <class InputT,
class OffsetT,
input_size InputSize = classify_input_size<InputT>(),
offset_size OffsetSize = classify_offset_size<OffsetT>()>
struct sm100_tuning;

// This tuning regressed during validation, so we disabled it and fall back to the SM90 tuning
// template <class Input, class OffsetT>
// struct sm100_tuning<Input, OffsetT, input_size::_1, offset_size::_4>
// {
// // trp_0.ipt_12.tpb_256.ns_792.dcid_6.l2w_365 1.063960 0.978016 1.072833 1.301435
// static constexpr int items = 12;
// static constexpr int threads = 256;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
// using delay_constructor = exponential_backon_jitter_constructor_t<792, 365>;
// };

// This tuning regressed during validation, so we disabled it and fall back to the SM90 tuning
// template <class Input, class OffsetT>
// struct sm100_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
// {
// // trp_1.ipt_14.tpb_288.ns_496.dcid_6.l2w_400 1.170449 1.123515 1.170428 1.252066
// static constexpr int items = 14;
// static constexpr int threads = 288;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = exponential_backon_jitter_constructor_t<496, 400>;
// };

template <class Input, class OffsetT>
struct sm100_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
{
// trp_0.ipt_11.tpb_512.ns_72.dcid_6.l2w_840 1.261035 1.069054 1.243873 1.394013
static constexpr int items = 11;
static constexpr int threads = 512;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = exponential_backon_jitter_constructor_t<72, 840>;
};

template <class Input, class OffsetT>
struct sm100_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
{
// trp_1.ipt_10.tpb_256.ns_8.dcid_6.l2w_845 1.137286 1.105647 1.140905 1.194373
static constexpr int items = 10;
static constexpr int threads = 256;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backon_jitter_constructor_t<8, 845>;
};

// todo(gonidelis): Add tunings for I128.
// template <class Input, class OffsetT>
// struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
// {
// static constexpr int threads = 128;
// static constexpr int items = 7;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = no_delay_constructor_t<1040>;
// };

// template <class Input, class OffsetT>
// struct sm100_tuning<Input, OffsetT, input_size::_1, offset_size::_8>
// {
// // trp_1.ipt_20.tpb_768.ns_444.dcid_5.l2w_330 1.510085 0.887070 1.446621 1.982442
// static constexpr int items = 20;
// static constexpr int threads = 768;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = exponential_backon_jitter_window_constructor_t<444, 330>;
// };

template <class Input, class OffsetT>
struct sm100_tuning<Input, OffsetT, input_size::_2, offset_size::_8>
{
// trp_1.ipt_20.tpb_768.ns_544.dcid_5.l2w_500 1.064438 1.000000 1.069149 1.200658
static constexpr int items = 20;
static constexpr int threads = 768;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backon_jitter_window_constructor_t<544, 500>;
};

template <class Input, class OffsetT>
struct sm100_tuning<Input, OffsetT, input_size::_4, offset_size::_8>
{
// trp_1.ipt_15.tpb_768.ns_144.dcid_6.l2w_280 1.099504 1.002083 1.095122 1.352941
static constexpr int items = 15;
static constexpr int threads = 768;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backon_jitter_constructor_t<144, 280>;
};

template <class Input, class OffsetT>
struct sm100_tuning<Input, OffsetT, input_size::_8, offset_size::_8>
{
// trp_1.ipt_14.tpb_320.ns_872.dcid_7.l2w_620 1.083194 1.000000 1.078944 1.315789
static constexpr int items = 14;
static constexpr int threads = 320;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = exponential_backon_constructor_t<872, 620>;
};

// todo(gonidelis): Add tunings for I128.
// template <class Input, class OffsetT>
// struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_8>
// {
// static constexpr int threads = 128;
// static constexpr int items = 7;
// static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
// using delay_constructor = no_delay_constructor_t<1040>;
// };

template <class InputT, class OffsetT>
struct policy_hub
{
Expand Down Expand Up @@ -273,7 +381,25 @@ struct policy_hub
using ThreeWayPartitionPolicy = decltype(select_agent_policy<sm90_tuning<InputT, OffsetT>>(0));
};

using MaxPolicy = Policy900;
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
{
// Use values from tuning if a specialization exists, otherwise pick Policy900
template <typename Tuning>
static auto select_agent_policy100(int)
-> AgentThreeWayPartitionPolicy<Tuning::threads,
Tuning::items,
Tuning::load_algorithm,
LOAD_DEFAULT,
BLOCK_SCAN_WARP_SCANS,
typename Tuning::delay_constructor>;

template <typename Tuning>
static auto select_agent_policy100(long) -> typename Policy900::ThreeWayPartitionPolicy;

using ThreeWayPartitionPolicy = decltype(select_agent_policy100<sm100_tuning<InputT, OffsetT>>(0));
};

using MaxPolicy = Policy1000;
};
} // namespace three_way_partition
} // namespace detail
Expand Down
Loading