@@ -97,159 +97,8 @@ template <typename Group, size_t Extent> class group_with_scratchpad {
97
97
sycl::span<std::byte, Extent> get_memory () const { return scratch; }
98
98
};
99
99
100
- // Default sorter provided by the first version of the extension specification.
101
- template <typename Compare = std::less<>> class default_sorter {
102
- Compare comp;
103
- sycl::span<std::byte> scratch;
104
-
105
- public:
106
- template <size_t Extent>
107
- default_sorter (sycl::span<std::byte, Extent> scratch_,
108
- Compare comp_ = Compare())
109
- : comp(comp_), scratch(scratch_) {}
110
-
111
- template <typename Group, typename Ptr>
112
- void operator ()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
113
- [[maybe_unused]] Ptr last) {
114
- #ifdef __SYCL_DEVICE_ONLY__
115
- using T = typename sycl::detail::GetValueType<Ptr>::type;
116
- size_t n = std::distance (first, last);
117
- T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
118
- sycl::detail::merge_sort (g, first, n, comp, scratch_begin);
119
- #else
120
- throw sycl::exception (
121
- std::error_code (PI_ERROR_INVALID_DEVICE, sycl::sycl_category ()),
122
- " default_sorter constructor is not supported on host device." );
123
- #endif
124
- }
125
-
126
- template <typename Group, typename T>
127
- T operator ()([[maybe_unused]] Group g, T val) {
128
- #ifdef __SYCL_DEVICE_ONLY__
129
- std::size_t local_id = g.get_local_linear_id ();
130
- auto range_size = g.get_local_range ().size ();
131
- T *scratch_begin = sycl::detail::align_scratch<T>(
132
- scratch, g, /* output storage and temporary storage */ 2 * range_size);
133
- scratch_begin[local_id] = val;
134
- sycl::detail::merge_sort (g, scratch_begin, range_size, comp,
135
- scratch_begin + range_size);
136
- val = scratch_begin[local_id];
137
- #else
138
- throw sycl::exception (
139
- std::error_code (PI_ERROR_INVALID_DEVICE, sycl::sycl_category ()),
140
- " default_sorter operator() is not supported on host device." );
141
- #endif
142
- return val;
143
- }
144
-
145
- template <typename T>
146
- static constexpr size_t memory_required (sycl::memory_scope,
147
- size_t range_size) {
148
- return range_size * sizeof (T) + alignof (T);
149
- }
150
-
151
- template <typename T, int dim = 1 >
152
- static constexpr size_t memory_required (sycl::memory_scope scope,
153
- sycl::range<dim> r) {
154
- return 2 * memory_required<T>(scope, r.size ());
155
- }
156
- };
157
-
158
100
enum class sorting_order { ascending, descending };
159
101
160
- namespace detail {
161
-
162
- template <typename T, sorting_order = sorting_order::ascending>
163
- struct ConvertToComp {
164
- using Type = std::less<T>;
165
- };
166
-
167
- template <typename T> struct ConvertToComp <T, sorting_order::descending> {
168
- using Type = std::greater<T>;
169
- };
170
- } // namespace detail
171
-
172
- // Radix sorter provided by the first version of the extension specification.
173
- template <typename ValT, sorting_order OrderT = sorting_order::ascending,
174
- unsigned int BitsPerPass = 4 >
175
- class radix_sorter {
176
-
177
- sycl::span<std::byte> scratch;
178
- uint32_t first_bit = 0 ;
179
- uint32_t last_bit = 0 ;
180
-
181
- static constexpr uint32_t bits = BitsPerPass;
182
- using bitset_t = std::bitset<sizeof (ValT) * CHAR_BIT>;
183
-
184
- public:
185
- template <size_t Extent>
186
- radix_sorter (sycl::span<std::byte, Extent> scratch_,
187
- const bitset_t mask = bitset_t {}.set())
188
- : scratch(scratch_) {
189
- static_assert ((std::is_arithmetic<ValT>::value ||
190
- std::is_same<ValT, sycl::half>::value ||
191
- std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
192
- " radix sort is not usable" );
193
-
194
- for (first_bit = 0 ; first_bit < mask.size () && !mask[first_bit];
195
- ++first_bit)
196
- ;
197
- for (last_bit = first_bit; last_bit < mask.size () && mask[last_bit];
198
- ++last_bit)
199
- ;
200
- }
201
-
202
- template <typename GroupT, typename PtrT>
203
- void operator ()([[maybe_unused]] GroupT g, [[maybe_unused]] PtrT first,
204
- [[maybe_unused]] PtrT last) {
205
- #ifdef __SYCL_DEVICE_ONLY__
206
- sycl::detail::privateDynamicSort</* is_key_value=*/ false ,
207
- OrderT == sorting_order::ascending,
208
- /* empty*/ 1 , BitsPerPass>(
209
- g, first, /* empty*/ first, std::distance (first, last), scratch.data (),
210
- first_bit, last_bit);
211
- #else
212
- throw sycl::exception (
213
- std::error_code (PI_ERROR_INVALID_DEVICE, sycl::sycl_category ()),
214
- " radix_sorter is not supported on host device." );
215
- #endif
216
- }
217
-
218
- template <typename GroupT>
219
- ValT operator ()([[maybe_unused]] GroupT g, [[maybe_unused]] ValT val) {
220
- #ifdef __SYCL_DEVICE_ONLY__
221
- ValT result[]{val};
222
- sycl::detail::privateStaticSort</* is_key_value=*/ false ,
223
- /* is_input_blocked=*/ true ,
224
- /* is_output_blocked=*/ true ,
225
- OrderT == sorting_order::ascending,
226
- /* items_per_work_item=*/ 1 , bits>(
227
- g, result, /* empty*/ result, scratch.data (), first_bit, last_bit);
228
- return result[0 ];
229
- #else
230
- throw sycl::exception (
231
- std::error_code (PI_ERROR_INVALID_DEVICE, sycl::sycl_category ()),
232
- " radix_sorter is not supported on host device." );
233
- #endif
234
- }
235
-
236
- static constexpr size_t memory_required (sycl::memory_scope,
237
- size_t range_size) {
238
- return range_size * sizeof (ValT) +
239
- (1 << bits) * range_size * sizeof (uint32_t ) + alignof (uint32_t );
240
- }
241
-
242
- // memory_helpers
243
- template <int dimensions = 1 >
244
- static constexpr size_t memory_required (sycl::memory_scope,
245
- sycl::range<dimensions> local_range) {
246
- return (std::max)(local_range.size () * sizeof (ValT),
247
- local_range.size () * (1 << bits) * sizeof (uint32_t ));
248
- }
249
- };
250
-
251
- // Default sorters provided by the second version of the extension
252
- // specification.
253
102
namespace default_sorters {
254
103
255
104
template <typename CompareT = std::less<>> class joint_sorter {
@@ -458,7 +307,6 @@ class group_key_value_sorter {
458
307
};
459
308
} // namespace default_sorters
460
309
461
- // Radix sorters provided by the second version of the extension specification.
462
310
namespace radix_sorters {
463
311
464
312
template <typename ValT, sorting_order OrderT = sorting_order::ascending,
0 commit comments