Skip to content

Commit b47cddc

Browse files
BeanavilNB4444
authored andcommitted
Several improvements to zip_iterator/zip_function
1 parent b68eb6b commit b47cddc

File tree

4 files changed

+59
-56
lines changed

4 files changed

+59
-56
lines changed

testing/zip_function.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ struct SumThreeTuple
2828
THRUST_DECLTYPE_RETURNS(thrust::get<0>(x) + thrust::get<1>(x) + thrust::get<2>(x))
2929
}; // end SumThreeTuple
3030

31+
template <typename T>
32+
struct TestZipFunctionCtor
33+
{
34+
void operator()()
35+
{
36+
ASSERT_EQUAL(thrust::zip_function<SumThree>()(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
37+
ASSERT_EQUAL(thrust::zip_function<SumThree>(SumThree{})(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
38+
# ifdef __cpp_deduction_guides
39+
ASSERT_EQUAL(thrust::zip_function(SumThree{})(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
40+
# endif // __cpp_deduction_guides
41+
}
42+
};
43+
SimpleUnitTest<TestZipFunctionCtor, type_list<int>> TestZipFunctionCtorInstance;
44+
3145
template <typename T>
3246
struct TestZipFunctionTransform
3347
{

testing/zip_iterator.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ template<typename T>
3232

3333
// test construction
3434
ZipIterator iter0 = make_zip_iterator(t);
35+
ASSERT_EQUAL(true, iter0 == ZipIterator{t});
3536

3637
ASSERT_EQUAL_QUIET(v0.begin(), get<0>(iter0.get_iterator_tuple()));
3738
ASSERT_EQUAL_QUIET(v1.begin(), get<1>(iter0.get_iterator_tuple()));

thrust/iterator/zip_iterator.h

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,20 @@ THRUST_NAMESPACE_BEGIN
6262
* #include <thrust/tuple.h>
6363
* #include <thrust/device_vector.h>
6464
* ...
65-
* thrust::device_vector<int> int_v(3);
66-
* int_v[0] = 0; int_v[1] = 1; int_v[2] = 2;
65+
* thrust::device_vector<int> int_v{0, 1, 2};
66+
* thrust::device_vector<float> float_v{0.0f, 1.0f, 2.0f};
67+
* thrust::device_vector<char> char_v{'a', 'b', 'c'};
6768
*
68-
* thrust::device_vector<float> float_v(3);
69-
* float_v[0] = 0.0f; float_v[1] = 1.0f; float_v[2] = 2.0f;
69+
* // aliases for iterators
70+
* using IntIterator = thrust::device_vector<int>::iterator;
71+
* using FloatIterator = thrust::device_vector<float>::iterator;
72+
* using CharIterator = thrust::device_vector<char>::iterator;
7073
*
71-
* thrust::device_vector<char> char_v(3);
72-
* char_v[0] = 'a'; char_v[1] = 'b'; char_v[2] = 'c';
73-
*
74-
* // typedef these iterators for shorthand
75-
* typedef thrust::device_vector<int>::iterator IntIterator;
76-
* typedef thrust::device_vector<float>::iterator FloatIterator;
77-
* typedef thrust::device_vector<char>::iterator CharIterator;
78-
*
79-
* // typedef a tuple of these iterators
80-
* typedef thrust::tuple<IntIterator, FloatIterator, CharIterator> IteratorTuple;
74+
* // alias for a tuple of these iterators
75+
* using IteratorTuple = thrust::tuple<IntIterator, FloatIterator, CharIterator>;
8176
*
8277
* // typedef the zip_iterator of this tuple
83-
* typedef thrust::zip_iterator<IteratorTuple> ZipIterator;
78+
* using ZipIterator = thrust::zip_iterator<IteratorTuple>;
8479
*
8580
* // finally, create the zip_iterator
8681
* ZipIterator iter(thrust::make_tuple(int_v.begin(), float_v.begin(), char_v.begin()));
@@ -109,15 +104,8 @@ THRUST_NAMESPACE_BEGIN
109104
*
110105
* int main()
111106
* {
112-
* thrust::device_vector<int> int_in(3), int_out(3);
113-
* int_in[0] = 0;
114-
* int_in[1] = 1;
115-
* int_in[2] = 2;
116-
*
117-
* thrust::device_vector<float> float_in(3), float_out(3);
118-
* float_in[0] = 0.0f;
119-
* float_in[1] = 10.0f;
120-
* float_in[2] = 20.0f;
107+
* thrust::device_vector<int> int_in{0, 1, 2}, int_out(3);
108+
* thrust::device_vector<float> float_in{0.0f, 10.0f, 20.0f}, float_out(3);
121109
*
122110
* thrust::copy(thrust::make_zip_iterator(thrust::make_tuple(int_in.begin(), float_in.begin())),
123111
* thrust::make_zip_iterator(thrust::make_tuple(int_in.end(), float_in.end())),
@@ -140,6 +128,10 @@ template <typename IteratorTuple>
140128
: public detail::zip_iterator_base<IteratorTuple>::type
141129
{
142130
public:
131+
/*! The underlying iterator tuple type. Alias to zip_iterator's first template argument.
132+
*/
133+
using iterator_tuple = IteratorTuple;
134+
143135
/*! Default constructor does nothing.
144136
*/
145137
#if THRUST_HOST_COMPILER == THRUST_HOST_COMPILER_MSVC && THRUST_MSVC_VERSION < 1920

thrust/zip_function.h

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -94,54 +94,40 @@ THRUST_DECLTYPE_RETURNS(
9494
* #include <thrust/zip_function.h>
9595
*
9696
* struct SumTuple {
97-
* float operator()(Tuple tup) {
98-
* return std::get<0>(tup) + std::get<1>(tup) + std::get<2>(tup);
97+
* float operator()(auto tup) const {
98+
* return thrust::get<0>(tup) + thrust::get<1>(tup) + thrust::get<2>(tup);
9999
* }
100100
* };
101101
* struct SumArgs {
102-
* float operator()(float a, float b, float c) {
102+
* float operator()(float a, float b, float c) const {
103103
* return a + b + c;
104104
* }
105105
* };
106106
*
107107
* int main() {
108-
* thrust::device_vector<float> A(3);
109-
* thrust::device_vector<float> B(3);
110-
* thrust::device_vector<float> C(3);
108+
* thrust::device_vector<float> A{0.f, 1.f, 2.f};
109+
* thrust::device_vector<float> B{1.f, 2.f, 3.f};
110+
* thrust::device_vector<float> C{2.f, 3.f, 4.f};
111111
* thrust::device_vector<float> D(3);
112-
* A[0] = 0.f; A[1] = 1.f; A[2] = 2.f;
113-
* B[0] = 1.f; B[1] = 2.f; B[2] = 3.f;
114-
* C[0] = 2.f; C[1] = 3.f; C[2] = 4.f;
115112
*
116-
* // The following four invocations of transform are equivalent
113+
* auto begin = thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin()));
114+
* auto end = thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end()));
115+
*
116+
* // The following four invocations of transform are equivalent:
117117
* // Transform with 3-tuple
118-
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
119-
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
120-
* D.begin(),
121-
* SumTuple{});
118+
* thrust::transform(begin, end, D.begin(), SumTuple{});
122119
*
123120
* // Transform with 3 parameters
124121
* thrust::zip_function<SumArgs> adapted{};
125-
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
126-
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
127-
* D.begin(),
128-
* adapted);
122+
* thrust::transform(begin, end, D.begin(), adapted);
129123
*
130124
* // Transform with 3 parameters with convenience function
131-
* thrust::zip_function<SumArgs> adapted{};
132-
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
133-
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
134-
* D.begin(),
135-
* thrust::make_zip_function(SumArgs{}));
125+
* thrust::transform(begin, end, D.begin(), thrust::make_zip_function(SumArgs{}));
136126
*
137127
* // Transform with 3 parameters with convenience function and lambda
138-
* thrust::zip_function<SumArgs> adapted{};
139-
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
140-
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
141-
* D.begin(),
142-
* thrust::make_zip_function([] (float a, float b, float c) {
143-
* return a + b + c;
144-
* }));
128+
* thrust::transform(begin, end, D.begin(), thrust::make_zip_function([] (float a, float b, float c) {
129+
* return a + b + c;
130+
* }));
145131
* return 0;
146132
* }
147133
* \endcode
@@ -153,9 +139,13 @@ template <typename Function>
153139
class zip_function
154140
{
155141
public:
142+
//! Default constructs the contained function object.
143+
zip_function() = default;
144+
156145
/*! Constructs a \p zip_function with the provided function object \p func. */
157-
THRUST_HOST_DEVICE
158-
zip_function(Function func) : func(std::move(func)) {}
146+
THRUST_HOST_DEVICE zip_function(Function func)
147+
: func(std::move(func))
148+
{}
159149

160150
/*! Applies the N-ary function object to elements of the tuple \p args. */
161151
// Add workaround for decltype(auto) on C++11-only compilers:
@@ -183,6 +173,12 @@ class zip_function
183173

184174
#endif // THRUST_CPP_DIALECT
185175

176+
//! Returns a reference to the underlying function.
177+
THRUST_HOST_DEVICE Function& underlying_function() const
178+
{
179+
return func;
180+
}
181+
186182
private:
187183
mutable Function func;
188184
};

0 commit comments

Comments
 (0)