Skip to content

Commit

Permalink
fix thread-reduce performance regression (#2944)
Browse files Browse the repository at this point in the history
  • Loading branch information
fbusato authored Nov 27, 2024
1 parent ab87e54 commit 27d8c87
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 30 deletions.
17 changes: 17 additions & 0 deletions cub/cub/thread/thread_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -702,27 +702,44 @@ struct CubOperatorToSimdOperator<::cuda::minimum<>, T>
using simd_type = typename type::simd_type;
};

template <typename T>
struct CubOperatorToSimdOperator<::cuda::minimum<T>, T> : CubOperatorToSimdOperator<::cuda::minimum<>, T>
{};

template <typename T>
struct CubOperatorToSimdOperator<::cuda::maximum<>, T>
{
using type = SimdMax<T>;
using simd_type = typename type::simd_type;
};

template <typename T>
struct CubOperatorToSimdOperator<::cuda::maximum<T>, T> : CubOperatorToSimdOperator<::cuda::maximum<>, T>
{};

template <typename T>
struct CubOperatorToSimdOperator<::cuda::std::plus<>, T>
{
using type = SimdSum<T>;
using simd_type = typename type::simd_type;
};

template <typename T>
struct CubOperatorToSimdOperator<::cuda::std::plus<T>, T> : CubOperatorToSimdOperator<::cuda::std::plus<>, T>
{};

template <typename T>
struct CubOperatorToSimdOperator<::cuda::std::multiplies<>, T>
{
using type = SimdMul<T>;
using simd_type = typename type::simd_type;
};

template <typename T>
struct CubOperatorToSimdOperator<::cuda::std::multiplies<T>, T>
: CubOperatorToSimdOperator<::cuda::std::multiplies<>, T>
{};

template <typename ReduceOp, typename T>
using cub_operator_to_simd_operator_t = typename CubOperatorToSimdOperator<ReduceOp, T>::type;

Expand Down
147 changes: 117 additions & 30 deletions cub/cub/thread/thread_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,27 @@ namespace internal
template <typename T, typename ReductionOp>
struct enable_generic_simd_reduction_traits
{
static constexpr bool value = cub::detail::is_one_of<T, ::cuda::std::int16_t, ::cuda::std::uint16_t>()
&& cub::detail::is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::maximum<>>();
static constexpr bool value =
cub::detail::is_one_of<T, ::cuda::std::int16_t, ::cuda::std::uint16_t>()
&& cub::detail::
is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::minimum<T>, ::cuda::maximum<>, ::cuda::maximum<T>>();
};

# if defined(_CCCL_HAS_NVFP16)

template <typename ReductionOp>
struct enable_generic_simd_reduction_traits<__half, ReductionOp>
{
static constexpr bool value = cub::detail::
is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::maximum<>, ::cuda::std::plus<>, ::cuda::std::multiplies<>>();
static constexpr bool value = cub::detail::is_one_of<
ReductionOp,
::cuda::minimum<>,
::cuda::minimum<__half>,
::cuda::maximum<>,
::cuda::maximum<__half>,
::cuda::std::plus<>,
::cuda::std::plus<__half>,
::cuda::std::multiplies<>,
::cuda::std::multiplies<__half>>();
};
# endif // defined(_CCCL_HAS_NVFP16)

Expand All @@ -248,8 +258,16 @@ struct enable_generic_simd_reduction_traits<__half, ReductionOp>
template <typename ReductionOp>
struct enable_generic_simd_reduction_traits<__nv_bfloat16, ReductionOp>
{
static constexpr bool value = cub::detail::
is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::maximum<>, ::cuda::std::plus<>, ::cuda::std::multiplies<>>();
static constexpr bool value = cub::detail::is_one_of<
ReductionOp,
::cuda::minimum<>,
::cuda::minimum<__nv_bfloat16>,
::cuda::maximum<>,
::cuda::maximum<__nv_bfloat16>,
::cuda::std::plus<>,
::cuda::std::plus<__nv_bfloat16>,
::cuda::std::multiplies<>,
::cuda::std::multiplies<__nv_bfloat16>>();
};

# endif // defined(_CCCL_HAS_NVBF16)
Expand All @@ -269,15 +287,24 @@ _CCCL_NODISCARD _CCCL_DEVICE constexpr bool enable_sm90_simd_reduction()
using cub::detail::is_one_of;
// ::cuda::std::plus<> not handled: IADD3 always produces less instructions than VIADD2
return is_one_of<T, ::cuda::std::int16_t, ::cuda::std::uint16_t>() && //
is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::maximum<>>() && Length >= 10;
is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::minimum<T>, ::cuda::maximum<>, ::cuda::maximum<T>>()
&& Length >= 10;
}

template <typename T, typename ReductionOp, int Length>
_CCCL_NODISCARD _CCCL_DEVICE constexpr bool enable_sm80_simd_reduction()
{
using cub::detail::is_one_of;
using ::cuda::std::is_same;
return is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::maximum<>, ::cuda::std::plus<>, ::cuda::std::multiplies<>>()
return is_one_of<ReductionOp,
::cuda::minimum<>,
::cuda::minimum<T>,
::cuda::maximum<>,
::cuda::maximum<T>,
::cuda::std::plus<>,
::cuda::std::plus<T>,
::cuda::std::multiplies<>,
::cuda::std::multiplies<T>>()
&& Length >= 4
# if defined(_CCCL_HAS_NVFP16) && defined(_CCCL_HAS_NVBF16)
&& (is_same<T, __half>::value || is_same<T, __nv_bfloat16>::value)
Expand All @@ -295,7 +322,12 @@ _CCCL_NODISCARD _CCCL_DEVICE constexpr bool enable_sm70_simd_reduction()
using cub::detail::is_one_of;
using ::cuda::std::is_same;
# if defined(_CCCL_HAS_NVFP16)
return is_same<T, __half>::value && is_one_of<ReductionOp, ::cuda::std::plus<>, ::cuda::std::multiplies<>>()
return is_same<T, __half>::value
&& is_one_of<ReductionOp,
::cuda::std::plus<>,
::cuda::std::plus<T>,
::cuda::std::multiplies<>,
::cuda::std::multiplies<T>>()
&& Length >= 4;
# else
return false;
Expand Down Expand Up @@ -344,14 +376,21 @@ template <typename T, typename ReductionOp>
struct enable_ternary_reduction_sm90
{
static constexpr bool value =
cub::detail::is_one_of<T, ::cuda::std::int32_t, ::cuda::std::uint32_t, ::cuda::std::int64_t, ::cuda::std::uint64_t>
&& cub::detail::is_one_of<ReductionOp,
::cuda::minimum<>,
::cuda::maximum<>,
::cuda::std::plus<>,
::cuda::std::bit_and<>,
::cuda::std::bit_or<>,
::cuda::std::bit_xor<>>();
cub::detail::is_one_of<T, ::cuda::std::int32_t, ::cuda::std::uint32_t>()
&& cub::detail::is_one_of<
ReductionOp,
::cuda::minimum<>,
::cuda::minimum<T>,
::cuda::maximum<>,
::cuda::maximum<T>,
::cuda::std::plus<>,
::cuda::std::plus<T>,
::cuda::std::bit_and<>,
::cuda::std::bit_and<T>,
::cuda::std::bit_or<>,
::cuda::std::bit_or<T>,
::cuda::std::bit_xor<>,
::cuda::std::bit_xor<T>>();
};

# if defined(_CCCL_HAS_NVFP16)
Expand All @@ -360,7 +399,13 @@ template <typename ReductionOp>
struct enable_ternary_reduction_sm90<__half2, ReductionOp>
{
static constexpr bool value =
cub::detail::is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::maximum<>, SimdMin<__half>, SimdMax<__half>>();
cub::detail::is_one_of<ReductionOp,
::cuda::minimum<>,
::cuda::minimum<__half2>,
::cuda::maximum<>,
::cuda::maximum<__half2>,
SimdMin<__half>,
SimdMax<__half>>();
};

# endif // defined(_CCCL_HAS_NVFP16)
Expand All @@ -370,8 +415,14 @@ struct enable_ternary_reduction_sm90<__half2, ReductionOp>
template <typename ReductionOp>
struct enable_ternary_reduction_sm90<__nv_bfloat162, ReductionOp>
{
static constexpr bool value = cub::detail::
is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::maximum<>, SimdMin<__nv_bfloat16>, SimdMax<__nv_bfloat16>>();
static constexpr bool value =
cub::detail::is_one_of<ReductionOp,
::cuda::minimum<>,
::cuda::minimum<__nv_bfloat162>,
::cuda::maximum<>,
::cuda::maximum<__nv_bfloat162>,
SimdMin<__nv_bfloat16>,
SimdMax<__nv_bfloat16>>();
};

# endif // defined(_CCCL_HAS_NVBF16)
Expand All @@ -394,10 +445,11 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE _CCCL_CONSTEXPR_CXX14 bool enable
NV_PROVIDES_SM_90,
(return enable_ternary_reduction_sm90<T, ReductionOp>::value;),
NV_PROVIDES_SM_50,
(return is_one_of<AccumT, ::cuda::std::int32_t, ::cuda::std::uint32_t, ::cuda::std::int64_t,
::cuda::std::uint64_t>()
&& is_one_of<ReductionOp, ::cuda::std::plus<>, ::cuda::std::bit_and<>, ::cuda::std::bit_or<>,
::cuda::std::bit_xor<>>();),
(return is_one_of<AccumT, ::cuda::std::int32_t, ::cuda::std::uint32_t>()
&& is_one_of<ReductionOp, ::cuda::std::plus<>, ::cuda::std::plus<T>,
::cuda::std::bit_and<>, ::cuda::std::bit_and<T>,
::cuda::std::bit_or<>, ::cuda::std::bit_or<T>,
::cuda::std::bit_xor<>, ::cuda::std::bit_xor<T>>();),
NV_ANY_TARGET,
(return false;)
);
Expand All @@ -415,12 +467,19 @@ _CCCL_NODISCARD _CCCL_DEVICE constexpr bool enable_promotion()
return ::cuda::std::is_integral<T>::value && sizeof(T) <= 2
&& is_one_of<ReductionOp,
::cuda::std::plus<>,
::cuda::std::plus<T>,
::cuda::std::multiplies<>,
::cuda::std::multiplies<T>,
::cuda::std::bit_and<>,
::cuda::std::bit_and<T>,
::cuda::std::bit_or<>,
::cuda::std::bit_or<T>,
::cuda::std::bit_xor<>,
::cuda::std::bit_xor<T>,
::cuda::maximum<>,
::cuda::minimum<>>();
::cuda::maximum<T>,
::cuda::minimum<>,
::cuda::minimum<T>>();
}

/***********************************************************************************************************************
Expand Down Expand Up @@ -551,18 +610,46 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT ThreadReduce(const Input&
using cub::internal::enable_simd_reduction;
using cub::internal::enable_ternary_reduction;
using PromT = ::cuda::std::_If<enable_promotion<Input, ReductionOp, AccumT>(), int, AccumT>;
_CCCL_IF_CONSTEXPR (!cub::detail::is_one_of<
ReductionOp,
::cuda::std::plus<>,
::cuda::std::plus<ValueT>,
::cuda::std::multiplies<>,
::cuda::std::multiplies<ValueT>,
::cuda::std::bit_and<>,
::cuda::std::bit_and<ValueT>,
::cuda::std::bit_or<>,
::cuda::std::bit_or<ValueT>,
::cuda::std::bit_xor<>,
::cuda::std::bit_xor<ValueT>,
::cuda::maximum<>,
::cuda::maximum<ValueT>,
::cuda::minimum<>,
::cuda::minimum<ValueT>,
cub::internal::SimdMin<ValueT>,
cub::internal::SimdMax<ValueT>>())
{
return cub::internal::ThreadReduceSequential<AccumT>(input, reduction_op);
}
_CCCL_IF_CONSTEXPR (cub::detail::is_one_of<ReductionOp, ::cuda::std::plus<>, ::cuda::std::plus<ValueT>>()
&& cub::detail::is_one_of<ValueT, int, ::cuda::std::uint32_t>())
{
// clang-format off
NV_IF_TARGET(NV_PROVIDES_SM_90,
(return cub::internal::ThreadReduceSequential<AccumT>(input, reduction_op);),
(return cub::internal::ThreadReduceTernaryTree<PromT>(input, reduction_op);)
);
// clang-format on
}
if (enable_simd_reduction<Input, ReductionOp, AccumT>())
{
return cub::internal::ThreadReduceSimd(input, reduction_op);
}
else if (enable_ternary_reduction<Input, ReductionOp, PromT>())
if (enable_ternary_reduction<Input, ReductionOp, PromT>())
{
return cub::internal::ThreadReduceTernaryTree<PromT>(input, reduction_op);
}
else
{
return cub::internal::ThreadReduceBinaryTree<PromT>(input, reduction_op);
}
return cub::internal::ThreadReduceBinaryTree<PromT>(input, reduction_op);
}

//! @brief Reduction over statically-sized array-like types, seeded with the specified @p prefix.
Expand Down

0 comments on commit 27d8c87

Please sign in to comment.