Skip to content

Commit

Permalink
Make thrust iterators work with NVRTC (#3676)
Browse files Browse the repository at this point in the history
* Make thrust iterators work with NVRTC

As a drive-by, all iterator tags used in CUB and Thrust are replaced with ones from libcu++.

Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
bernhardmgruber and miscco authored Feb 5, 2025
1 parent 14614dd commit 3786a08
Show file tree
Hide file tree
Showing 33 changed files with 129 additions and 105 deletions.
11 changes: 11 additions & 0 deletions cub/test/catch2_test_nvrtc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ TEST_CASE("Test nvrtc", "[test][nvrtc]")
#include <cub/device/dispatch/kernels/scan.cuh>
#include <cub/device/dispatch/kernels/merge_sort.cuh>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/iterator/reverse_iterator.h>
#include <thrust/iterator/tabulate_output_iterator.h>
#include <thrust/iterator/transform_input_output_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/iterator/zip_iterator.h>
extern "C" __global__ void kernel(int *ptr, int *errors)
{
constexpr int items_per_thread = 4;
Expand Down
2 changes: 1 addition & 1 deletion thrust/testing/counting_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ _CCCL_DIAG_SUPPRESS_MSVC(4244 4267) // possible loss of data
void test_iterator_traits()
{
using It = cuda::std::iterator_traits<thrust::counting_iterator<int>>;
using category = thrust::detail::iterator_category_with_system_and_traversal<std::random_access_iterator_tag,
using category = thrust::detail::iterator_category_with_system_and_traversal<::cuda::std::random_access_iterator_tag,
thrust::any_system_tag,
thrust::random_access_traversal_tag>;

Expand Down
2 changes: 1 addition & 1 deletion thrust/testing/cuda/adjacent_difference.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct detect_wrong_difference
using value_type = void;
using pointer = void;
using reference = void;
using iterator_category = std::output_iterator_tag;
using iterator_category = ::cuda::std::output_iterator_tag;

bool* flag;

Expand Down
7 changes: 3 additions & 4 deletions thrust/thrust/detail/alignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
#endif // no system header

#include <cuda/cmath>
#include <cuda/std/cstddef> // For `std::size_t` and `std::max_align_t`.
#include <cuda/std/type_traits>

#include <cstddef> // For `std::size_t` and `std::max_align_t`.

THRUST_NAMESPACE_BEGIN
namespace detail
{
Expand All @@ -49,7 +48,7 @@ using alignment_of = ::cuda::std::alignment_of<T>;
/// type whose alignment requirement is a divisor of `Align`.
///
/// The behavior is undefined if `Align` is not a power of 2.
template <std::size_t Align>
template <::cuda::std::size_t Align>
struct aligned_type
{
struct alignas(Align) type
Expand All @@ -74,7 +73,7 @@ _CCCL_HOST_DEVICE T aligned_reinterpret_cast(U u)
return reinterpret_cast<T>(reinterpret_cast<void*>(u));
}

_CCCL_HOST_DEVICE inline std::size_t aligned_storage_size(std::size_t n, std::size_t align)
_CCCL_HOST_DEVICE inline ::cuda::std::size_t aligned_storage_size(::cuda::std::size_t n, ::cuda::std::size_t align)
{
return ::cuda::ceil_div(n, align) * align;
}
Expand Down
20 changes: 13 additions & 7 deletions thrust/thrust/detail/allocator_aware_execution_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <thrust/detail/alignment.h>
#include <thrust/detail/execute_with_allocator_fwd.h>

#include <type_traits>
#include <cuda/std/type_traits>

THRUST_NAMESPACE_BEGIN

Expand Down Expand Up @@ -60,31 +60,37 @@ struct allocator_aware_execution_policy
using type = thrust::detail::execute_with_allocator<Allocator, ExecutionPolicyCRTPBase>;
};

_CCCL_EXEC_CHECK_DISABLE
template <typename MemoryResource>
typename execute_with_memory_resource_type<MemoryResource>::type operator()(MemoryResource* mem_res) const
_CCCL_HOST_DEVICE typename execute_with_memory_resource_type<MemoryResource>::type
operator()(MemoryResource* mem_res) const
{
return typename execute_with_memory_resource_type<MemoryResource>::type(mem_res);
}

_CCCL_EXEC_CHECK_DISABLE
template <typename Allocator>
typename execute_with_allocator_type<Allocator&>::type operator()(Allocator& alloc) const
_CCCL_HOST_DEVICE typename execute_with_allocator_type<Allocator&>::type operator()(Allocator& alloc) const
{
return typename execute_with_allocator_type<Allocator&>::type(alloc);
}

_CCCL_EXEC_CHECK_DISABLE
template <typename Allocator>
typename execute_with_allocator_type<Allocator>::type operator()(const Allocator& alloc) const
_CCCL_HOST_DEVICE typename execute_with_allocator_type<Allocator>::type operator()(const Allocator& alloc) const
{
return typename execute_with_allocator_type<Allocator>::type(alloc);
}

// just the rvalue overload
// perfect forwarding doesn't help, because a const reference has to be turned
// into a value by copying for the purpose of storing it in execute_with_allocator
template <typename Allocator, typename std::enable_if<!std::is_lvalue_reference<Allocator>::value>::type* = nullptr>
typename execute_with_allocator_type<Allocator>::type operator()(Allocator&& alloc) const
_CCCL_EXEC_CHECK_DISABLE
template <typename Allocator,
typename ::cuda::std::enable_if<!::cuda::std::is_lvalue_reference<Allocator>::value>::type* = nullptr>
_CCCL_HOST_DEVICE typename execute_with_allocator_type<Allocator>::type operator()(Allocator&& alloc) const
{
return typename execute_with_allocator_type<Allocator>::type(std::move(alloc));
return typename execute_with_allocator_type<Allocator>::type(::cuda::std::move(alloc));
}
};

Expand Down
8 changes: 6 additions & 2 deletions thrust/thrust/detail/execute_with_allocator_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
# pragma system_header
#endif // no system header

#include <thrust/detail/execute_with_dependencies.h>
#include <thrust/detail/type_traits.h>
#if !_CCCL_COMPILER(NVRTC)
# include <thrust/detail/execute_with_dependencies.h>
#endif // !_CCCL_COMPILER(NVRTC)

THRUST_NAMESPACE_BEGIN

Expand All @@ -53,11 +55,12 @@ _CCCL_SUPPRESS_DEPRECATED_PUSH // because of execute_with_allocator_and_dependen
: alloc(alloc_)
{}

::cuda::std::remove_reference_t<Allocator>& get_allocator()
_CCCL_HOST_DEVICE ::cuda::std::remove_reference_t<Allocator>& get_allocator()
{
return alloc;
}

#if !_CCCL_COMPILER(NVRTC)
template <typename... Dependencies>
CCCL_DEPRECATED _CCCL_HOST execute_with_allocator_and_dependencies<Allocator, BaseSystem, Dependencies...>
after(Dependencies&&... dependencies) const
Expand Down Expand Up @@ -97,6 +100,7 @@ _CCCL_SUPPRESS_DEPRECATED_PUSH // because of execute_with_allocator_and_dependen
{
return {alloc, capture_as_dependency(std::move(dependencies))};
}
#endif // !_CCCL_COMPILER(NVRTC)
};

_CCCL_SUPPRESS_DEPRECATED_POP
Expand Down
14 changes: 0 additions & 14 deletions thrust/thrust/detail/tuple_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,6 @@ struct tuple_transform_functor;
template <typename Tuple, template <typename> class UnaryMetaFunction, typename UnaryFunction, size_t... Is>
struct tuple_transform_functor<Tuple, UnaryMetaFunction, UnaryFunction, thrust::index_sequence<Is...>>
{
static _CCCL_HOST typename tuple_meta_transform<Tuple, UnaryMetaFunction>::type
do_it_on_the_host(const Tuple& t, UnaryFunction f)
{
using XfrmTuple = typename tuple_meta_transform<Tuple, UnaryMetaFunction>::type;

return XfrmTuple(f(thrust::get<Is>(t))...);
}

static _CCCL_HOST_DEVICE typename tuple_meta_transform<Tuple, UnaryMetaFunction>::type
do_it_on_the_host_or_device(const Tuple& t, UnaryFunction f)
{
Expand All @@ -60,12 +52,6 @@ struct tuple_transform_functor<Tuple, UnaryMetaFunction, UnaryFunction, thrust::
}
};

template <template <typename> class UnaryMetaFunction, typename Tuple, typename UnaryFunction>
typename tuple_meta_transform<Tuple, UnaryMetaFunction>::type tuple_host_transform(const Tuple& t, UnaryFunction f)
{
return tuple_transform_functor<Tuple, UnaryMetaFunction, UnaryFunction>::do_it_on_the_host(t, f);
}

template <template <typename> class UnaryMetaFunction, typename Tuple, typename UnaryFunction>
typename tuple_meta_transform<Tuple, UnaryMetaFunction>::type _CCCL_HOST_DEVICE
tuple_host_device_transform(const Tuple& t, UnaryFunction f)
Expand Down
23 changes: 12 additions & 11 deletions thrust/thrust/detail/type_deduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <thrust/detail/preprocessor.h>

#include <type_traits>
#include <utility>
#include <cuda/std/type_traits>
#include <cuda/std/utility>

///////////////////////////////////////////////////////////////////////////////

/// \def THRUST_FWD(x)
/// \brief Performs universal forwarding of a universal reference.
///
#define THRUST_FWD(x) ::std::forward<decltype(x)>(x)
#define THRUST_FWD(x) ::cuda::std::forward<decltype(x)>(x)

/// \def THRUST_MVCAP(x)
/// \brief Capture `x` into a lambda by moving.
/// deprecated [Since 2.8]
#define THRUST_MVCAP(x) x = ::std::move(x)
#define THRUST_MVCAP(x) x = ::cuda::std::move(x)

/// \def THRUST_RETOF(invocable, ...)
/// \brief Expands to the type returned by invoking an instance of the invocable
Expand All @@ -40,9 +41,9 @@
/// deprecated [Since 2.8]
#define THRUST_RETOF(...) THRUST_PP_DISPATCH(THRUST_RETOF, __VA_ARGS__)
/// deprecated [Since 2.8]
#define THRUST_RETOF1(C) decltype(::std::declval<C>()())
#define THRUST_RETOF1(C) decltype(::cuda::std::declval<C>()())
/// deprecated [Since 2.8]
#define THRUST_RETOF2(C, V) decltype(::std::declval<C>()(::std::declval<V>()))
#define THRUST_RETOF2(C, V) decltype(::cuda::std::declval<C>()(::cuda::std::declval<V>()))

/// \def THRUST_RETURNS(...)
/// \brief Expands to a function definition that returns the expression
Expand Down Expand Up @@ -91,11 +92,11 @@
/**/
#else
/// deprecated [Since 2.8]
# define THRUST_DECLTYPE_RETURNS_WITH_SFINAE_CONDITION(condition, ...) \
noexcept(noexcept(__VA_ARGS__))->typename std::enable_if<condition, decltype(__VA_ARGS__)>::type \
{ \
return (__VA_ARGS__); \
} \
# define THRUST_DECLTYPE_RETURNS_WITH_SFINAE_CONDITION(condition, ...) \
noexcept(noexcept(__VA_ARGS__))->typename ::cuda::std::enable_if<condition, decltype(__VA_ARGS__)>::type \
{ \
return (__VA_ARGS__); \
} \
/**/
#endif

Expand Down
2 changes: 1 addition & 1 deletion thrust/thrust/detail/type_traits/minimum_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct primitive_minimum_type<T, T>
struct any_conversion
{
template <typename T>
operator T();
_CCCL_HOST_DEVICE operator T();
};

} // namespace minimum_type_detail
Expand Down
25 changes: 13 additions & 12 deletions thrust/thrust/detail/type_traits/pointer_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
#include <thrust/detail/type_traits/is_thrust_pointer.h>
#include <thrust/iterator/iterator_traits.h>

#include <cstddef>
#include <type_traits>
#include <cuda/std/cstddef>
#include <cuda/std/type_traits>

THRUST_NAMESPACE_BEGIN
namespace detail
Expand Down Expand Up @@ -62,7 +62,7 @@ struct pointer_difference
template <typename T>
struct pointer_difference<T*>
{
using type = std::ptrdiff_t;
using type = ::cuda::std::ptrdiff_t;
};

template <typename Ptr, typename T>
Expand Down Expand Up @@ -91,7 +91,7 @@ template <template <typename, typename, typename, typename...> class Ptr,
typename T>
struct rebind_pointer<Ptr<OldT, Tag, Ref<OldT, RefTail...>, PtrTail...>, T>
{
// static_assert(std::is_same<OldT, Tag>::value, "0");
// static_assert(::cuda::std::is_same<OldT, Tag>::value, "0");
using type = Ptr<T, Tag, Ref<T, RefTail...>, PtrTail...>;
};

Expand All @@ -107,7 +107,7 @@ template <template <typename, typename, typename, typename...> class Ptr,
typename T>
struct rebind_pointer<Ptr<OldT, Tag, Ref<OldT, RefTail...>, DerivedPtr<OldT, DerivedPtrTail...>>, T>
{
// static_assert(std::is_same<OldT, Tag>::value, "1");
// static_assert(::cuda::std::is_same<OldT, Tag>::value, "1");
using type = Ptr<T, Tag, Ref<T, RefTail...>, DerivedPtr<T, DerivedPtrTail...>>;
};

Expand All @@ -117,10 +117,10 @@ template <template <typename, typename, typename, typename...> class Ptr,
typename Tag,
typename... PtrTail,
typename T>
struct rebind_pointer<Ptr<OldT, Tag, typename std::add_lvalue_reference<OldT>::type, PtrTail...>, T>
struct rebind_pointer<Ptr<OldT, Tag, typename ::cuda::std::add_lvalue_reference<OldT>::type, PtrTail...>, T>
{
// static_assert(std::is_same<OldT, Tag>::value, "2");
using type = Ptr<T, Tag, typename std::add_lvalue_reference<T>::type, PtrTail...>;
// static_assert(::cuda::std::is_same<OldT, Tag>::value, "2");
using type = Ptr<T, Tag, typename ::cuda::std::add_lvalue_reference<T>::type, PtrTail...>;
};

// Rebind `thrust::pointer`-like things with native reference types and templated
Expand All @@ -131,11 +131,12 @@ template <template <typename, typename, typename, typename...> class Ptr,
template <typename...> class DerivedPtr,
typename... DerivedPtrTail,
typename T>
struct rebind_pointer<Ptr<OldT, Tag, typename std::add_lvalue_reference<OldT>::type, DerivedPtr<OldT, DerivedPtrTail...>>,
T>
struct rebind_pointer<
Ptr<OldT, Tag, typename ::cuda::std::add_lvalue_reference<OldT>::type, DerivedPtr<OldT, DerivedPtrTail...>>,
T>
{
// static_assert(std::is_same<OldT, Tag>::value, "3");
using type = Ptr<T, Tag, typename std::add_lvalue_reference<T>::type, DerivedPtr<T, DerivedPtrTail...>>;
// static_assert(::cuda::std::is_same<OldT, Tag>::value, "3");
using type = Ptr<T, Tag, typename ::cuda::std::add_lvalue_reference<T>::type, DerivedPtr<T, DerivedPtrTail...>>;
};

namespace pointer_traits_detail
Expand Down
2 changes: 0 additions & 2 deletions thrust/thrust/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
#include <cuda/functional>
#include <cuda/std/functional>

#include <functional>

THRUST_NAMESPACE_BEGIN

/*! \addtogroup predefined_function_objects Predefined Function Objects
Expand Down
2 changes: 1 addition & 1 deletion thrust/thrust/iterator/constant_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class constant_iterator : public detail::constant_iterator_base<Value, Increment
{
/*! \cond
*/
friend class thrust::iterator_core_access;
friend class iterator_core_access;
using super_t = typename detail::constant_iterator_base<Value, Incrementable, System>::type;
using incrementable = typename detail::constant_iterator_base<Value, Incrementable, System>::incrementable;
using base_iterator = typename detail::constant_iterator_base<Value, Incrementable, System>::base_iterator;
Expand Down
3 changes: 2 additions & 1 deletion thrust/thrust/iterator/counting_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_categories.h>
#include <thrust/iterator/iterator_facade.h>
Expand Down Expand Up @@ -142,7 +143,7 @@ class _CCCL_DECLSPEC_EMPTY_BASES counting_iterator
*/
using super_t = typename detail::counting_iterator_base<Incrementable, System, Traversal, Difference>::type;

friend class thrust::iterator_core_access;
friend class iterator_core_access;

public:
using reference = typename super_t::reference;
Expand Down
2 changes: 1 addition & 1 deletion thrust/thrust/iterator/detail/any_system_tag.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct any_system_tag : thrust::execution_policy<any_system_tag>
// allow any_system_tag to convert to any type at all
// XXX make this safer using enable_if<is_tag<T>> upon c++11
template <typename T>
operator T() const
_CCCL_HOST_DEVICE operator T() const
{
return T();
}
Expand Down
4 changes: 2 additions & 2 deletions thrust/thrust/iterator/detail/discard_iterator_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include <thrust/iterator/detail/any_assign.h>
#include <thrust/iterator/iterator_adaptor.h>

#include <cstddef> // for std::ptrdiff_t
#include <cuda/std/cstddef> // for std::ptrdiff_t

THRUST_NAMESPACE_BEGIN

Expand All @@ -47,7 +47,7 @@ struct discard_iterator_base
// but this interferes with zip_iterator<discard_iterator>
using value_type = any_assign;
using reference = any_assign&;
using incrementable = std::ptrdiff_t;
using incrementable = ::cuda::std::ptrdiff_t;

using base_iterator = typename thrust::counting_iterator<incrementable, System, thrust::random_access_traversal_tag>;

Expand Down
Loading

0 comments on commit 3786a08

Please sign in to comment.