Skip to content

Commit fb83b4a

Browse files
Several improvements to zip_iterator/zip_function (#1710)
* Improve zip_iterator documentation * Re-expose zip_iterator's IteratorTuple * Test zip_iterator construction from iterator tuple * Expose zip_function's underlying function * Add default ctor for zip_function * Simplify zip_function documentation example
1 parent 3104dd0 commit fb83b4a

File tree

4 files changed

+56
-54
lines changed

4 files changed

+56
-54
lines changed

thrust/testing/zip_function.cu

+14
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ struct SumThreeTuple
2828
THRUST_DECLTYPE_RETURNS(thrust::get<0>(x) + thrust::get<1>(x) + thrust::get<2>(x))
2929
}; // end SumThreeTuple
3030

31+
template <typename T>
32+
struct TestZipFunctionCtor
33+
{
34+
void operator()()
35+
{
36+
ASSERT_EQUAL(thrust::zip_function<SumThree>()(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
37+
ASSERT_EQUAL(thrust::zip_function<SumThree>(SumThree{})(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
38+
# ifdef __cpp_deduction_guides
39+
ASSERT_EQUAL(thrust::zip_function(SumThree{})(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
40+
# endif // __cpp_deduction_guides
41+
}
42+
};
43+
SimpleUnitTest<TestZipFunctionCtor, type_list<int>> TestZipFunctionCtorInstance;
44+
3145
template <typename T>
3246
struct TestZipFunctionTransform
3347
{

thrust/testing/zip_iterator.cu

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct TestZipIteratorManipulation
3535

3636
// test construction
3737
ZipIterator iter0 = make_zip_iterator(t);
38+
ASSERT_EQUAL(true, iter0 == ZipIterator{t});
3839

3940
ASSERT_EQUAL_QUIET(v0.begin(), get<0>(iter0.get_iterator_tuple()));
4041
ASSERT_EQUAL_QUIET(v1.begin(), get<1>(iter0.get_iterator_tuple()));

thrust/thrust/iterator/zip_iterator.h

+16-24
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,20 @@ THRUST_NAMESPACE_BEGIN
6969
* #include <thrust/tuple.h>
7070
* #include <thrust/device_vector.h>
7171
* ...
72-
* thrust::device_vector<int> int_v(3);
73-
* int_v[0] = 0; int_v[1] = 1; int_v[2] = 2;
72+
* thrust::device_vector<int> int_v{0, 1, 2};
73+
* thrust::device_vector<float> float_v{0.0f, 1.0f, 2.0f};
74+
* thrust::device_vector<char> char_v{'a', 'b', 'c'};
7475
*
75-
* thrust::device_vector<float> float_v(3);
76-
* float_v[0] = 0.0f; float_v[1] = 1.0f; float_v[2] = 2.0f;
76+
* // aliases for iterators
77+
* using IntIterator = thrust::device_vector<int>::iterator;
78+
* using FloatIterator = thrust::device_vector<float>::iterator;
79+
* using CharIterator = thrust::device_vector<char>::iterator;
7780
*
78-
* thrust::device_vector<char> char_v(3);
79-
* char_v[0] = 'a'; char_v[1] = 'b'; char_v[2] = 'c';
80-
*
81-
* // typedef these iterators for shorthand
82-
* typedef thrust::device_vector<int>::iterator IntIterator;
83-
* typedef thrust::device_vector<float>::iterator FloatIterator;
84-
* typedef thrust::device_vector<char>::iterator CharIterator;
85-
*
86-
* // typedef a tuple of these iterators
87-
* typedef thrust::tuple<IntIterator, FloatIterator, CharIterator> IteratorTuple;
81+
* // alias for a tuple of these iterators
82+
* using IteratorTuple = thrust::tuple<IntIterator, FloatIterator, CharIterator>;
8883
*
8984
* // typedef the zip_iterator of this tuple
90-
* typedef thrust::zip_iterator<IteratorTuple> ZipIterator;
85+
* using ZipIterator = thrust::zip_iterator<IteratorTuple>;
9186
*
9287
* // finally, create the zip_iterator
9388
* ZipIterator iter(thrust::make_tuple(int_v.begin(), float_v.begin(), char_v.begin()));
@@ -116,15 +111,8 @@ THRUST_NAMESPACE_BEGIN
116111
*
117112
* int main()
118113
* {
119-
* thrust::device_vector<int> int_in(3), int_out(3);
120-
* int_in[0] = 0;
121-
* int_in[1] = 1;
122-
* int_in[2] = 2;
123-
*
124-
* thrust::device_vector<float> float_in(3), float_out(3);
125-
* float_in[0] = 0.0f;
126-
* float_in[1] = 10.0f;
127-
* float_in[2] = 20.0f;
114+
* thrust::device_vector<int> int_in{0, 1, 2}, int_out(3);
115+
* thrust::device_vector<float> float_in{0.0f, 10.0f, 20.0f}, float_out(3);
128116
*
129117
* thrust::copy(thrust::make_zip_iterator(thrust::make_tuple(int_in.begin(), float_in.begin())),
130118
* thrust::make_zip_iterator(thrust::make_tuple(int_in.end(), float_in.end())),
@@ -146,6 +134,10 @@ template <typename IteratorTuple>
146134
class zip_iterator : public detail::zip_iterator_base<IteratorTuple>::type
147135
{
148136
public:
137+
/*! The underlying iterator tuple type. Alias to zip_iterator's first template argument.
138+
*/
139+
using iterator_tuple = IteratorTuple;
140+
149141
/*! Default constructor does nothing.
150142
*/
151143
#if defined(_CCCL_COMPILER_MSVC_2017)

thrust/thrust/zip_function.h

+25-30
Original file line numberDiff line numberDiff line change
@@ -95,54 +95,40 @@ _CCCL_HOST_DEVICE auto apply_impl(Function&& func, Tuple&& args, index_sequence<
9595
* #include <thrust/zip_function.h>
9696
*
9797
* struct SumTuple {
98-
* float operator()(Tuple tup) {
99-
* return std::get<0>(tup) + std::get<1>(tup) + std::get<2>(tup);
98+
* float operator()(auto tup) const {
99+
* return thrust::get<0>(tup) + thrust::get<1>(tup) + thrust::get<2>(tup);
100100
* }
101101
* };
102102
* struct SumArgs {
103-
* float operator()(float a, float b, float c) {
103+
* float operator()(float a, float b, float c) const {
104104
* return a + b + c;
105105
* }
106106
* };
107107
*
108108
* int main() {
109-
* thrust::device_vector<float> A(3);
110-
* thrust::device_vector<float> B(3);
111-
* thrust::device_vector<float> C(3);
109+
* thrust::device_vector<float> A{0.f, 1.f, 2.f};
110+
* thrust::device_vector<float> B{1.f, 2.f, 3.f};
111+
* thrust::device_vector<float> C{2.f, 3.f, 4.f};
112112
* thrust::device_vector<float> D(3);
113-
* A[0] = 0.f; A[1] = 1.f; A[2] = 2.f;
114-
* B[0] = 1.f; B[1] = 2.f; B[2] = 3.f;
115-
* C[0] = 2.f; C[1] = 3.f; C[2] = 4.f;
116113
*
117-
* // The following four invocations of transform are equivalent
114+
* auto begin = thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin()));
115+
* auto end = thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end()));
116+
*
117+
* // The following four invocations of transform are equivalent:
118118
* // Transform with 3-tuple
119-
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
120-
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
121-
* D.begin(),
122-
* SumTuple{});
119+
* thrust::transform(begin, end, D.begin(), SumTuple{});
123120
*
124121
* // Transform with 3 parameters
125122
* thrust::zip_function<SumArgs> adapted{};
126-
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
127-
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
128-
* D.begin(),
129-
* adapted);
123+
* thrust::transform(begin, end, D.begin(), adapted);
130124
*
131125
* // Transform with 3 parameters with convenience function
132-
* thrust::zip_function<SumArgs> adapted{};
133-
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
134-
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
135-
* D.begin(),
136-
* thrust::make_zip_function(SumArgs{}));
126+
* thrust::transform(begin, end, D.begin(), thrust::make_zip_function(SumArgs{}));
137127
*
138128
* // Transform with 3 parameters with convenience function and lambda
139-
* thrust::zip_function<SumArgs> adapted{};
140-
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
141-
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
142-
* D.begin(),
143-
* thrust::make_zip_function([] (float a, float b, float c) {
144-
* return a + b + c;
145-
* }));
129+
* thrust::transform(begin, end, D.begin(), thrust::make_zip_function([] (float a, float b, float c) {
130+
* return a + b + c;
131+
* }));
146132
* return 0;
147133
* }
148134
* \endcode
@@ -154,6 +140,9 @@ template <typename Function>
154140
class zip_function
155141
{
156142
public:
143+
//! Default constructs the contained function object.
144+
zip_function() = default;
145+
157146
_CCCL_HOST_DEVICE zip_function(Function func)
158147
: func(std::move(func))
159148
{}
@@ -181,6 +170,12 @@ class zip_function
181170

182171
# endif // _CCCL_STD_VER
183172

173+
//! Returns a reference to the underlying function.
174+
_CCCL_HOST_DEVICE Function& underlying_function() const
175+
{
176+
return func;
177+
}
178+
184179
private:
185180
mutable Function func;
186181
};

0 commit comments

Comments
 (0)