Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes
Browse files Browse the repository at this point in the history
bernhardmgruber committed Jan 30, 2025
1 parent 0b01e7e commit 207b50c
Showing 26 changed files with 143 additions and 90 deletions.
18 changes: 10 additions & 8 deletions c2h/generators.cu
Original file line number Diff line number Diff line change
@@ -44,11 +44,13 @@

#include <cstdint>

#include <c2h/bfloat16.cuh>
#include <c2h/custom_type.h>
#include <c2h/device_policy.h>
#include <c2h/extended_types.h>
#include <c2h/fill_striped.h>
#include <c2h/generators.h>
#include <c2h/half.cuh>
#include <c2h/vector.h>

#if C2H_HAS_CURAND
@@ -455,15 +457,15 @@ template void
init_key_segments(const c2h::device_vector<std::uint32_t>& segment_offsets, float* out, std::size_t element_size);
template void init_key_segments(
const c2h::device_vector<std::uint32_t>& segment_offsets, custom_type_state_t* out, std::size_t element_size);
#ifdef _LIBCUDACXX_HAS_NVFP16
#ifdef TEST_HALF_T
template void
init_key_segments(const c2h::device_vector<std::uint32_t>& segment_offsets, half_t* out, std::size_t element_size);
#endif // _LIBCUDACXX_HAS_NVFP16
#endif // TEST_HALF_T

#ifdef _LIBCUDACXX_HAS_NVBF16
#ifdef TEST_BF_T
template void
init_key_segments(const c2h::device_vector<std::uint32_t>& segment_offsets, bfloat16_t* out, std::size_t element_size);
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // TEST_BF_T
} // namespace detail

template <typename T>
@@ -529,15 +531,15 @@ INSTANTIATE(double);
INSTANTIATE(bool);
INSTANTIATE(char);

#ifdef _CCCL_HAS_NVFP16
#ifdef TEST_HALF_T
INSTANTIATE(half_t);
INSTANTIATE(__half);
#endif // _CCCL_HAS_NVFP16
#endif // TEST_HALF_T

#ifdef _CCCL_HAS_NVBF16
#ifdef TEST_BF_T
INSTANTIATE(bfloat16_t);
INSTANTIATE(__nv_bfloat16);
#endif // _CCCL_HAS_NVBF16
#endif // TEST_BF_T

#undef INSTANTIATE_RND
#undef INSTANTIATE_MOD
27 changes: 1 addition & 26 deletions c2h/include/c2h/test_util_vec.h
Original file line number Diff line number Diff line change
@@ -290,7 +290,7 @@ C2H_VEC_OVERLOAD(ulonglong, unsigned long long)
C2H_VEC_OVERLOAD(float, float)
C2H_VEC_OVERLOAD(double, double)

// Specialize cub::NumericTraits<T> and cuda::std::numeric_limits for vector types.
// Specialize cuda::std::numeric_limits for vector types.

# define REPEAT_TO_LIST_1(a) a
# define REPEAT_TO_LIST_2(a) a, a
@@ -299,31 +299,6 @@ C2H_VEC_OVERLOAD(double, double)
# define REPEAT_TO_LIST(N, a) _CCCL_PP_CAT(REPEAT_TO_LIST_, N)(a)

# define C2H_VEC_TRAITS_OVERLOAD_IMPL(T, BaseT, N) \
CUB_NAMESPACE_BEGIN \
namespace detail \
{ \
template <> \
struct NumericTraits<T> \
{ \
static constexpr Category CATEGORY = NOT_A_NUMBER; \
enum \
{ \
PRIMITIVE = false, \
NULL_TYPE = false, \
}; \
static __host__ __device__ T Max() \
{ \
T retval = {REPEAT_TO_LIST(N, ::cuda::std::numeric_limits<BaseT>::max())}; \
return retval; \
} \
static __host__ __device__ T Lowest() \
{ \
T retval = {REPEAT_TO_LIST(N, ::cuda::std::numeric_limits<BaseT>::lowest())}; \
return retval; \
} \
}; \
} \
CUB_NAMESPACE_END \
_LIBCUDACXX_BEGIN_NAMESPACE_STD \
template <> \
class numeric_limits<T> \
2 changes: 1 addition & 1 deletion cub/cub/agent/agent_reduce.cuh
Original file line number Diff line number Diff line change
@@ -164,7 +164,7 @@ struct AgentReduce
// pointer to a primitive type
static constexpr bool ATTEMPT_VECTORIZATION =
(VECTOR_LOAD_LENGTH > 1) && (ITEMS_PER_THREAD % VECTOR_LOAD_LENGTH == 0)
&& (::cuda::std::is_pointer<InputIteratorT>::value) && Traits<InputT>::PRIMITIVE;
&& (::cuda::std::is_pointer<InputIteratorT>::value) && is_primitive<InputT>::value;

static constexpr CacheLoadModifier LOAD_MODIFIER = AgentReducePolicy::LOAD_MODIFIER;

2 changes: 1 addition & 1 deletion cub/cub/agent/agent_reduce_by_key.cuh
Original file line number Diff line number Diff line change
@@ -228,7 +228,7 @@ struct AgentReduceByKey
// Whether or not the scan operation has a zero-valued identity value (true
// if we're performing addition on a primitive type)
static constexpr int HAS_IDENTITY_ZERO =
(std::is_same<ReductionOpT, ::cuda::std::plus<>>::value) && (Traits<AccumT>::PRIMITIVE);
(std::is_same<ReductionOpT, ::cuda::std::plus<>>::value) && is_primitive<AccumT>::value;

// Cache-modified Input iterator wrapper type (for applying cache modifier)
// for keys Wrap the native input pointer with
2 changes: 1 addition & 1 deletion cub/cub/agent/agent_segment_fixup.cuh
Original file line number Diff line number Diff line change
@@ -171,7 +171,7 @@ struct AgentSegmentFixup

// Whether or not the scan operation has a zero-valued identity value
// (true if we're performing addition on a primitive type)
HAS_IDENTITY_ZERO = (std::is_same<ReductionOpT, ::cuda::std::plus<>>::value) && (Traits<ValueT>::PRIMITIVE),
HAS_IDENTITY_ZERO = (std::is_same<ReductionOpT, ::cuda::std::plus<>>::value) && is_primitive<ValueT>::value,
};

// Cache-modified Input iterator wrapper type (for applying cache modifier) for keys
2 changes: 1 addition & 1 deletion cub/cub/agent/agent_sub_warp_merge_sort.cuh
Original file line number Diff line number Diff line change
@@ -161,7 +161,7 @@ class AgentSubWarpSort

_CCCL_DEVICE static bool get_oob_default(Int2Type<true> /* is bool */)
{
// Traits<KeyT>::MAX_KEY for `bool` is 0xFF which is different from `true` and makes
// key_traits<KeyT>::max_key for `bool` is 0xFF which is different from `true` and makes
// comparison with oob unreliable.
return !IS_DESCENDING;
}
12 changes: 4 additions & 8 deletions cub/cub/agent/single_pass_scan_operators.cuh
Original file line number Diff line number Diff line change
@@ -486,14 +486,14 @@ using default_no_delay_t = default_no_delay_constructor_t::delay_t;

template <class T>
using default_delay_constructor_t =
::cuda::std::_If<Traits<T>::PRIMITIVE, fixed_delay_constructor_t<350, 450>, default_no_delay_constructor_t>;
::cuda::std::_If<is_primitive<T>::value, fixed_delay_constructor_t<350, 450>, default_no_delay_constructor_t>;

template <class T>
using default_delay_t = typename default_delay_constructor_t<T>::delay_t;

template <class KeyT, class ValueT>
using default_reduce_by_key_delay_constructor_t =
::cuda::std::_If<(Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16),
::cuda::std::_If<is_primitive<ValueT>::value && (sizeof(ValueT) + sizeof(KeyT) < 16),
reduce_by_key_delay_constructor_t<350, 450>,
default_delay_constructor_t<KeyValuePair<KeyT, ValueT>>>;

@@ -545,10 +545,8 @@ struct tile_state_with_memory_order
/**
* Tile status interface.
*/
_CCCL_SUPPRESS_DEPRECATED_PUSH
template <typename T, bool SINGLE_WORD = Traits<T>::PRIMITIVE>
template <typename T, bool SINGLE_WORD = detail::is_primitive<T>::value>
struct ScanTileState;
_CCCL_SUPPRESS_DEPRECATED_POP

/**
* Tile status interface specialized for scan status and value types
@@ -950,12 +948,10 @@ struct ScanTileState<T, false>
* Tile status interface for reduction by key.
*
*/
_CCCL_SUPPRESS_DEPRECATED_PUSH
template <typename ValueT,
typename KeyT,
bool SINGLE_WORD = (Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16)>
bool SINGLE_WORD = detail::is_primitive<ValueT>::value && (sizeof(ValueT) + sizeof(KeyT) < 16)>
struct ReduceByKeyScanTileState;
_CCCL_SUPPRESS_DEPRECATED_POP

/**
* Tile status interface for reduction by key, specialized for scan status and value types that
55 changes: 52 additions & 3 deletions cub/cub/block/radix_rank_sort_operations.cuh
Original file line number Diff line number Diff line change
@@ -76,6 +76,7 @@ CUB_NAMESPACE_BEGIN
template <typename KeyT, bool IsFP = ::cuda::is_floating_point_v<KeyT>>
struct BaseDigitExtractor
{
// TODO(bgruber): sanity check, remove eventually
_CCCL_SUPPRESS_DEPRECATED_PUSH
static_assert(Traits<KeyT>::CATEGORY != FLOATING_POINT, "");
_CCCL_SUPPRESS_DEPRECATED_POP
@@ -91,6 +92,7 @@ struct BaseDigitExtractor
template <typename KeyT>
struct BaseDigitExtractor<KeyT, true>
{
// TODO(bgruber): sanity check, remove eventually
_CCCL_SUPPRESS_DEPRECATED_PUSH
static_assert(Traits<KeyT>::CATEGORY == FLOATING_POINT, "");
_CCCL_SUPPRESS_DEPRECATED_POP
@@ -214,7 +216,7 @@ struct is_fundamental_type
};

template <class T>
struct is_fundamental_type<T, ::cuda::std::void_t<typename Traits<T>::UnsignedBits>>
struct is_fundamental_type<T, ::cuda::std::void_t<typename unsigned_bits<T>::type>>
{
static constexpr bool value = true;
};
@@ -262,6 +264,45 @@ struct bit_ordered_inversion_policy_t
}
};

template <typename T, typename SFINAE = void>
struct key_traits;

template <typename T>
struct key_traits<T,
::cuda::std::enable_if_t<(::cuda::std::is_integral<T>::value && ::cuda::std::is_unsigned<T>::value)
# if CUB_IS_INT128_ENABLED
|| ::cuda::std::is_same<T, __uint128_t>::value
# endif
>>
{
using unsigned_bits = unsigned_bits_t<T>;
static constexpr unsigned_bits lowest_key = unsigned_bits(0);
static constexpr unsigned_bits max_key = unsigned_bits(-1);
};

template <typename T>
struct key_traits<T,
::cuda::std::enable_if_t<(::cuda::std::is_integral<T>::value && ::cuda::std::is_signed<T>::value)
# if CUB_IS_INT128_ENABLED
|| ::cuda::std::is_same<T, __int128_t>::value>
# endif
>
{
using unsigned_bits = unsigned_bits_t<T>;
static constexpr unsigned_bits high_bit = unsigned_bits(1) << ((sizeof(unsigned_bits) * CHAR_BIT) - 1);
static constexpr unsigned_bits lowest_key = high_bit;
static constexpr unsigned_bits max_key = unsigned_bits(-1) ^ lowest_key;
};

template <typename T>
struct key_traits<T, ::cuda::std::enable_if_t<::cuda::is_floating_point<T>::value>>
{
using unsigned_bits = unsigned_bits_t<T>;
static constexpr unsigned_bits high_bit = unsigned_bits(1) << ((sizeof(unsigned_bits) * CHAR_BIT) - 1);
static constexpr unsigned_bits lowest_key = unsigned_bits(-1);
static constexpr unsigned_bits max_key = unsigned_bits(-1) ^ high_bit;
};

template <class T, bool = is_fundamental_type<T>::value>
struct traits_t
{
@@ -274,12 +315,20 @@ struct traits_t

static _CCCL_HOST_DEVICE bit_ordered_type min_raw_binary_key(detail::identity_decomposer_t)
{
return Traits<T>::LOWEST_KEY;
// TODO(bgruber): sanity check, remove eventually
_CCCL_SUPPRESS_DEPRECATED_PUSH
static_assert(key_traits<T>::lowest_key == Traits<T>::LOWEST_KEY, "");
_CCCL_SUPPRESS_DEPRECATED_POP
return key_traits<T>::lowest_key;
}

static _CCCL_HOST_DEVICE bit_ordered_type max_raw_binary_key(detail::identity_decomposer_t)
{
return Traits<T>::MAX_KEY;
// TODO(bgruber): sanity check, remove eventually
_CCCL_SUPPRESS_DEPRECATED_PUSH
static_assert(key_traits<T>::max_key == Traits<T>::MAX_KEY, "");
_CCCL_SUPPRESS_DEPRECATED_POP
return key_traits<T>::max_key;
}

static _CCCL_HOST_DEVICE int default_end_bit(detail::identity_decomposer_t)
2 changes: 1 addition & 1 deletion cub/cub/device/dispatch/tuning/tuning_histogram.cuh
Original file line number Diff line number Diff line change
@@ -72,7 +72,7 @@ enum class counter_size
template <class T>
constexpr primitive_sample is_primitive_sample()
{
return Traits<T>::PRIMITIVE ? primitive_sample::yes : primitive_sample::no;
return is_primitive<T>::value ? primitive_sample::yes : primitive_sample::no;
}

template <class CounterT>
4 changes: 2 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh
Original file line number Diff line number Diff line change
@@ -90,13 +90,13 @@ enum class accum_size
template <class T>
constexpr primitive_key is_primitive_key()
{
return Traits<T>::PRIMITIVE ? primitive_key::yes : primitive_key::no;
return is_primitive<T>::value ? primitive_key::yes : primitive_key::no;
}

template <class T>
constexpr primitive_accum is_primitive_accum()
{
return Traits<T>::PRIMITIVE ? primitive_accum::yes : primitive_accum::no;
return is_primitive<T>::value ? primitive_accum::yes : primitive_accum::no;
}

template <class ReductionOpT>
4 changes: 2 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_run_length_encode.cuh
Original file line number Diff line number Diff line change
@@ -82,13 +82,13 @@ enum class length_size
template <class T>
constexpr primitive_key is_primitive_key()
{
return Traits<T>::PRIMITIVE ? primitive_key::yes : primitive_key::no;
return is_primitive<T>::value ? primitive_key::yes : primitive_key::no;
}

template <class T>
constexpr primitive_length is_primitive_length()
{
return Traits<T>::PRIMITIVE ? primitive_length::yes : primitive_length::no;
return is_primitive<T>::value ? primitive_length::yes : primitive_length::no;
}

template <class KeyT>
2 changes: 1 addition & 1 deletion cub/cub/device/dispatch/tuning/tuning_scan.cuh
Original file line number Diff line number Diff line change
@@ -88,7 +88,7 @@ enum class accum_size
template <class AccumT>
constexpr primitive_accum is_primitive_accum()
{
return Traits<AccumT>::PRIMITIVE ? primitive_accum::yes : primitive_accum::no;
return is_primitive<AccumT>::value ? primitive_accum::yes : primitive_accum::no;
}

template <class ScanOpT>
2 changes: 1 addition & 1 deletion cub/cub/device/dispatch/tuning/tuning_scan_by_key.cuh
Original file line number Diff line number Diff line change
@@ -91,7 +91,7 @@ enum class key_size
template <class AccumT>
constexpr primitive_accum is_primitive_accum()
{
return Traits<AccumT>::PRIMITIVE ? primitive_accum::yes : primitive_accum::no;
return is_primitive<AccumT>::value ? primitive_accum::yes : primitive_accum::no;
}

template <class ScanOpT>
2 changes: 1 addition & 1 deletion cub/cub/device/dispatch/tuning/tuning_select_if.cuh
Original file line number Diff line number Diff line change
@@ -514,7 +514,7 @@ struct sm90_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4
template <class InputT>
constexpr primitive is_primitive()
{
return Traits<InputT>::PRIMITIVE ? primitive::yes : primitive::no;
return detail::is_primitive<InputT>::value ? primitive::yes : primitive::no;
}

template <class FlagT>
4 changes: 2 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh
Original file line number Diff line number Diff line change
@@ -85,13 +85,13 @@ enum class val_size
template <class T>
constexpr primitive_key is_primitive_key()
{
return Traits<T>::PRIMITIVE ? primitive_key::yes : primitive_key::no;
return is_primitive<T>::value ? primitive_key::yes : primitive_key::no;
}

template <class T>
constexpr primitive_val is_primitive_val()
{
return Traits<T>::PRIMITIVE ? primitive_val::yes : primitive_val::no;
return is_primitive<T>::value ? primitive_val::yes : primitive_val::no;
}

template <class KeyT>
2 changes: 1 addition & 1 deletion cub/cub/thread/thread_load.cuh
Original file line number Diff line number Diff line change
@@ -355,7 +355,7 @@ template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE T
ThreadLoad(const T* ptr, Int2Type<LOAD_VOLATILE> /*modifier*/, Int2Type<true> /*is_pointer*/)
{
return ThreadLoadVolatilePointer(ptr, Int2Type<Traits<T>::PRIMITIVE>());
return ThreadLoadVolatilePointer(ptr, Int2Type<detail::is_primitive<T>::value>{});
}

/**
2 changes: 1 addition & 1 deletion cub/cub/thread/thread_store.cuh
Original file line number Diff line number Diff line change
@@ -321,7 +321,7 @@ template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStore(T* ptr, T val, Int2Type<STORE_VOLATILE> /*modifier*/, Int2Type<true> /*is_pointer*/)
{
ThreadStoreVolatilePtr(ptr, val, Int2Type<Traits<T>::PRIMITIVE>());
ThreadStoreVolatilePtr(ptr, val, Int2Type<detail::is_primitive<T>::value>{});
}

/**
Loading

0 comments on commit 207b50c

Please sign in to comment.