Skip to content

Commit

Permalink
Fix transform_iterator<identity> and drop result_of_adaptable_function (
Browse files Browse the repository at this point in the history
#3652)

Thrust's transform_iterator relied on thrust::identity<T>::result_type via result_of_adaptable_function to avoid dangling references. However, thrust::identity<void>, cuda::std::identity, and the placeholder expression _1 have the same issue without a workaround. This change centralizes and adds workarounds for all of these by introducing a new trait thrust::detail::transform_iterator_reference, which decays the return value type of any of the aforementioned function objects.
  • Loading branch information
bernhardmgruber committed Feb 5, 2025
1 parent 99d48d3 commit 72223bd
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 95 deletions.
24 changes: 13 additions & 11 deletions thrust/testing/transform_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,9 @@ void TestTransformIteratorReferenceAndValueType()
static_assert(is_same<decltype(it_tr_tid)::value_type, bool>::value, "");
(void) it_tr_tid;

auto it_tr_cid = thrust::make_transform_iterator(it, cuda::std::__identity{});
static_assert(is_same<decltype(it_tr_cid)::reference, bool&&>::value, ""); // inferred, like forward
auto it_tr_cid = thrust::make_transform_iterator(it, cuda::std::identity{});
static_assert(is_same<decltype(it_tr_cid)::reference, bool>::value, ""); // special handling by
// transform_iterator_reference
static_assert(is_same<decltype(it_tr_cid)::value_type, bool>::value, "");
(void) it_tr_cid;
}
Expand Down Expand Up @@ -196,8 +197,9 @@ void TestTransformIteratorReferenceAndValueType()
static_assert(is_same<decltype(it_tr_tid)::value_type, bool>::value, "");
(void) it_tr_tid;

auto it_tr_cid = thrust::make_transform_iterator(it, cuda::std::__identity{});
static_assert(is_same<decltype(it_tr_cid)::reference, bool&&>::value, ""); // inferred, like forward
auto it_tr_cid = thrust::make_transform_iterator(it, cuda::std::identity{});
static_assert(is_same<decltype(it_tr_cid)::reference, bool>::value, ""); // special handling by
// transform_iterator_reference
static_assert(is_same<decltype(it_tr_cid)::value_type, bool>::value, "");
(void) it_tr_cid;
}
Expand Down Expand Up @@ -234,8 +236,9 @@ void TestTransformIteratorReferenceAndValueType()
static_assert(is_same<decltype(it_tr_tid)::value_type, bool>::value, "");
(void) it_tr_tid;

auto it_tr_cid = thrust::make_transform_iterator(it, cuda::std::__identity{});
static_assert(is_same<decltype(it_tr_cid)::reference, bool&&>::value, ""); // inferred, like forward
auto it_tr_cid = thrust::make_transform_iterator(it, cuda::std::identity{});
static_assert(is_same<decltype(it_tr_cid)::reference, bool>::value, ""); // special handling by
// transform_iterator_reference
static_assert(is_same<decltype(it_tr_cid)::value_type, bool>::value, "");
(void) it_tr_cid;
}
Expand All @@ -247,11 +250,10 @@ void TestTransformIteratorIdentity()
thrust::device_vector<int> v(3, 42);

ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), thrust::identity<int>{}), 42);
// FIXME(bgruber): fix transform_iterator to get these tests compiling:
// ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), thrust::identity<>{}), 42);
// ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), cuda::std::identity{}), 42);
// using namespace thrust::placeholders;
// ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), _1), 42);
ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), thrust::identity<>{}), 42);
ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), cuda::std::identity{}), 42);
using namespace thrust::placeholders;
ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), _1), 42);
}

DECLARE_UNITTEST(TestTransformIteratorIdentity);
7 changes: 0 additions & 7 deletions thrust/thrust/detail/functional/actor.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
# pragma system_header
#endif // no system header
#include <thrust/detail/type_deduction.h>
#include <thrust/detail/type_traits/result_of_adaptable_function.h>
#include <thrust/tuple.h>

#include <cuda/std/type_traits>
Expand Down Expand Up @@ -211,11 +210,5 @@ _CCCL_HOST_DEVICE auto compose(Eval e, const SubExpr1& subexpr1, const SubExpr2&
{{::cuda::std::move(e)}, make_actor(subexpr1), make_actor(subexpr2)}};
}
} // namespace functional

template <typename Eval, typename... Args>
struct result_of_adaptable_function<functional::actor<Eval>(Args...)>
{
using type = decltype(::cuda::std::declval<functional::actor<Eval>>()(::cuda::std::declval<Args>()...));
};
} // namespace detail
THRUST_NAMESPACE_END
1 change: 0 additions & 1 deletion thrust/thrust/detail/functional/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <thrust/detail/type_traits/result_of_adaptable_function.h>
#include <thrust/functional.h>
#include <thrust/tuple.h>

Expand Down
67 changes: 0 additions & 67 deletions thrust/thrust/detail/type_traits/result_of_adaptable_function.h

This file was deleted.

44 changes: 35 additions & 9 deletions thrust/thrust/iterator/detail/transform_iterator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <thrust/detail/type_traits/result_of_adaptable_function.h>
#include <thrust/detail/functional/actor.h>
#include <thrust/functional.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_traits.h>

#include <cuda/std/functional>
#include <cuda/std/type_traits>

THRUST_NAMESPACE_BEGIN
Expand All @@ -39,19 +41,43 @@ class transform_iterator;
namespace detail
{

template <class UnaryFunc, class Iterator>
struct transform_iterator_reference
{
// by default, dereferencing the iterator yields the same as the function.
using type = decltype(::cuda::std::declval<UnaryFunc>()(::cuda::std::declval<iterator_value_t<Iterator>>()));
};

// for certain function objects, we need to tweak the reference type. Notably, identity functions must decay to values.
// See the implementation of transform_iterator<...>::dereference() for several comments on why this is necessary.
template <typename T, class Iterator>
struct transform_iterator_reference<identity<T>, Iterator>
{
using type = T;
};
template <class Iterator>
struct transform_iterator_reference<identity<>, Iterator>
{
using type = iterator_value_t<Iterator>;
};
template <class Iterator>
struct transform_iterator_reference<::cuda::std::identity, Iterator>
{
using type = iterator_value_t<Iterator>;
};
template <typename Eval, class Iterator>
struct transform_iterator_reference<functional::actor<Eval>, Iterator>
{
using type = ::cuda::std::remove_reference_t<decltype(::cuda::std::declval<functional::actor<Eval>>()(
::cuda::std::declval<iterator_value_t<Iterator>>()))>;
};

// Type function to compute the iterator_adaptor instantiation to be used for transform_iterator
template <class UnaryFunc, class Iterator, class Reference, class Value>
struct make_transform_iterator_base
{
private:
// FIXME(bgruber): the next line should be correct, but thrust::identity<T> lies and advertises a ::return_type of T,
// while its operator() returns const T& (which __invoke_of correctly detects), which causes transform_iterator to
// crash (or cause UB) during dereferencing. Check the test `thrust.test.dereference` for the OMP and TBB backends.
// using wrapped_func_ret_t = ::cuda::std::__invoke_of<UnaryFunc, iterator_value_t<Iterator>>;
using wrapped_func_ret_t = result_of_adaptable_function<UnaryFunc(iterator_value_t<Iterator>)>;

// By default, dereferencing the iterator yields the same as the function.
using reference = typename ia_dflt_help<Reference, wrapped_func_ret_t>::type;
using reference = typename ia_dflt_help<Reference, transform_iterator_reference<UnaryFunc, Iterator>>::type;
using value_type = typename ia_dflt_help<Value, ::cuda::std::remove_cvref<reference>>::type;

public:
Expand Down

0 comments on commit 72223bd

Please sign in to comment.