Skip to content

Commit 8d61de9

Browse files
Cleanup util_arch
1 parent 284e104 commit 8d61de9

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

cub/cub/util_arch.cuh

+16-15
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
#include <cub/util_macro.cuh>
4848
#include <cub/util_namespace.cuh>
4949

50+
#include <cuda/cmath>
51+
#include <cuda/std/__algorithm/max.h>
52+
#include <cuda/std/__algorithm/min.h>
53+
5054
// Legacy include; this functionality used to be defined in here.
5155
#include <cub/detail/detect_cuda_runtime.cuh>
5256

@@ -143,27 +147,24 @@ namespace detail
143147
static constexpr ::cuda::std::size_t max_smem_per_block = 48 * 1024;
144148
} // namespace detail
145149

146-
template <int NOMINAL_4B_BLOCK_THREADS, int NOMINAL_4B_ITEMS_PER_THREAD, typename T>
150+
template <int Nominal4ByteBlockThreads, int Nominal4ByteItemsPerThread, typename T>
147151
struct RegBoundScaling
148152
{
149-
enum
150-
{
151-
ITEMS_PER_THREAD = CUB_MAX(1, NOMINAL_4B_ITEMS_PER_THREAD * 4 / CUB_MAX(4, sizeof(T))),
152-
BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS,
153-
((cub::detail::max_smem_per_block / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32),
154-
};
153+
static constexpr int ITEMS_PER_THREAD =
154+
::cuda::std::max(1, Nominal4ByteItemsPerThread * 4 / ::cuda::std::max(4, int{sizeof(T)}));
155+
static constexpr int BLOCK_THREADS =
156+
::cuda::std::min(Nominal4ByteBlockThreads,
157+
::cuda::ceil_div(int{detail::max_smem_per_block} / (int{sizeof(T)} * ITEMS_PER_THREAD), 32) * 32);
155158
};
156159

157-
template <int NOMINAL_4B_BLOCK_THREADS, int NOMINAL_4B_ITEMS_PER_THREAD, typename T>
160+
template <int Nominal4ByteBlockThreads, int Nominal4ByteItemsPerThread, typename T>
158161
struct MemBoundScaling
159162
{
160-
enum
161-
{
162-
ITEMS_PER_THREAD =
163-
CUB_MAX(1, CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(T), NOMINAL_4B_ITEMS_PER_THREAD * 2)),
164-
BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS,
165-
((cub::detail::max_smem_per_block / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32),
166-
};
163+
static constexpr int ITEMS_PER_THREAD = ::cuda::std::max(
164+
1, ::cuda::std::min(Nominal4ByteItemsPerThread * 4 / int{sizeof(T)}, Nominal4ByteItemsPerThread * 2));
165+
static constexpr int BLOCK_THREADS =
166+
::cuda::std::min(Nominal4ByteBlockThreads,
167+
::cuda::ceil_div(int{detail::max_smem_per_block} / (int{sizeof(T)} * ITEMS_PER_THREAD), 32) * 32);
167168
};
168169

169170
#endif // Do not document

0 commit comments

Comments
 (0)