Skip to content

Commit a1d8b31

Browse files
Expose thrust's contiguous iterator unwrap helpers (#1717)
Lifts the following functions into the public API and renames them: * contiguous_iterator_raw_pointer_t -> unwrap_contiguous_iterator_t * contiguous_iterator_raw_pointer_cast -> unwrap_contiguous_iterator * try_unwrap_contiguous_iterator_return_t -> try_unwrap_contiguous_iterator_t * try_unwrap_contiguous_iterator Fixes: #1711 Co-authored-by: Michael Schellenberger Costa <[email protected]>
1 parent fb83b4a commit a1d8b31

File tree

4 files changed

+40
-35
lines changed

4 files changed

+40
-35
lines changed

thrust/testing/is_contiguous_iterator.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ struct expect_passthrough
8989
template <typename IteratorT, typename PointerT, typename expected_unwrapped_type /* = expect_[pointer|passthrough] */>
9090
struct check_unwrapped_iterator
9191
{
92-
using unwrapped_t = typename std::remove_reference<decltype(thrust::detail::try_unwrap_contiguous_iterator(
93-
std::declval<IteratorT>()))>::type;
92+
using unwrapped_t = ::cuda::std::__libcpp_remove_reference_t<decltype(thrust::try_unwrap_contiguous_iterator(
93+
cuda::std::declval<IteratorT>()))>;
9494

9595
static constexpr bool value =
9696
std::is_same<expected_unwrapped_type, expect_pointer>::value

thrust/thrust/system/cuda/detail/adjacent_difference.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,17 @@ adjacent_difference(execution_policy<Derived>& policy, InputIt first, InputIt la
151151
std::size_t storage_size = 0;
152152
cudaStream_t stream = cuda_cub::stream(policy);
153153

154-
using UnwrapInputIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<InputIt>;
155-
using UnwrapOutputIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<OutputIt>;
154+
using UnwrapInputIt = thrust::try_unwrap_contiguous_iterator_t<InputIt>;
155+
using UnwrapOutputIt = thrust::try_unwrap_contiguous_iterator_t<OutputIt>;
156156

157157
using InputValueT = thrust::iterator_value_t<UnwrapInputIt>;
158158
using OutputValueT = thrust::iterator_value_t<UnwrapOutputIt>;
159159

160160
constexpr bool can_compare_iterators = std::is_pointer<UnwrapInputIt>::value && std::is_pointer<UnwrapOutputIt>::value
161161
&& std::is_same<InputValueT, OutputValueT>::value;
162162

163-
auto first_unwrap = thrust::detail::try_unwrap_contiguous_iterator(first);
164-
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
163+
auto first_unwrap = thrust::try_unwrap_contiguous_iterator(first);
164+
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);
165165

166166
thrust::detail::integral_constant<bool, can_compare_iterators> comparable;
167167

thrust/thrust/system/cuda/detail/scan_by_key.h

+12-12
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ _CCCL_HOST_DEVICE ValuesOutIt inclusive_scan_by_key_n(
8585
}
8686

8787
// Convert to raw pointers if possible:
88-
using KeysInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<KeysInIt>;
89-
using ValuesInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesInIt>;
90-
using ValuesOutUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesOutIt>;
88+
using KeysInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<KeysInIt>;
89+
using ValuesInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesInIt>;
90+
using ValuesOutUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesOutIt>;
9191
using AccumT = typename thrust::iterator_traits<ValuesInUnwrapIt>::value_type;
9292

93-
auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys);
94-
auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values);
95-
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
93+
auto keys_unwrap = thrust::try_unwrap_contiguous_iterator(keys);
94+
auto values_unwrap = thrust::try_unwrap_contiguous_iterator(values);
95+
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);
9696

9797
using Dispatch32 = cub::DispatchScanByKey<
9898
KeysInUnwrapIt,
@@ -195,13 +195,13 @@ _CCCL_HOST_DEVICE ValuesOutIt exclusive_scan_by_key_n(
195195
}
196196

197197
// Convert to raw pointers if possible:
198-
using KeysInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<KeysInIt>;
199-
using ValuesInUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesInIt>;
200-
using ValuesOutUnwrapIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesOutIt>;
198+
using KeysInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<KeysInIt>;
199+
using ValuesInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesInIt>;
200+
using ValuesOutUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesOutIt>;
201201

202-
auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys);
203-
auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values);
204-
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
202+
auto keys_unwrap = thrust::try_unwrap_contiguous_iterator(keys);
203+
auto values_unwrap = thrust::try_unwrap_contiguous_iterator(values);
204+
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);
205205

206206
using Dispatch32 = cub::DispatchScanByKey<
207207
KeysInUnwrapIt,

thrust/thrust/type_traits/is_contiguous_iterator.h

+22-17
Original file line numberDiff line numberDiff line change
@@ -221,20 +221,24 @@ struct contiguous_iterator_traits
221221

222222
using raw_pointer = typename thrust::detail::pointer_traits<decltype(&*std::declval<Iterator>())>::raw_pointer;
223223
};
224+
} // namespace detail
224225

225-
template <typename Iterator>
226-
using contiguous_iterator_raw_pointer_t = typename contiguous_iterator_traits<Iterator>::raw_pointer;
226+
//! Converts a contiguous iterator type to its underlying raw pointer type.
227+
template <typename ContiguousIterator>
228+
using unwrap_contiguous_iterator_t = typename detail::contiguous_iterator_traits<ContiguousIterator>::raw_pointer;
227229

228-
// Converts a contiguous iterator to a raw pointer:
229-
template <typename Iterator>
230-
_CCCL_HOST_DEVICE contiguous_iterator_raw_pointer_t<Iterator> contiguous_iterator_raw_pointer_cast(Iterator it)
230+
//! Converts a contiguous iterator to its underlying raw pointer.
231+
template <typename ContiguousIterator>
232+
_CCCL_HOST_DEVICE auto unwrap_contiguous_iterator(ContiguousIterator it)
233+
-> unwrap_contiguous_iterator_t<ContiguousIterator>
231234
{
232-
static_assert(thrust::is_contiguous_iterator<Iterator>::value,
233-
"contiguous_iterator_raw_pointer_cast called with "
234-
"non-contiguous iterator.");
235+
static_assert(thrust::is_contiguous_iterator<ContiguousIterator>::value,
236+
"unwrap_contiguous_iterator called with non-contiguous iterator.");
235237
return thrust::raw_pointer_cast(&*it);
236238
}
237239

240+
namespace detail
241+
{
238242
// Implementation for non-contiguous iterators -- passthrough.
239243
template <typename Iterator, bool IsContiguous = thrust::is_contiguous_iterator<Iterator>::value>
240244
struct try_unwrap_contiguous_iterator_impl
@@ -251,27 +255,28 @@ struct try_unwrap_contiguous_iterator_impl
251255
template <typename Iterator>
252256
struct try_unwrap_contiguous_iterator_impl<Iterator, true /*is_contiguous*/>
253257
{
254-
using type = contiguous_iterator_raw_pointer_t<Iterator>;
258+
using type = unwrap_contiguous_iterator_t<Iterator>;
255259

256260
static _CCCL_HOST_DEVICE type get(Iterator it)
257261
{
258-
return contiguous_iterator_raw_pointer_cast(it);
262+
return unwrap_contiguous_iterator(it);
259263
}
260264
};
265+
} // namespace detail
261266

267+
//! Takes an iterator type and, if it is contiguous, yields the raw pointer type it represents. Otherwise returns the
268+
//! iterator type unmodified.
262269
template <typename Iterator>
263-
using try_unwrap_contiguous_iterator_return_t = typename try_unwrap_contiguous_iterator_impl<Iterator>::type;
270+
using try_unwrap_contiguous_iterator_t = typename detail::try_unwrap_contiguous_iterator_impl<Iterator>::type;
264271

265-
// Casts to a raw pointer if iterator is marked as contiguous, otherwise returns
266-
// the input iterator.
272+
//! Takes an iterator and, if it is contiguous, unwraps it to the raw pointer it represents. Otherwise returns the
273+
//! iterator unmodified.
267274
template <typename Iterator>
268-
_CCCL_HOST_DEVICE try_unwrap_contiguous_iterator_return_t<Iterator> try_unwrap_contiguous_iterator(Iterator it)
275+
_CCCL_HOST_DEVICE auto try_unwrap_contiguous_iterator(Iterator it) -> try_unwrap_contiguous_iterator_t<Iterator>
269276
{
270-
return try_unwrap_contiguous_iterator_impl<Iterator>::get(it);
277+
return detail::try_unwrap_contiguous_iterator_impl<Iterator>::get(it);
271278
}
272279

273-
} // namespace detail
274-
275280
/*! \endcond
276281
*/
277282

0 commit comments

Comments
 (0)