Skip to content

Commit b68eb6b

Browse files
BeanavilNB4444
authored andcommitted
Expose thrust's contiguous iterator unwrap helpers
1 parent 54dad0f commit b68eb6b

File tree

5 files changed

+46
-55
lines changed

5 files changed

+46
-55
lines changed

testing/is_contiguous_iterator.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ template <typename IteratorT,
144144
struct check_unwrapped_iterator
145145
{
146146
using unwrapped_t = typename std::remove_reference<
147-
decltype(thrust::detail::try_unwrap_contiguous_iterator(
147+
decltype(thrust::try_unwrap_contiguous_iterator(
148148
std::declval<IteratorT>()))>::type;
149149

150150
static constexpr bool value =

thrust/system/cuda/detail/adjacent_difference.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ namespace __adjacent_difference {
187187
std::size_t storage_size = 0;
188188
cudaStream_t stream = cuda_cub::stream(policy);
189189

190-
using UnwrapInputIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<InputIt>;
191-
using UnwrapOutputIt = thrust::detail::try_unwrap_contiguous_iterator_return_t<OutputIt>;
190+
using UnwrapInputIt = thrust::try_unwrap_contiguous_iterator_t<InputIt>;
191+
using UnwrapOutputIt = thrust::try_unwrap_contiguous_iterator_t<OutputIt>;
192192

193193
using InputValueT = thrust::iterator_value_t<UnwrapInputIt>;
194194
using OutputValueT = thrust::iterator_value_t<UnwrapOutputIt>;
@@ -198,8 +198,8 @@ namespace __adjacent_difference {
198198
std::is_pointer<UnwrapOutputIt>::value &&
199199
std::is_same<InputValueT, OutputValueT>::value;
200200

201-
auto first_unwrap = thrust::detail::try_unwrap_contiguous_iterator(first);
202-
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
201+
auto first_unwrap = thrust::try_unwrap_contiguous_iterator(first);
202+
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);
203203

204204
thrust::detail::integral_constant<bool, can_compare_iterators> comparable;
205205

thrust/system/cuda/detail/scan_by_key.h

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,14 @@ ValuesOutIt inclusive_scan_by_key_n(
8181
}
8282

8383
// Convert to raw pointers if possible:
84-
using KeysInUnwrapIt =
85-
thrust::detail::try_unwrap_contiguous_iterator_return_t<KeysInIt>;
86-
using ValuesInUnwrapIt =
87-
thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesInIt>;
88-
using ValuesOutUnwrapIt =
89-
thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesOutIt>;
90-
using AccumT = typename thrust::iterator_traits<ValuesInUnwrapIt>::value_type;
91-
92-
auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys);
93-
auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values);
94-
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
84+
using KeysInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<KeysInIt>;
85+
using ValuesInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesInIt>;
86+
using ValuesOutUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesOutIt>;
87+
using AccumT = typename thrust::iterator_traits<ValuesInUnwrapIt>::value_type;
88+
89+
auto keys_unwrap = thrust::try_unwrap_contiguous_iterator(keys);
90+
auto values_unwrap = thrust::try_unwrap_contiguous_iterator(values);
91+
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);
9592

9693
using Dispatch32 = cub::DispatchScanByKey<KeysInUnwrapIt,
9794
ValuesInUnwrapIt,
@@ -195,16 +192,13 @@ ValuesOutIt exclusive_scan_by_key_n(
195192
}
196193

197194
// Convert to raw pointers if possible:
198-
using KeysInUnwrapIt =
199-
thrust::detail::try_unwrap_contiguous_iterator_return_t<KeysInIt>;
200-
using ValuesInUnwrapIt =
201-
thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesInIt>;
202-
using ValuesOutUnwrapIt =
203-
thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesOutIt>;
204-
205-
auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys);
206-
auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values);
207-
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
195+
using KeysInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<KeysInIt>;
196+
using ValuesInUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesInIt>;
197+
using ValuesOutUnwrapIt = thrust::try_unwrap_contiguous_iterator_t<ValuesOutIt>;
198+
199+
auto keys_unwrap = thrust::try_unwrap_contiguous_iterator(keys);
200+
auto values_unwrap = thrust::try_unwrap_contiguous_iterator(values);
201+
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);
208202

209203
using Dispatch32 = cub::DispatchScanByKey<KeysInUnwrapIt,
210204
ValuesInUnwrapIt,

thrust/system/hip/detail/adjacent_difference.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,8 @@ namespace __adjacent_difference
135135
}
136136

137137
// Check if iterators can be compared
138-
using unwrap_input_iterator
139-
= thrust::detail::try_unwrap_contiguous_iterator_return_t<InputIt>;
140-
using unwrap_output_iterator
141-
= thrust::detail::try_unwrap_contiguous_iterator_return_t<OutputIt>;
138+
using unwrap_input_iterator = thrust::try_unwrap_contiguous_iterator_t<InputIt>;
139+
using unwrap_output_iterator = thrust::try_unwrap_contiguous_iterator_t<OutputIt>;
142140

143141
using input_value_type = thrust::iterator_value_t<unwrap_input_iterator>;
144142
using output_value_type = thrust::iterator_value_t<unwrap_output_iterator>;
@@ -149,8 +147,8 @@ namespace __adjacent_difference
149147
&& std::is_same<input_value_type, output_value_type>::value;
150148

151149
// Unwrap iterators to make them comparable
152-
auto first_unwrap = thrust::detail::try_unwrap_contiguous_iterator(first);
153-
auto result_unwrap = thrust::detail::try_unwrap_contiguous_iterator(result);
150+
auto first_unwrap = thrust::try_unwrap_contiguous_iterator(first);
151+
auto result_unwrap = thrust::try_unwrap_contiguous_iterator(result);
154152

155153
thrust::detail::integral_constant<bool, can_compare_iterators> comparable;
156154

thrust/type_traits/is_contiguous_iterator.h

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -229,23 +229,24 @@ struct contiguous_iterator_traits
229229
using raw_pointer = typename thrust::detail::pointer_traits<
230230
decltype(&*std::declval<Iterator>())>::raw_pointer;
231231
};
232+
} // namespace detail
232233

233-
template <typename Iterator>
234-
using contiguous_iterator_raw_pointer_t =
235-
typename contiguous_iterator_traits<Iterator>::raw_pointer;
234+
//! Converts a contiguous iterator type to its underlying raw pointer type.
235+
template <typename ContiguousIterator>
236+
using unwrap_contiguous_iterator_t = typename detail::contiguous_iterator_traits<ContiguousIterator>::raw_pointer;
236237

237-
// Converts a contiguous iterator to a raw pointer:
238-
template <typename Iterator>
239-
THRUST_HOST_DEVICE
240-
contiguous_iterator_raw_pointer_t<Iterator>
241-
contiguous_iterator_raw_pointer_cast(Iterator it)
238+
//! Converts a contiguous iterator to its underlying raw pointer.
239+
template <typename ContiguousIterator>
240+
THRUST_HOST_DEVICE auto unwrap_contiguous_iterator(ContiguousIterator it)
241+
-> unwrap_contiguous_iterator_t<ContiguousIterator>
242242
{
243-
static_assert(thrust::is_contiguous_iterator<Iterator>::value,
244-
"contiguous_iterator_raw_pointer_cast called with "
245-
"non-contiguous iterator.");
243+
static_assert(thrust::is_contiguous_iterator<ContiguousIterator>::value,
244+
"unwrap_contiguous_iterator called with non-contiguous iterator.");
246245
return thrust::raw_pointer_cast(&*it);
247246
}
248247

248+
namespace detail
249+
{
249250
// Implementation for non-contiguous iterators -- passthrough.
250251
template <typename Iterator,
251252
bool IsContiguous = thrust::is_contiguous_iterator<Iterator>::value>
@@ -260,30 +261,28 @@ struct try_unwrap_contiguous_iterator_impl
260261
template <typename Iterator>
261262
struct try_unwrap_contiguous_iterator_impl<Iterator, true /*is_contiguous*/>
262263
{
263-
using type = contiguous_iterator_raw_pointer_t<Iterator>;
264+
using type = unwrap_contiguous_iterator_t<Iterator>;
264265

265266
static THRUST_HOST_DEVICE type get(Iterator it)
266267
{
267-
return contiguous_iterator_raw_pointer_cast(it);
268+
return unwrap_contiguous_iterator(it);
268269
}
269270
};
271+
} // namespace detail
270272

273+
//! Takes an iterator type and, if it is contiguous, yields the raw pointer type it represents. Otherwise returns the
274+
//! iterator type unmodified.
271275
template <typename Iterator>
272-
using try_unwrap_contiguous_iterator_return_t =
273-
typename try_unwrap_contiguous_iterator_impl<Iterator>::type;
276+
using try_unwrap_contiguous_iterator_t = typename detail::try_unwrap_contiguous_iterator_impl<Iterator>::type;
274277

275-
// Casts to a raw pointer if iterator is marked as contiguous, otherwise returns
276-
// the input iterator.
278+
//! Takes an iterator and, if it is contiguous, unwraps it to the raw pointer it represents. Otherwise returns the
279+
//! iterator unmodified.
277280
template <typename Iterator>
278-
THRUST_HOST_DEVICE
279-
try_unwrap_contiguous_iterator_return_t<Iterator>
280-
try_unwrap_contiguous_iterator(Iterator it)
281+
THRUST_HOST_DEVICE auto try_unwrap_contiguous_iterator(Iterator it) -> try_unwrap_contiguous_iterator_t<Iterator>
281282
{
282-
return try_unwrap_contiguous_iterator_impl<Iterator>::get(it);
283+
return detail::try_unwrap_contiguous_iterator_impl<Iterator>::get(it);
283284
}
284285

285-
} // namespace detail
286-
287286
/*! \endcond
288287
*/
289288

0 commit comments

Comments
 (0)