Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix block reduce alignment #1233

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions cub/cub/block/specializations/block_reduce_warp_reductions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

#include <cub/detail/uninitialized_copy.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
#include <cub/warp/warp_reduce.cuh>

CUB_NAMESPACE_BEGIN
Expand Down Expand Up @@ -92,17 +93,19 @@ struct BlockReduceWarpReductions


/// WarpReduce utility type
typedef typename WarpReduce<T, LOGICAL_WARP_SIZE>::InternalWarpReduce WarpReduce;
using WarpReduce = typename WarpReduce<T, LOGICAL_WARP_SIZE>::InternalWarpReduce;
using WarpReduceStorage = typename WarpReduce::TempStorage;

/// Shared memory storage layout type
struct _TempStorage
/// Alignment ensures that loads of `warp_aggregates` can be vectorized
struct alignas(detail::max_alignment_t<16, T, WarpReduceStorage>::value) _TempStorage
{
/// Buffer for warp-synchronous reduction
typename WarpReduce::TempStorage warp_reduce[WARPS];

/// Shared totals from each warp-synchronous reduction
T warp_aggregates[WARPS];

/// Buffer for warp-synchronous reduction
WarpReduceStorage warp_reduce[WARPS];

/// Shared prefix for the entire thread block
T block_prefix;
};
Expand Down
19 changes: 19 additions & 0 deletions cub/cub/util_type.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,25 @@ template <> struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTE
template <typename T>
struct Traits : NumericTraits<typename ::cuda::std::remove_cv<T>::type> {};

namespace detail
{

template <::cuda::std::size_t Alignment, class... T>
struct max_alignment_t;

template <::cuda::std::size_t Alignment>
struct max_alignment_t<Alignment>
{
constexpr static ::cuda::std::size_t value = Alignment;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

we usually go with the ordering of static constexpr because static is the more relevant information here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remark: clang-format can handle such orderings using the QualifierOrder style option, which we seem to not set in our .clang-format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened: #1748

};

template <::cuda::std::size_t Alignment, class Head, class... Tail>
struct max_alignment_t<Alignment, Head, Tail...>
{
constexpr static ::cuda::std::size_t value = max_alignment_t<(cub::max)(Alignment, alignof(Head)), Tail...>::value;
};

} // namespace detail

#endif // DOXYGEN_SHOULD_SKIP_THIS

Expand Down
14 changes: 14 additions & 0 deletions cub/test/catch2_test_util_type.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,17 @@ CUB_TEST("Tests non_void_value_t", "[util][type]")
::cuda::std::is_same<int, //
cub::detail::non_void_value_t<non_void_fancy_it, fallback_t>>::value);
}

namespace {
struct alignas(32) test32_t {};
} // namespace

CUB_TEST("Maximal alignment is computed correctly", "[util][type]")
{
STATIC_REQUIRE(cub::detail::max_alignment_t<8>::value == 8);
STATIC_REQUIRE(cub::detail::max_alignment_t<1, int>::value == alignof(int));
STATIC_REQUIRE(cub::detail::max_alignment_t<1, int, double>::value == alignof(double));
STATIC_REQUIRE(cub::detail::max_alignment_t<1, int, double>::value == alignof(double));
STATIC_REQUIRE(cub::detail::max_alignment_t<1, int, double, test32_t>::value == 32);
STATIC_REQUIRE(cub::detail::max_alignment_t<128, int, double, test32_t>::value == 128);
}
Loading