Skip to content

Commit

Permalink
Replace CUB iterators by Thrust ones (#3480)
Browse files Browse the repository at this point in the history
Also consider thrust::discard_iterator's value_type void

Fixes: #3261
  • Loading branch information
bernhardmgruber authored Feb 5, 2025
1 parent 0a578d5 commit f4545e0
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 52 deletions.
9 changes: 3 additions & 6 deletions cub/cub/device/device_run_length_encode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
#include <cub/device/dispatch/dispatch_reduce_by_key.cuh>
#include <cub/device/dispatch/dispatch_rle.cuh>
#include <cub/device/dispatch/tuning/tuning_run_length_encode.cuh>
#include <cub/iterator/constant_input_iterator.cuh>

#include <thrust/iterator/constant_iterator.h>

#include <iterator>

Expand Down Expand Up @@ -200,17 +201,14 @@ struct DeviceRunLengthEncode
using length_t = cub::detail::non_void_value_t<LengthsOutputIteratorT, offset_t>;

// Generator type for providing 1s values for run-length reduction
_CCCL_SUPPRESS_DEPRECATED_PUSH
using lengths_input_iterator_t = ConstantInputIterator<length_t, offset_t>;
_CCCL_SUPPRESS_DEPRECATED_POP
using lengths_input_iterator_t = THRUST_NS_QUALIFIER::constant_iterator<length_t, offset_t>;

using accum_t = ::cuda::std::__accumulator_t<reduction_op, length_t, length_t>;

using key_t = cub::detail::non_void_value_t<UniqueOutputIteratorT, cub::detail::value_t<InputIteratorT>>;

using policy_t = detail::rle::encode::policy_hub<accum_t, key_t>;

_CCCL_SUPPRESS_DEPRECATED_PUSH
return DispatchReduceByKey<
InputIteratorT,
UniqueOutputIteratorT,
Expand All @@ -232,7 +230,6 @@ struct DeviceRunLengthEncode
reduction_op(),
num_items,
stream);
_CCCL_SUPPRESS_DEPRECATED_POP
}

//! @rst
Expand Down
17 changes: 2 additions & 15 deletions cub/cub/device/dispatch/dispatch_streaming_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

#include <cub/device/dispatch/dispatch_reduce.cuh>
#include <cub/iterator/arg_index_input_iterator.cuh>
#include <cub/iterator/constant_input_iterator.cuh>

#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/tabulate_output_iterator.h>

Expand All @@ -25,8 +25,6 @@

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document

// suppress deprecation warnings for ConstantInputIterator
_CCCL_SUPPRESS_DEPRECATED_PUSH
CUB_NAMESPACE_BEGIN

namespace detail::reduce
Expand Down Expand Up @@ -189,12 +187,6 @@ template <typename InputIteratorT,
detail::reduce::policy_hub<KeyValuePair<PerPartitionOffsetT, InitT>, PerPartitionOffsetT, ReductionOpT>>
struct dispatch_streaming_arg_reduce_t
{
# if _CCCL_COMPILER(NVHPC)
// NVHPC fails to suppress a deprecation when the alias is inside the function below, so we put it here and span a
// deprecation suppression region across the entire file as well
using constant_offset_it_t = ConstantInputIterator<GlobalOffsetT>;
# endif // _CCCL_COMPILER(NVHPC)

// Internal dispatch routine for computing a device-wide argument extremum, like `ArgMin` and `ArgMax`
//
// @param[in] d_temp_storage
Expand Down Expand Up @@ -234,11 +226,7 @@ struct dispatch_streaming_arg_reduce_t
cudaStream_t stream)
{
// Constant iterator to provide the offset of the current partition for the user-provided input iterator
# if !_CCCL_COMPILER(NVHPC)
_CCCL_SUPPRESS_DEPRECATED_PUSH
using constant_offset_it_t = ConstantInputIterator<GlobalOffsetT>;
_CCCL_SUPPRESS_DEPRECATED_POP
# endif
using constant_offset_it_t = THRUST_NS_QUALIFIER::constant_iterator<GlobalOffsetT>;

// Wrapped input iterator to produce index-value tuples, i.e., <PerPartitionOffsetT, InputT>-tuples
// We make sure to offset the user-provided input iterator by the current partition's offset
Expand Down Expand Up @@ -382,7 +370,6 @@ struct dispatch_streaming_arg_reduce_t
};

} // namespace detail::reduce
_CCCL_SUPPRESS_DEPRECATED_POP
CUB_NAMESPACE_END

#endif // !_CCCL_DOXYGEN_INVOKED
10 changes: 9 additions & 1 deletion cub/cub/util_type.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

#include <cub/detail/uninitialized_copy.cuh>

#include <thrust/iterator/discard_iterator.h>

#include <cuda/std/cstdint>
#include <cuda/std/limits>
#include <cuda/std/type_traits>
Expand Down Expand Up @@ -107,7 +109,13 @@ struct non_void_value_impl
template <typename It, typename FallbackT>
struct non_void_value_impl<It, FallbackT, false>
{
using type = ::cuda::std::_If<::cuda::std::is_void<value_t<It>>::value, FallbackT, value_t<It>>;
// we consider thrust::discard_iterator's value_type as `void` as well, so users can switch from
// cub::DiscardInputIterator to thrust::discard_iterator.
using type =
::cuda::std::_If<::cuda::std::is_void<value_t<It>>::value
|| ::cuda::std::is_same<value_t<It>, THRUST_NS_QUALIFIER::discard_iterator<>::value_type>::value,
FallbackT,
value_t<It>>;
};

/**
Expand Down
4 changes: 0 additions & 4 deletions cub/test/catch2_test_device_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@
#include <c2h/custom_type.h>
#include <c2h/extended_types.h>

// need to suppress deprecation warnings for ConstantInputIterator in the cudafe1.stub.c file, so there is no matching
// _CCCL_SUPPRESS_DEPRECATED_POP at the end of this file
_CCCL_SUPPRESS_DEPRECATED_PUSH

DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::Reduce, device_reduce);
DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::Sum, device_sum);
DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::Min, device_min);
Expand Down
3 changes: 0 additions & 3 deletions cub/test/catch2_test_device_reduce_fp_inf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::ArgMin, device_arg_min_old);
DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::ArgMax, device_arg_max_old);
_CCCL_SUPPRESS_DEPRECATED_POP

// suppress deprecation of ConstantInputIterator in cudafe1.stub.c file
_CCCL_SUPPRESS_DEPRECATED_PUSH

// %PARAM% TEST_LAUNCH lid 0:1

C2H_TEST("Device reduce arg{min,max} works with inf items", "[reduce][device]")
Expand Down
3 changes: 0 additions & 3 deletions cub/test/catch2_test_device_reduce_large_offsets.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::ArgMin, device_arg_min);
DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::Max, device_max);
DECLARE_LAUNCH_WRAPPER(cub::DeviceReduce::ArgMax, device_arg_max);

// suppress deprecation of ConstantInputIterator in cudafe1.stub.c file
_CCCL_SUPPRESS_DEPRECATED_PUSH

// %PARAM% TEST_LAUNCH lid 0:1:2

// List of offset types to test
Expand Down
14 changes: 0 additions & 14 deletions cub/test/catch2_test_device_run_length_encode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
*
******************************************************************************/

#include <cuda/__cccl_config>

#if _CCCL_COMPILER(NVHPC)
// to suppress warnings for CountingInputIterator
_CCCL_SUPPRESS_DEPRECATED_PUSH
#endif // _CCCL_COMPILER(NVHPC)

#include "insert_nested_NVTX_range_guard.h"
// above header needs to be included first

Expand All @@ -50,9 +43,6 @@ _CCCL_SUPPRESS_DEPRECATED_PUSH

DECLARE_LAUNCH_WRAPPER(cub::DeviceRunLengthEncode::Encode, run_length_encode);

// suppress deprecation of ConstantInputIterator in cudafe1.stub.c file
_CCCL_SUPPRESS_DEPRECATED_PUSH

// %PARAM% TEST_LAUNCH lid 0:1:2

using all_types =
Expand Down Expand Up @@ -274,7 +264,3 @@ C2H_TEST("DeviceRunLengthEncode::Encode can handle leading NaN", "[device][run_l
REQUIRE(out_counts == reference_counts);
REQUIRE(out_num_runs == reference_num_runs);
}

#if _CCCL_COMPILER(NVHPC)
_CCCL_SUPPRESS_DEPRECATED_POP
#endif // _CCCL_COMPILER(NVHPC)
11 changes: 5 additions & 6 deletions cub/test/catch2_test_util_type.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
*
******************************************************************************/

#include <cub/iterator/counting_input_iterator.cuh>
#include <cub/iterator/discard_output_iterator.cuh>
#include <cub/util_type.cuh>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>

#include <cuda/std/type_traits>

#include <c2h/catch2_test_helper.h>
#include <c2h/extended_types.h>

C2H_TEST("Tests non_void_value_t", "[util][type]")
{
_CCCL_SUPPRESS_DEPRECATED_PUSH
using fallback_t = float;
using void_fancy_it = cub::DiscardOutputIterator<std::size_t>;
using non_void_fancy_it = cub::CountingInputIterator<int>;
using void_fancy_it = thrust::discard_iterator<std::size_t>;
using non_void_fancy_it = thrust::counting_iterator<int>;

// falls back for const void*
STATIC_REQUIRE(::cuda::std::is_same<fallback_t, //
Expand All @@ -62,7 +62,6 @@ C2H_TEST("Tests non_void_value_t", "[util][type]")
// works for a fancy iterator that has int as value type
STATIC_REQUIRE(::cuda::std::is_same<int, //
cub::detail::non_void_value_t<non_void_fancy_it, fallback_t>>::value);
_CCCL_SUPPRESS_DEPRECATED_POP
}

CUB_DEFINE_DETECT_NESTED_TYPE(cat_detect, cat);
Expand Down

0 comments on commit f4545e0

Please sign in to comment.