Skip to content

Commit

Permalink
Fix RLE tuning (#3239)
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko authored Jan 3, 2025
1 parent b64e5c1 commit 36e27f7
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions cub/cub/device/dispatch/tuning/tuning_run_length_encode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ struct policy_hub
static constexpr int max_input_bytes = static_cast<int>(::cuda::std::max(sizeof(KeyT), sizeof(LengthT)));
static constexpr int combined_input_bytes = sizeof(KeyT) + sizeof(LengthT);

template <CacheLoadModifier LoadModifier>
struct DefaultPolicy
{
static constexpr int nominal_4B_items_per_thread = 6;
Expand All @@ -252,14 +253,14 @@ struct policy_hub
AgentReduceByKeyPolicy<128,
items,
BLOCK_LOAD_DIRECT,
LOAD_LDG,
LoadModifier,
BLOCK_SCAN_WARP_SCANS,
default_reduce_by_key_delay_constructor_t<LengthT, int>>;
};

// SM35
struct Policy350
: DefaultPolicy
: DefaultPolicy<LOAD_LDG>
, ChainedPolicy<350, Policy350, Policy350>
{};

Expand All @@ -273,7 +274,7 @@ struct policy_hub
BLOCK_SCAN_WARP_SCANS,
typename Tuning::delay_constructor>;
template <typename Tuning>
static auto select_agent_policy(long) -> typename DefaultPolicy::ReduceByKeyPolicyT;
static auto select_agent_policy(long) -> typename DefaultPolicy<LOAD_DEFAULT>::ReduceByKeyPolicyT;

// SM80
struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
Expand All @@ -283,7 +284,7 @@ struct policy_hub

// SM86
struct Policy860
: DefaultPolicy
: DefaultPolicy<LOAD_LDG>
, ChainedPolicy<860, Policy860, Policy800>
{};

Expand Down Expand Up @@ -433,7 +434,7 @@ struct sm90_tuning<LengthT, __uint128_t, primitive_length::yes, primitive_key::n
template <class LengthT, class KeyT>
struct policy_hub
{
template <BlockLoadAlgorithm BlockLoad, typename DelayConstructorKey>
template <BlockLoadAlgorithm BlockLoad, typename DelayConstructorKey, CacheLoadModifier LoadModifier>
struct DefaultPolicy
{
static constexpr int nominal_4B_items_per_thread = 15;
Expand All @@ -444,15 +445,15 @@ struct policy_hub
AgentRlePolicy<96,
ITEMS_PER_THREAD,
BlockLoad,
LOAD_LDG,
LoadModifier,
true,
BLOCK_SCAN_WARP_SCANS,
default_reduce_by_key_delay_constructor_t<DelayConstructorKey, int>>;
};

// SM35
struct Policy350
: DefaultPolicy<BLOCK_LOAD_DIRECT, int> // TODO(bgruber): I think we want `LengthT` instead of `int`
: DefaultPolicy<BLOCK_LOAD_DIRECT, int, LOAD_LDG> // TODO(bgruber): I think we want `LengthT` instead of `int`
, ChainedPolicy<350, Policy350, Policy350>
{};

Expand All @@ -467,7 +468,8 @@ struct policy_hub
BLOCK_SCAN_WARP_SCANS,
typename Tuning::delay_constructor>;
template <typename Tuning>
static auto select_agent_policy(long) -> typename DefaultPolicy<BLOCK_LOAD_WARP_TRANSPOSE, LengthT>::RleSweepPolicyT;
static auto select_agent_policy(long) ->
typename DefaultPolicy<BLOCK_LOAD_WARP_TRANSPOSE, LengthT, LOAD_DEFAULT>::RleSweepPolicyT;

// SM80
struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
Expand All @@ -477,7 +479,7 @@ struct policy_hub

// SM86
struct Policy860
: DefaultPolicy<BLOCK_LOAD_DIRECT, int> // TODO(bgruber): I think we want `LengthT` instead of `int`
: DefaultPolicy<BLOCK_LOAD_DIRECT, int, LOAD_LDG> // TODO(bgruber): I think we want `LengthT` instead of `int`
, ChainedPolicy<860, Policy860, Policy800>
{};

Expand Down

0 comments on commit 36e27f7

Please sign in to comment.