Skip to content

Commit

Permalink
Fix ReduceByKey tuning (#3240)
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko authored Jan 3, 2025
1 parent fd2a15d commit b64e5c1
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions cub/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ struct policy_hub
static constexpr int max_input_bytes = static_cast<int>(::cuda::std::max(sizeof(KeyT), sizeof(AccumT)));
static constexpr int combined_input_bytes = sizeof(KeyT) + sizeof(AccumT);

template <CacheLoadModifier LoadModifier>
struct DefaultPolicy
{
static constexpr int nominal_4B_items_per_thread = 6;
Expand All @@ -627,13 +628,13 @@ struct policy_hub
AgentReduceByKeyPolicy<128,
items_per_thread,
BLOCK_LOAD_DIRECT,
LOAD_LDG,
LoadModifier,
BLOCK_SCAN_WARP_SCANS,
default_reduce_by_key_delay_constructor_t<AccumT, int>>;
};

struct Policy350
: DefaultPolicy
: DefaultPolicy<LOAD_LDG>
, ChainedPolicy<350, Policy350, Policy350>
{};

Expand All @@ -648,7 +649,7 @@ struct policy_hub
typename Tuning::delay_constructor>;

template <typename Tuning>
static auto select_agent_policy(long) -> typename DefaultPolicy::ReduceByKeyPolicyT;
static auto select_agent_policy(long) -> typename DefaultPolicy<LOAD_DEFAULT>::ReduceByKeyPolicyT;

struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
{
Expand All @@ -657,7 +658,7 @@ struct policy_hub
};

struct Policy860
: DefaultPolicy
: DefaultPolicy<LOAD_LDG>
, ChainedPolicy<860, Policy860, Policy800>
{};

Expand Down

0 comments on commit b64e5c1

Please sign in to comment.