Skip to content

Commit

Permalink
Refactor three_way_parition tuning (#3140)
Browse files Browse the repository at this point in the history
* Drop needless comments
* Move and rename policy_hub
* Drop unneeded namespace qualifications
* Rename DefaultTuning
* Eliminate redundancy
* Swap sm80 and sm90 tuning
  • Loading branch information
bernhardmgruber authored Dec 12, 2024
1 parent 7321a51 commit 0ea508f
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 162 deletions.
20 changes: 10 additions & 10 deletions cub/cub/device/dispatch/dispatch_three_way_partition.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,16 @@ DeviceThreeWayPartitionInitKernel(ScanTileStateT tile_state, int num_tiles, NumS
* Dispatch
******************************************************************************/

template <typename InputIteratorT,
typename FirstOutputIteratorT,
typename SecondOutputIteratorT,
typename UnselectedOutputIteratorT,
typename NumSelectedIteratorT,
typename SelectFirstPartOp,
typename SelectSecondPartOp,
typename OffsetT,
typename SelectedPolicy =
detail::device_three_way_partition_policy_hub<cub::detail::value_t<InputIteratorT>, OffsetT>>
template <
typename InputIteratorT,
typename FirstOutputIteratorT,
typename SecondOutputIteratorT,
typename UnselectedOutputIteratorT,
typename NumSelectedIteratorT,
typename SelectFirstPartOp,
typename SelectSecondPartOp,
typename OffsetT,
typename SelectedPolicy = detail::three_way_partition::policy_hub<cub::detail::value_t<InputIteratorT>, OffsetT>>
struct DispatchThreeWayPartitionIf
{
/*****************************************************************************
Expand Down
249 changes: 97 additions & 152 deletions cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ CUB_NAMESPACE_BEGIN

namespace detail
{

namespace three_way_partition
{

enum class input_size
{
_1,
Expand Down Expand Up @@ -92,246 +90,193 @@ template <class InputT,
class OffsetT,
input_size InputSize = classify_input_size<InputT>(),
offset_size OffsetSize = classify_offset_size<OffsetT>()>
struct sm90_tuning
{
static constexpr int threads = 256;
static constexpr int items = Nominal4BItemsToItems<InputT>(9);

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using AccumPackHelperT = detail::three_way_partition::accumulator_pack_t<OffsetT>;
using AccumPackT = typename AccumPackHelperT::pack_t;
using delay_constructor = detail::default_delay_constructor_t<AccumPackT>;
};
struct sm80_tuning;

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<445>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
struct sm80_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<104, 512>;
static constexpr int threads = 256;
static constexpr int items = 12;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = no_delay_constructor_t<910>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
struct sm80_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
{
static constexpr int threads = 320;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1105>;
static constexpr int threads = 256;
static constexpr int items = 11;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = no_delay_constructor_t<1120>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
struct sm80_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
{
static constexpr int threads = 384;
static constexpr int items = 7;

static constexpr int threads = 224;
static constexpr int items = 11;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<464, 1165>;
using delay_constructor = fixed_delay_constructor_t<264, 1080>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
struct sm80_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
{
static constexpr int threads = 128;
static constexpr int items = 7;

static constexpr int threads = 128;
static constexpr int items = 10;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1040>;
using delay_constructor = fixed_delay_constructor_t<672, 1120>;
};

template <class InputT,
class OffsetT,
input_size InputSize = classify_input_size<InputT>(),
offset_size OffsetSize = classify_offset_size<OffsetT>()>
struct sm90_tuning;

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_8>
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 24;

static constexpr int threads = 256;
static constexpr int items = 12;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<4, 285>;
using delay_constructor = no_delay_constructor_t<445>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_8>
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
{
static constexpr int threads = 640;
static constexpr int items = 24;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<245>;
static constexpr int threads = 256;
static constexpr int items = 12;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = fixed_delay_constructor_t<104, 512>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_8>
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 23;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<910>;
static constexpr int threads = 320;
static constexpr int items = 12;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = no_delay_constructor_t<1105>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_8>
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 18;

static constexpr int threads = 384;
static constexpr int items = 7;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1145>;
using delay_constructor = fixed_delay_constructor_t<464, 1165>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_8>
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 11;

static constexpr int threads = 128;
static constexpr int items = 7;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1050>;
using delay_constructor = no_delay_constructor_t<1040>;
};

template <class InputT,
class OffsetT,
input_size InputSize = classify_input_size<InputT>(),
offset_size OffsetSize = classify_offset_size<OffsetT>()>
struct sm80_tuning
template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_8>
{
static constexpr int threads = 256;
static constexpr int items = Nominal4BItemsToItems<InputT>(9);

static constexpr int threads = 256;
static constexpr int items = 24;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using AccumPackHelperT = detail::three_way_partition::accumulator_pack_t<OffsetT>;
using AccumPackT = typename AccumPackHelperT::pack_t;
using delay_constructor = detail::default_delay_constructor_t<AccumPackT>;
using delay_constructor = fixed_delay_constructor_t<4, 285>;
};

template <class Input, class OffsetT>
struct sm80_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_8>
{
static constexpr int threads = 256;
static constexpr int items = 12;

static constexpr int threads = 640;
static constexpr int items = 24;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<910>;
using delay_constructor = no_delay_constructor_t<245>;
};

template <class Input, class OffsetT>
struct sm80_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_8>
{
static constexpr int threads = 256;
static constexpr int items = 11;

static constexpr int threads = 256;
static constexpr int items = 23;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1120>;
using delay_constructor = no_delay_constructor_t<910>;
};

template <class Input, class OffsetT>
struct sm80_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_8>
{
static constexpr int threads = 224;
static constexpr int items = 11;

static constexpr int threads = 256;
static constexpr int items = 18;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<264, 1080>;
using delay_constructor = no_delay_constructor_t<1145>;
};

template <class Input, class OffsetT>
struct sm80_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_8>
{
static constexpr int threads = 128;
static constexpr int items = 10;

static constexpr int threads = 256;
static constexpr int items = 11;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<672, 1120>;
using delay_constructor = no_delay_constructor_t<1050>;
};

} // namespace three_way_partition

template <class InputT, class OffsetT>
struct device_three_way_partition_policy_hub
struct policy_hub
{
struct DefaultTuning
template <typename DelayConstructor>
struct DefaultPolicy
{
static constexpr int ITEMS_PER_THREAD = Nominal4BItemsToItems<InputT>(9);

using ThreeWayPartitionPolicy =
cub::AgentThreeWayPartitionPolicy<256,
ITEMS_PER_THREAD,
cub::BLOCK_LOAD_DIRECT,
cub::LOAD_DEFAULT,
cub::BLOCK_SCAN_WARP_SCANS>;
AgentThreeWayPartitionPolicy<256,
Nominal4BItemsToItems<InputT>(9),
BLOCK_LOAD_DIRECT,
LOAD_DEFAULT,
BLOCK_SCAN_WARP_SCANS,
DelayConstructor>;
};

/// SM35
struct Policy350
: DefaultTuning
: DefaultPolicy<fixed_delay_constructor_t<350, 450>>
, ChainedPolicy<350, Policy350, Policy350>
{};

// Use values from tuning if a specialization exists, otherwise pick DefaultPolicy
template <typename Tuning>
static auto select_agent_policy(int)
-> AgentThreeWayPartitionPolicy<Tuning::threads,
Tuning::items,
Tuning::load_algorithm,
LOAD_DEFAULT,
BLOCK_SCAN_WARP_SCANS,
typename Tuning::delay_constructor>;

template <typename Tuning>
static auto select_agent_policy(long) ->
typename DefaultPolicy<
default_delay_constructor_t<typename accumulator_pack_t<OffsetT>::pack_t>>::ThreeWayPartitionPolicy;

struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
{
using tuning = detail::three_way_partition::sm80_tuning<InputT, OffsetT>;

using ThreeWayPartitionPolicy =
AgentThreeWayPartitionPolicy<tuning::threads,
tuning::items,
tuning::load_algorithm,
cub::LOAD_DEFAULT,
cub::BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
using ThreeWayPartitionPolicy = decltype(select_agent_policy<sm80_tuning<InputT, OffsetT>>(0));
};

struct Policy860
: DefaultTuning
: DefaultPolicy<fixed_delay_constructor_t<350, 450>>
, ChainedPolicy<860, Policy860, Policy800>
{};

/// SM90
struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
{
using tuning = detail::three_way_partition::sm90_tuning<InputT, OffsetT>;

using ThreeWayPartitionPolicy =
AgentThreeWayPartitionPolicy<tuning::threads,
tuning::items,
tuning::load_algorithm,
cub::LOAD_DEFAULT,
cub::BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
using ThreeWayPartitionPolicy = decltype(select_agent_policy<sm90_tuning<InputT, OffsetT>>(0));
};

using MaxPolicy = Policy900;
};

} // namespace three_way_partition
} // namespace detail

CUB_NAMESPACE_END

0 comments on commit 0ea508f

Please sign in to comment.