Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e17f497

Browse files
committedFeb 15, 2025··
Replace cub::Traits by numeric_limits and deprecate
* Consistently use ::cuda::std::numeric_limits in CUB Fixes: #3381
1 parent 98c9205 commit e17f497

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+449
-433
lines changed
 

‎c2h/generators.cu

+4-25
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,17 @@
4040
#include <thrust/scan.h>
4141
#include <thrust/tabulate.h>
4242

43-
#include <cuda/std/type_traits>
43+
#include <cuda/type_traits>
4444

4545
#include <cstdint>
4646

47+
#include <c2h/bfloat16.cuh>
4748
#include <c2h/custom_type.h>
4849
#include <c2h/device_policy.h>
4950
#include <c2h/extended_types.h>
5051
#include <c2h/fill_striped.h>
5152
#include <c2h/generators.h>
53+
#include <c2h/half.cuh>
5254
#include <c2h/vector.h>
5355

5456
#if C2H_HAS_CURAND
@@ -118,30 +120,7 @@ private:
118120
c2h::device_vector<float> m_distribution;
119121
};
120122

121-
// TODO(bgruber): modelled after cub::Traits. We should generalize this somewhere into libcu++.
122-
template <typename T>
123-
struct is_floating_point : ::cuda::std::is_floating_point<T>
124-
{};
125-
#if _CCCL_HAS_NVFP16()
126-
template <>
127-
struct is_floating_point<__half> : ::cuda::std::true_type
128-
{};
129-
#endif // _CCCL_HAS_NVFP16()
130-
#if _CCCL_HAS_NVBF16()
131-
template <>
132-
struct is_floating_point<__nv_bfloat16> : ::cuda::std::true_type
133-
{};
134-
#endif // _CCCL_HAS_NVBF16()
135-
#if _CCCL_HAS_NVFP8()
136-
template <>
137-
struct is_floating_point<__nv_fp8_e4m3> : ::cuda::std::true_type
138-
{};
139-
template <>
140-
struct is_floating_point<__nv_fp8_e5m2> : ::cuda::std::true_type
141-
{};
142-
#endif // _CCCL_HAS_NVFP8()
143-
144-
template <typename T, bool = is_floating_point<T>::value>
123+
template <typename T, bool = ::cuda::is_floating_point_v<T>>
145124
struct random_to_item_t
146125
{
147126
float m_min;

‎c2h/include/c2h/bfloat16.cuh

+15-12
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ struct bfloat16_t
212212
}
213213
};
214214

215+
#ifdef __GNUC__
216+
# pragma GCC diagnostic pop
217+
#endif
218+
215219
/******************************************************************************
216220
* I/O stream overloads
217221
******************************************************************************/
@@ -230,18 +234,17 @@ inline std::ostream& operator<<(std::ostream& out, const __nv_bfloat16& x)
230234
}
231235

232236
/******************************************************************************
233-
* Traits overloads
237+
* traits and limits
234238
******************************************************************************/
235239

236240
_LIBCUDACXX_BEGIN_NAMESPACE_STD
237241
template <>
238242
struct __is_extended_floating_point<bfloat16_t> : true_type
239243
{};
240-
241-
#ifndef _CCCL_NO_VARIABLE_TEMPLATES
244+
#ifndef _CCCL_NO_INLINE_VARIABLES
242245
template <>
243246
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<bfloat16_t> = true;
244-
#endif // _CCCL_NO_VARIABLE_TEMPLATES
247+
#endif // _CCCL_NO_INLINE_VARIABLES
245248

246249
template <>
247250
class __numeric_limits_impl<bfloat16_t, __numeric_limits_type::__floating_point>
@@ -264,13 +267,13 @@ public:
264267
};
265268
_LIBCUDACXX_END_NAMESPACE_STD
266269

267-
_CCCL_SUPPRESS_DEPRECATED_PUSH
268270
template <>
269-
struct CUB_NS_QUALIFIER::NumericTraits<bfloat16_t>
270-
: CUB_NS_QUALIFIER::BaseTraits<FLOATING_POINT, unsigned short, bfloat16_t>
271-
{};
272-
_CCCL_SUPPRESS_DEPRECATED_POP
271+
struct CUB_NS_QUALIFIER::detail::unsigned_bits<bfloat16_t, void>
272+
{
273+
using type = unsigned short;
274+
};
273275

274-
#ifdef __GNUC__
275-
# pragma GCC diagnostic pop
276-
#endif
276+
// template <>
277+
// struct CUB_NS_QUALIFIER::detail::NumericTraits<bfloat16_t>
278+
// : CUB_NS_QUALIFIER::detail::BaseTraits<FLOATING_POINT, unsigned short, bfloat16_t>
279+
// {};

0 commit comments

Comments
 (0)
Please sign in to comment.