Skip to content

Commit

Permalink
Fix block reduce alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Dec 19, 2023
1 parent 2165845 commit 122c4a0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
13 changes: 8 additions & 5 deletions cub/cub/block/specializations/block_reduce_warp_reductions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

#include <cub/detail/uninitialized_copy.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
#include <cub/warp/warp_reduce.cuh>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -92,17 +93,19 @@ struct BlockReduceWarpReductions


/// WarpReduce utility type
typedef typename WarpReduce<T, LOGICAL_WARP_SIZE>::InternalWarpReduce WarpReduce;
using WarpReduce = typename WarpReduce<T, LOGICAL_WARP_SIZE>::InternalWarpReduce;
using WarpReduceStorage = typename WarpReduce::TempStorage;

/// Shared memory storage layout type
struct _TempStorage
/// Alignment ensures that loads of `warp_aggregates` can be vectorized
struct alignas(detail::max_alignment_t<16, T, WarpReduceStorage>::value) _TempStorage
{
/// Buffer for warp-synchronous reduction
typename WarpReduce::TempStorage warp_reduce[WARPS];

/// Shared totals from each warp-synchronous reduction
T warp_aggregates[WARPS];

/// Buffer for warp-synchronous reduction
WarpReduceStorage warp_reduce[WARPS];

/// Shared prefix for the entire thread block
T block_prefix;
};
Expand Down
19 changes: 19 additions & 0 deletions cub/cub/util_type.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,25 @@ template <> struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTE
template <typename T>
struct Traits : NumericTraits<typename ::cuda::std::remove_cv<T>::type> {};

namespace detail
{

template <::cuda::std::size_t Alignment, class... T>
struct max_alignment_t;

template <::cuda::std::size_t Alignment>
struct max_alignment_t<Alignment>
{
constexpr static ::cuda::std::size_t value = Alignment;
};

template <::cuda::std::size_t Alignment, class Head, class... Tail>
struct max_alignment_t<Alignment, Head, Tail...>
{
constexpr static ::cuda::std::size_t value = max_alignment_t<(cub::max)(Alignment, alignof(Head)), Tail...>::value;
};

} // namespace detail

#endif // DOXYGEN_SHOULD_SKIP_THIS

Expand Down
14 changes: 14 additions & 0 deletions cub/test/catch2_test_util_type.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,17 @@ CUB_TEST("Tests non_void_value_t", "[util][type]")
::cuda::std::is_same<int, //
cub::detail::non_void_value_t<non_void_fancy_it, fallback_t>>::value);
}

namespace {
struct alignas(32) test32_t {};
} // namespace

CUB_TEST("Maximal alignment is computed correctly", "[util][type]")
{
STATIC_REQUIRE(cub::detail::max_alignment_t<8>::value == 8);
STATIC_REQUIRE(cub::detail::max_alignment_t<1, int>::value == alignof(int));
STATIC_REQUIRE(cub::detail::max_alignment_t<1, int, double>::value == alignof(double));
STATIC_REQUIRE(cub::detail::max_alignment_t<1, int, double>::value == alignof(double));
STATIC_REQUIRE(cub::detail::max_alignment_t<1, int, double, test32_t>::value == 32);
STATIC_REQUIRE(cub::detail::max_alignment_t<128, int, double, test32_t>::value == 128);
}

0 comments on commit 122c4a0

Please sign in to comment.