diff --git a/cub/cub/block/specializations/block_reduce_warp_reductions.cuh b/cub/cub/block/specializations/block_reduce_warp_reductions.cuh index 7477af7e581..4bef99bb958 100644 --- a/cub/cub/block/specializations/block_reduce_warp_reductions.cuh +++ b/cub/cub/block/specializations/block_reduce_warp_reductions.cuh @@ -46,6 +46,7 @@ #include #include +#include #include CUB_NAMESPACE_BEGIN @@ -92,17 +93,19 @@ struct BlockReduceWarpReductions /// WarpReduce utility type - typedef typename WarpReduce::InternalWarpReduce WarpReduce; + using WarpReduce = typename WarpReduce::InternalWarpReduce; + using WarpReduceStorage = typename WarpReduce::TempStorage; /// Shared memory storage layout type - struct _TempStorage + /// Alignment ensures that loads of `warp_aggregates` can be vectorized + struct alignas(detail::max_alignment_t<16, T, WarpReduceStorage>::value) _TempStorage { - /// Buffer for warp-synchronous reduction - typename WarpReduce::TempStorage warp_reduce[WARPS]; - /// Shared totals from each warp-synchronous reduction T warp_aggregates[WARPS]; + /// Buffer for warp-synchronous reduction + WarpReduceStorage warp_reduce[WARPS]; + /// Shared prefix for the entire thread block T block_prefix; }; diff --git a/cub/cub/util_type.cuh b/cub/cub/util_type.cuh index 9a982c2bb85..e46632118b2 100644 --- a/cub/cub/util_type.cuh +++ b/cub/cub/util_type.cuh @@ -1370,6 +1370,25 @@ template <> struct NumericTraits : BaseTraits struct Traits : NumericTraits::type> {}; +namespace detail +{ + +template <::cuda::std::size_t Alignment, class... T> +struct max_alignment_t; + +template <::cuda::std::size_t Alignment> +struct max_alignment_t +{ + constexpr static ::cuda::std::size_t value = Alignment; +}; + +template <::cuda::std::size_t Alignment, class Head, class... Tail> +struct max_alignment_t +{ + constexpr static ::cuda::std::size_t value = max_alignment_t<(cub::max)(Alignment, alignof(Head)), Tail...>::value; +}; + +} // namespace detail #endif // DOXYGEN_SHOULD_SKIP_THIS diff --git a/cub/test/catch2_test_util_type.cu b/cub/test/catch2_test_util_type.cu index a73de1afa3f..570a3123db2 100644 --- a/cub/test/catch2_test_util_type.cu +++ b/cub/test/catch2_test_util_type.cu @@ -66,3 +66,17 @@ CUB_TEST("Tests non_void_value_t", "[util][type]") ::cuda::std::is_same>::value); } + +namespace { +struct alignas(32) test32_t {}; +} // namespace + +CUB_TEST("Maximal alignment is computed correctly", "[util][type]") +{ + STATIC_REQUIRE(cub::detail::max_alignment_t<8>::value == 8); + STATIC_REQUIRE(cub::detail::max_alignment_t<1, int>::value == alignof(int)); + STATIC_REQUIRE(cub::detail::max_alignment_t<1, int, double>::value == alignof(double)); + STATIC_REQUIRE(cub::detail::max_alignment_t<1, int, double>::value == alignof(double)); + STATIC_REQUIRE(cub::detail::max_alignment_t<1, int, double, test32_t>::value == 32); + STATIC_REQUIRE(cub::detail::max_alignment_t<128, int, double, test32_t>::value == 128); +}