Skip to content

Commit daab0a4

Browse files
Cleanup util_arch (#2773)
1 parent 4f2efaf commit daab0a4

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

cub/cub/util_arch.cuh

+16-15
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
#include <cub/util_macro.cuh>
4848
#include <cub/util_namespace.cuh>
4949

50+
#include <cuda/cmath>
51+
#include <cuda/std/__algorithm/max.h>
52+
#include <cuda/std/__algorithm/min.h>
53+
5054
// Legacy include; this functionality used to be defined in here.
5155
#include <cub/detail/detect_cuda_runtime.cuh>
5256

@@ -113,27 +117,24 @@ namespace detail
113117
static constexpr ::cuda::std::size_t max_smem_per_block = 48 * 1024;
114118
} // namespace detail
115119

116-
template <int NOMINAL_4B_BLOCK_THREADS, int NOMINAL_4B_ITEMS_PER_THREAD, typename T>
120+
template <int Nominal4ByteBlockThreads, int Nominal4ByteItemsPerThread, typename T>
117121
struct RegBoundScaling
118122
{
119-
enum
120-
{
121-
ITEMS_PER_THREAD = CUB_MAX(1, NOMINAL_4B_ITEMS_PER_THREAD * 4 / CUB_MAX(4, sizeof(T))),
122-
BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS,
123-
((cub::detail::max_smem_per_block / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32),
124-
};
123+
static constexpr int ITEMS_PER_THREAD =
124+
::cuda::std::max(1, Nominal4ByteItemsPerThread * 4 / ::cuda::std::max(4, int{sizeof(T)}));
125+
static constexpr int BLOCK_THREADS =
126+
::cuda::std::min(Nominal4ByteBlockThreads,
127+
::cuda::ceil_div(int{detail::max_smem_per_block} / (int{sizeof(T)} * ITEMS_PER_THREAD), 32) * 32);
125128
};
126129

127-
template <int NOMINAL_4B_BLOCK_THREADS, int NOMINAL_4B_ITEMS_PER_THREAD, typename T>
130+
template <int Nominal4ByteBlockThreads, int Nominal4ByteItemsPerThread, typename T>
128131
struct MemBoundScaling
129132
{
130-
enum
131-
{
132-
ITEMS_PER_THREAD =
133-
CUB_MAX(1, CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(T), NOMINAL_4B_ITEMS_PER_THREAD * 2)),
134-
BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS,
135-
((cub::detail::max_smem_per_block / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32),
136-
};
133+
static constexpr int ITEMS_PER_THREAD = ::cuda::std::max(
134+
1, ::cuda::std::min(Nominal4ByteItemsPerThread * 4 / int{sizeof(T)}, Nominal4ByteItemsPerThread * 2));
135+
static constexpr int BLOCK_THREADS =
136+
::cuda::std::min(Nominal4ByteBlockThreads,
137+
::cuda::ceil_div(int{detail::max_smem_per_block} / (int{sizeof(T)} * ITEMS_PER_THREAD), 32) * 32);
137138
};
138139

139140
#endif // Do not document

0 commit comments

Comments
 (0)