From 1e1af8d4d62ed15cc127cdb112d6b506e27668f7 Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Wed, 21 Aug 2024 09:48:56 +0200 Subject: [PATCH] Use/Test radix sort for int128, half, bfloat16 in Thrust (#2168) int128 was already working but not covered by a test. --- thrust/testing/cuda/sort.cu | 59 +++++++++++++++++++++++-- thrust/testing/unittest/meta.h | 15 +++++++ thrust/thrust/system/cuda/detail/sort.h | 11 ++++- 3 files changed, 81 insertions(+), 4 deletions(-) diff --git a/thrust/testing/cuda/sort.cu b/thrust/testing/cuda/sort.cu index 8e7e5542e79..6962a396404 100644 --- a/thrust/testing/cuda/sort.cu +++ b/thrust/testing/cuda/sort.cu @@ -9,6 +9,7 @@ #include +#include #include #include @@ -164,9 +165,23 @@ struct TestRadixSortDispatch void operator()() const {} }; -// TODO(bgruber): use a single test case with a concatenated key list and a cartesion product with the comparators -SimpleUnitTest TestRadixSortDispatchIntegralInstance; -SimpleUnitTest TestRadixSortDispatchFPInstance; +SimpleUnitTest +#endif // _LIBCUDACXX_HAS_NO_INT128 +#ifdef _CCCL_HAS_NVFP16 + , + unittest::type_list<__half> +#endif // _CCCL_HAS_NVFP16 +#ifdef _CCCL_HAS_NVBF16 + , + unittest::type_list<__nv_bfloat16> +#endif // _CCCL_HAS_NVBF16 + >> + TestRadixSortDispatchInstance; /** * Copy of CUB testing utility @@ -263,3 +278,41 @@ void TestSortWithLargeNumberOfItems() TestSortWithMagnitude(33); } DECLARE_UNITTEST(TestSortWithLargeNumberOfItems); + +template +struct TestSortAscendingKey +{ + void operator()() const + { + constexpr int n = 10000; + + thrust::host_vector h_data = unittest::random_integers(n); + thrust::device_vector d_data = h_data; + + std::sort(h_data.begin(), h_data.end(), thrust::less{}); + thrust::sort(d_data.begin(), d_data.end(), thrust::less{}); + + ASSERT_EQUAL_QUIET(h_data, d_data); + } +}; + +SimpleUnitTest +#ifndef _LIBCUDACXX_HAS_NO_INT128 + , + unittest::type_list<__int128_t, __uint128_t> +#endif +// CTK 12.2 offers __host__ __device__ operators for __half and __nv_bfloat16, so we can use std::sort +#if _CCCL_CUDACC_VER >= 1202000 +# if defined(_CCCL_HAS_NVFP16) || !defined(__CUDA_NO_HALF_OPERATORS__) && !defined(__CUDA_NO_HALF_CONVERSIONS__) + , + unittest::type_list<__half> +# endif +# if defined(_CCCL_HAS_NVBF16) \ + || !defined(__CUDA_NO_BFLOAT16_OPERATORS__) && !defined(__CUDA_NO_BFLOAT16_CONVERSIONS__) + , + unittest::type_list<__nv_bfloat16> +# endif +#endif // _CCCL_CUDACC_VER >= 1202000 + >> + TestSortAscendingKeyMoreTypes; diff --git a/thrust/testing/unittest/meta.h b/thrust/testing/unittest/meta.h index 7fd90fa0149..30cb835d55e 100644 --- a/thrust/testing/unittest/meta.h +++ b/thrust/testing/unittest/meta.h @@ -157,4 +157,19 @@ struct transform2, type_list, Template> using type = type_list::type...>; }; +template +struct concat; + +template +struct concat +{ + using type = L; +}; + +template