Skip to content

Commit

Permalink
fix tuning_scan sm90 config issue (#3236)
Browse files Browse the repository at this point in the history
Co-authored-by: Shijie Chen <[email protected]>
  • Loading branch information
gevtushenko and shijie-nv authored Jan 2, 2025
1 parent 2bdcb7b commit b57e065
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions cub/cub/device/dispatch/tuning/tuning_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,16 @@ constexpr accum_size classify_accum_size()
: accum_size::unknown;
}

template <int Threads, int Items, int L2B, int L2W>
template <class AccumT, int Threads, int Items, int L2B, int L2W>
struct tuning
{
static constexpr int threads = Threads;
static constexpr int items = Items;
using delay_constructor = fixed_delay_constructor_t<L2B, L2W>;
static constexpr BlockLoadAlgorithm load_algorithm =
(sizeof(AccumT) > 128) ? BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED : BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr BlockStoreAlgorithm store_algorithm =
(sizeof(AccumT) > 128) ? BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED : BLOCK_STORE_WARP_TRANSPOSE;
};

template <class AccumT,
Expand Down Expand Up @@ -206,16 +210,16 @@ template <class AccumT,
struct sm90_tuning;

// clang-format off
template <class T> struct sm90_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_1> : tuning<192, 22, 168, 1140> {};
template <class T> struct sm90_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_2> : tuning<512, 12, 376, 1125> {};
template <class T> struct sm90_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_4> : tuning<128, 24, 648, 1245> {};
template <class T> struct sm90_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_8> : tuning<224, 24, 632, 1290> {};
template <class T> struct sm90_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_1> : tuning<T, 192, 22, 168, 1140> {};
template <class T> struct sm90_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_2> : tuning<T, 512, 12, 376, 1125> {};
template <class T> struct sm90_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_4> : tuning<T, 128, 24, 648, 1245> {};
template <class T> struct sm90_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_8> : tuning<T, 224, 24, 632, 1290> {};

template <> struct sm90_tuning<float, primitive_op::yes, primitive_accum::yes, accum_size::_4> : tuning<128, 24, 688, 1140> {};
template <> struct sm90_tuning<double, primitive_op::yes, primitive_accum::yes, accum_size::_8> : tuning<224, 24, 576, 1215> {};
template <> struct sm90_tuning<float, primitive_op::yes, primitive_accum::yes, accum_size::_4> : tuning<float, 128, 24, 688, 1140> {};
template <> struct sm90_tuning<double, primitive_op::yes, primitive_accum::yes, accum_size::_8> : tuning<double, 224, 24, 576, 1215> {};

#if CUB_IS_INT128_ENABLED
template <> struct sm90_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_size::_16> : tuning<576, 21, 860, 630> {};
template <> struct sm90_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_size::_16> : tuning<__int128_t, 576, 21, 860, 630> {};
template <>
struct sm90_tuning<__uint128_t, primitive_op::yes, primitive_accum::no, accum_size::_16>
: sm90_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_size::_16>
Expand Down

0 comments on commit b57e065

Please sign in to comment.