Skip to content

Commit

Permalink
Replace cub::Traits by numeric_limits and deprecate
Browse files Browse the repository at this point in the history
* Consistently use ::cuda::std::numeric_limits in CUB

Fixes: NVIDIA#3381
  • Loading branch information
bernhardmgruber committed Feb 24, 2025
1 parent eca53a2 commit a15fd10
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 70 deletions.
22 changes: 13 additions & 9 deletions c2h/include/c2h/bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ struct bfloat16_t
}
};

#ifdef __GNUC__
# pragma GCC diagnostic pop
#endif

/******************************************************************************
* I/O stream overloads
******************************************************************************/
Expand All @@ -230,18 +234,16 @@ inline std::ostream& operator<<(std::ostream& out, const __nv_bfloat16& x)
}

/******************************************************************************
* Traits overloads
* traits and limits
******************************************************************************/

_LIBCUDACXX_BEGIN_NAMESPACE_STD
template <>
struct __is_extended_floating_point<bfloat16_t> : true_type
{};

#ifndef _CCCL_NO_VARIABLE_TEMPLATES
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<bfloat16_t> = true;
#endif // _CCCL_NO_VARIABLE_TEMPLATES

template <>
class numeric_limits<bfloat16_t>
Expand All @@ -267,10 +269,12 @@ public:
_LIBCUDACXX_END_NAMESPACE_STD

template <>
struct CUB_NS_QUALIFIER::NumericTraits<bfloat16_t>
: CUB_NS_QUALIFIER::BaseTraits<FLOATING_POINT, true, unsigned short, bfloat16_t>
{};
struct CUB_NS_QUALIFIER::detail::unsigned_bits<bfloat16_t, void>
{
using type = unsigned short;
};

#ifdef __GNUC__
# pragma GCC diagnostic pop
#endif
// template <>
// struct CUB_NS_QUALIFIER::detail::NumericTraits<bfloat16_t>
// : CUB_NS_QUALIFIER::detail::BaseTraits<FLOATING_POINT, true, unsigned short, bfloat16_t>
// {};
22 changes: 13 additions & 9 deletions c2h/include/c2h/half.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ struct half_t
}
};

#ifdef __GNUC__
# pragma GCC diagnostic pop
#endif

/******************************************************************************
* I/O stream overloads
******************************************************************************/
Expand All @@ -325,18 +329,16 @@ inline std::ostream& operator<<(std::ostream& out, const __half& x)
}

/******************************************************************************
* Traits overloads
* traits and limits
******************************************************************************/

_LIBCUDACXX_BEGIN_NAMESPACE_STD
template <>
struct __is_extended_floating_point<half_t> : true_type
{};

#ifndef _CCCL_NO_VARIABLE_TEMPLATES
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<half_t> = true;
#endif // _CCCL_NO_VARIABLE_TEMPLATES

template <>
class numeric_limits<half_t>
Expand All @@ -362,10 +364,12 @@ public:
_LIBCUDACXX_END_NAMESPACE_STD

template <>
struct CUB_NS_QUALIFIER::NumericTraits<half_t>
: CUB_NS_QUALIFIER::BaseTraits<FLOATING_POINT, true, unsigned short, half_t>
{};
struct CUB_NS_QUALIFIER::detail::unsigned_bits<half_t, void>
{
using type = unsigned short;
};

#ifdef __GNUC__
# pragma GCC diagnostic pop
#endif
// template <>
// struct CUB_NS_QUALIFIER::detail::NumericTraits<half_t>
// : CUB_NS_QUALIFIER::detail::BaseTraits<FLOATING_POINT, true, unsigned short, half_t>
// {};
2 changes: 1 addition & 1 deletion cub/cub/agent/agent_sub_warp_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class AgentSubWarpSort

_CCCL_DEVICE static bool get_oob_default(::cuda::std::true_type /* is bool */)
{
// Traits<KeyT>::MAX_KEY for `bool` is 0xFF which is different from `true` and makes
// key_traits<KeyT>::max_key for `bool` is 0xFF which is different from `true` and makes
// comparison with oob unreliable.
return !IS_DESCENDING;
}
Expand Down
36 changes: 21 additions & 15 deletions cub/cub/block/radix_rank_sort_operations.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
#include <cuda/std/__algorithm/min.h>
#include <cuda/std/cstdint>
#include <cuda/std/tuple>
#include <cuda/std/type_traits>
#include <cuda/type_traits>
#include <cuda/type_traits>

CUB_NAMESPACE_BEGIN
Expand All @@ -77,8 +77,7 @@ CUB_NAMESPACE_BEGIN
template <typename KeyT, bool IsFP = ::cuda::is_floating_point_v<KeyT>>
struct BaseDigitExtractor
{
using TraitsT = Traits<KeyT>;
using UnsignedBits = typename TraitsT::UnsignedBits;
using UnsignedBits = typename key_traits<KeyT>::unsigned_bits;

static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits ProcessFloatMinusZero(UnsignedBits key)
{
Expand All @@ -89,14 +88,13 @@ struct BaseDigitExtractor
template <typename KeyT>
struct BaseDigitExtractor<KeyT, true>
{
using TraitsT = Traits<KeyT>;
using UnsignedBits = typename TraitsT::UnsignedBits;
using UnsignedBits = typename key_traits<KeyT>::unsigned_bits;

static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits ProcessFloatMinusZero(UnsignedBits key)
{
UnsignedBits TWIDDLED_MINUS_ZERO_BITS =
TraitsT::TwiddleIn(UnsignedBits(1) << UnsignedBits(8 * sizeof(UnsignedBits) - 1));
UnsignedBits TWIDDLED_ZERO_BITS = TraitsT::TwiddleIn(0);
key_traits<KeyT>::twiddle_in(UnsignedBits(1) << UnsignedBits(8 * sizeof(UnsignedBits) - 1));
UnsignedBits TWIDDLED_ZERO_BITS = key_traits<KeyT>::twiddle_in(0);
return key == TWIDDLED_MINUS_ZERO_BITS ? TWIDDLED_ZERO_BITS : key;
}
};
Expand Down Expand Up @@ -209,7 +207,7 @@ struct is_fundamental_type
};

template <class T>
struct is_fundamental_type<T, ::cuda::std::void_t<typename Traits<T>::UnsignedBits>>
struct is_fundamental_type<T, ::cuda::std::void_t<typename key_traits<T>::unsigned_bits>>
{
static constexpr bool value = true;
};
Expand All @@ -233,23 +231,23 @@ using decomposer_check_t = is_tuple_of_references_to_fundamental_types_t<invoke_
template <class T>
struct bit_ordered_conversion_policy_t
{
using bit_ordered_type = typename Traits<T>::UnsignedBits;
using bit_ordered_type = typename key_traits<T>::unsigned_bits;

static _CCCL_HOST_DEVICE bit_ordered_type to_bit_ordered(detail::identity_decomposer_t, bit_ordered_type val)
{
return Traits<T>::TwiddleIn(val);
return key_traits<T>::twiddle_in(val);
}

static _CCCL_HOST_DEVICE bit_ordered_type from_bit_ordered(detail::identity_decomposer_t, bit_ordered_type val)
{
return Traits<T>::TwiddleOut(val);
return key_traits<T>::twiddle_out(val);
}
};

template <class T>
struct bit_ordered_inversion_policy_t
{
using bit_ordered_type = typename Traits<T>::UnsignedBits;
using bit_ordered_type = typename key_traits<T>::unsigned_bits;

static _CCCL_HOST_DEVICE bit_ordered_type inverse(detail::identity_decomposer_t, bit_ordered_type val)
{
Expand All @@ -260,7 +258,7 @@ struct bit_ordered_inversion_policy_t
template <class T, bool = is_fundamental_type<T>::value>
struct traits_t
{
using bit_ordered_type = typename Traits<T>::UnsignedBits;
using bit_ordered_type = typename key_traits<T>::unsigned_bits;
using bit_ordered_conversion_policy = bit_ordered_conversion_policy_t<T>;
using bit_ordered_inversion_policy = bit_ordered_inversion_policy_t<T>;

Expand All @@ -269,12 +267,20 @@ struct traits_t

static _CCCL_HOST_DEVICE bit_ordered_type min_raw_binary_key(detail::identity_decomposer_t)
{
return Traits<T>::LOWEST_KEY;
// TODO(bgruber): sanity check, remove eventually
_CCCL_SUPPRESS_DEPRECATED_PUSH
static_assert(key_traits<T>::lowest_key == Traits<T>::LOWEST_KEY, "");
_CCCL_SUPPRESS_DEPRECATED_POP
return key_traits<T>::lowest_key;
}

static _CCCL_HOST_DEVICE bit_ordered_type max_raw_binary_key(detail::identity_decomposer_t)
{
return Traits<T>::MAX_KEY;
// TODO(bgruber): sanity check, remove eventually
_CCCL_SUPPRESS_DEPRECATED_PUSH
static_assert(key_traits<T>::max_key == Traits<T>::MAX_KEY, "");
_CCCL_SUPPRESS_DEPRECATED_POP
return key_traits<T>::max_key;
}

static _CCCL_HOST_DEVICE int default_end_bit(detail::identity_decomposer_t)
Expand Down
Loading

0 comments on commit a15fd10

Please sign in to comment.