Skip to content

Commit

Permalink
Use/Test radix sort for int128, half, bfloat16 in Thrust (#2168)
Browse files Browse the repository at this point in the history
int128 was already working but not covered by a test.
  • Loading branch information
bernhardmgruber authored Aug 21, 2024
1 parent 06e334f commit 1e1af8d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 4 deletions.
59 changes: 56 additions & 3 deletions thrust/testing/cuda/sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <cuda/std/limits>

#include <algorithm>
#include <cstdint>
#include <exception>

Expand Down Expand Up @@ -164,9 +165,23 @@ struct TestRadixSortDispatch

void operator()() const {}
};
// TODO(bgruber): use a single test case with a concatenated key list and a cartesion product with the comparators
SimpleUnitTest<TestRadixSortDispatch, IntegralTypes> TestRadixSortDispatchIntegralInstance;
SimpleUnitTest<TestRadixSortDispatch, FloatingPointTypes> TestRadixSortDispatchFPInstance;
SimpleUnitTest<TestRadixSortDispatch,
unittest::concat<IntegralTypes,
FloatingPointTypes
#ifndef _LIBCUDACXX_HAS_NO_INT128
,
unittest::type_list<__int128_t, __uint128_t>
#endif // _LIBCUDACXX_HAS_NO_INT128
#ifdef _CCCL_HAS_NVFP16
,
unittest::type_list<__half>
#endif // _CCCL_HAS_NVFP16
#ifdef _CCCL_HAS_NVBF16
,
unittest::type_list<__nv_bfloat16>
#endif // _CCCL_HAS_NVBF16
>>
TestRadixSortDispatchInstance;

/**
* Copy of CUB testing utility
Expand Down Expand Up @@ -263,3 +278,41 @@ void TestSortWithLargeNumberOfItems()
TestSortWithMagnitude(33);
}
DECLARE_UNITTEST(TestSortWithLargeNumberOfItems);

template <typename T>
struct TestSortAscendingKey
{
void operator()() const
{
constexpr int n = 10000;

thrust::host_vector<T> h_data = unittest::random_integers<T>(n);
thrust::device_vector<T> d_data = h_data;

std::sort(h_data.begin(), h_data.end(), thrust::less<T>{});
thrust::sort(d_data.begin(), d_data.end(), thrust::less<T>{});

ASSERT_EQUAL_QUIET(h_data, d_data);
}
};

SimpleUnitTest<TestSortAscendingKey,
unittest::concat<unittest::type_list<>
#ifndef _LIBCUDACXX_HAS_NO_INT128
,
unittest::type_list<__int128_t, __uint128_t>
#endif
// CTK 12.2 offers __host__ __device__ operators for __half and __nv_bfloat16, so we can use std::sort
#if _CCCL_CUDACC_VER >= 1202000
# if defined(_CCCL_HAS_NVFP16) || !defined(__CUDA_NO_HALF_OPERATORS__) && !defined(__CUDA_NO_HALF_CONVERSIONS__)
,
unittest::type_list<__half>
# endif
# if defined(_CCCL_HAS_NVBF16) \
|| !defined(__CUDA_NO_BFLOAT16_OPERATORS__) && !defined(__CUDA_NO_BFLOAT16_CONVERSIONS__)
,
unittest::type_list<__nv_bfloat16>
# endif
#endif // _CCCL_CUDACC_VER >= 1202000
>>
TestSortAscendingKeyMoreTypes;
15 changes: 15 additions & 0 deletions thrust/testing/unittest/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,19 @@ struct transform2<type_list<T1s...>, type_list<T2s...>, Template>
using type = type_list<typename ApplyTemplate2<Template, T1s, T2s>::type...>;
};

template <typename... Ls>
struct concat;

template <typename L>
struct concat<L>
{
using type = L;
};

template <template <typename...> class L, typename... T1s, typename... T2s, typename... Ls>
struct concat<L<T1s...>, L<T2s...>, Ls...>
{
using type = concat<L<T1s..., T2s...>, Ls...>;
};

} // namespace unittest
11 changes: 10 additions & 1 deletion thrust/thrust/system/cuda/detail/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,16 @@ namespace __smart_sort
template <class Key, class CompareOp>
using can_use_primitive_sort = ::cuda::std::integral_constant<
bool,
::cuda::std::is_arithmetic<Key>::value
(::cuda::std::is_arithmetic<Key>::value
# if defined(_CCCL_HAS_NVFP16) && !defined(__CUDA_NO_HALF_OPERATORS__) && !defined(__CUDA_NO_HALF_CONVERSIONS__)
|| ::cuda::std::is_same<Key, __half>::value
# endif // defined(_CCCL_HAS_NVFP16) && !defined(__CUDA_NO_HALF_OPERATORS__) && !defined(__CUDA_NO_HALF_CONVERSIONS__)
# if defined(_CCCL_HAS_NVBF16) && !defined(__CUDA_NO_BFLOAT16_CONVERSIONS__) \
&& !defined(__CUDA_NO_BFLOAT16_OPERATORS__)
|| ::cuda::std::is_same<Key, __nv_bfloat16>::value
# endif // defined(_CCCL_HAS_NVBF16) && !defined(__CUDA_NO_BFLOAT16_CONVERSIONS__) &&
// !defined(__CUDA_NO_BFLOAT16_OPERATORS__)
)
&& (::cuda::std::is_same<CompareOp, thrust::less<Key>>::value
|| ::cuda::std::is_same<CompareOp, ::cuda::std::less<Key>>::value
|| ::cuda::std::is_same<CompareOp, thrust::less<void>>::value
Expand Down

0 comments on commit 1e1af8d

Please sign in to comment.