1
+ // ==------------ group_sort_impl.hpp ---------------------------------------==//
2
+ //
3
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
+ // See https://llvm.org/LICENSE.txt for license information.
5
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
+ //
7
+ // ===----------------------------------------------------------------------===//
8
+ // This file includes some functions for group sorting algorithm implementations
9
+ //
10
+
11
+ #pragma once
12
+
13
+ #if __cplusplus >= 201703L
14
+ #include < CL/sycl/detail/helpers.hpp>
15
+
16
+ #ifdef __SYCL_DEVICE_ONLY__
17
+
18
+ __SYCL_INLINE_NAMESPACE (cl) {
19
+ namespace sycl {
20
+ namespace detail {
21
+
22
+ // ---- merge sort implementation
23
+
24
+ // following two functions could be useless if std::[lower|upper]_bound worked
25
+ // well
26
+ template <typename Acc, typename Value, typename Compare>
27
+ std::size_t lower_bound (Acc acc, std::size_t first, std::size_t last,
28
+ const Value &value, Compare comp) {
29
+ std::size_t n = last - first;
30
+ std::size_t cur = n;
31
+ std::size_t it;
32
+ while (n > 0 ) {
33
+ it = first;
34
+ cur = n / 2 ;
35
+ it += cur;
36
+ if (comp (acc[it], value)) {
37
+ n -= cur + 1 , first = ++it;
38
+ } else
39
+ n = cur;
40
+ }
41
+ return first;
42
+ }
43
+
44
+ template <typename Acc, typename Value, typename Compare>
45
+ std::size_t upper_bound (Acc acc, const std::size_t first,
46
+ const std::size_t last, const Value &value,
47
+ Compare comp) {
48
+ return detail::lower_bound (acc, first, last, value,
49
+ [comp](auto x, auto y) { return !comp (y, x); });
50
+ }
51
+
52
+ // swap for all data types including tuple-like types
53
+ template <typename T> void swap_tuples (T &a, T &b) { std::swap (a, b); }
54
+
55
+ template <template <typename ...> class TupleLike , typename T1, typename T2>
56
+ void swap_tuples (TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
57
+ std::swap (std::get<0 >(a), std::get<0 >(b));
58
+ std::swap (std::get<1 >(a), std::get<1 >(b));
59
+ }
60
+
61
+ template <typename Iter> struct GetValueType {
62
+ using type = typename std::iterator_traits<Iter>::value_type;
63
+ };
64
+
65
+ template <typename ElementType, access::address_space Space>
66
+ struct GetValueType <sycl::multi_ptr<ElementType, Space>> {
67
+ using type = ElementType;
68
+ };
69
+
70
+ // since we couldn't assign data to raw memory, it's better to use placement
71
+ // for first assignment
72
+ template <typename Acc, typename T>
73
+ void set_value (Acc ptr, const std::size_t idx, const T &val, bool is_first) {
74
+ if (is_first) {
75
+ ::new (ptr + idx) T (val);
76
+ } else {
77
+ ptr[idx] = val;
78
+ }
79
+ }
80
+
81
+ template <typename InAcc, typename OutAcc, typename Compare>
82
+ void merge (const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
83
+ const std::size_t start_1, const std::size_t end_1,
84
+ const std::size_t end_2, const std::size_t start_out, Compare comp,
85
+ const std::size_t chunk, bool is_first) {
86
+ const std::size_t start_2 = end_1;
87
+ // Borders of the sequences to merge within this call
88
+ const std::size_t local_start_1 =
89
+ sycl::min (static_cast <std::size_t >(offset + start_1), end_1);
90
+ const std::size_t local_end_1 =
91
+ sycl::min (static_cast <std::size_t >(local_start_1 + chunk), end_1);
92
+ const std::size_t local_start_2 =
93
+ sycl::min (static_cast <std::size_t >(offset + start_2), end_2);
94
+ const std::size_t local_end_2 =
95
+ sycl::min (static_cast <std::size_t >(local_start_2 + chunk), end_2);
96
+
97
+ const std::size_t local_size_1 = local_end_1 - local_start_1;
98
+ const std::size_t local_size_2 = local_end_2 - local_start_2;
99
+
100
+ // TODO: process cases where all elements of 1st sequence > 2nd, 2nd > 1st
101
+ // to improve performance
102
+
103
+ // Process 1st sequence
104
+ if (local_start_1 < local_end_1) {
105
+ // Reduce the range for searching within the 2nd sequence and handle bound
106
+ // items find left border in 2nd sequence
107
+ const auto local_l_item_1 = in_acc1[local_start_1];
108
+ std::size_t l_search_bound_2 =
109
+ detail::lower_bound (in_acc1, start_2, end_2, local_l_item_1, comp);
110
+ const std::size_t l_shift_1 = local_start_1 - start_1;
111
+ const std::size_t l_shift_2 = l_search_bound_2 - start_2;
112
+
113
+ set_value (out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1,
114
+ is_first);
115
+
116
+ std::size_t r_search_bound_2{};
117
+ // find right border in 2nd sequence
118
+ if (local_size_1 > 1 ) {
119
+ const auto local_r_item_1 = in_acc1[local_end_1 - 1 ];
120
+ r_search_bound_2 = detail::lower_bound (in_acc1, l_search_bound_2, end_2,
121
+ local_r_item_1, comp);
122
+ const auto r_shift_1 = local_end_1 - 1 - start_1;
123
+ const auto r_shift_2 = r_search_bound_2 - start_2;
124
+
125
+ set_value (out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_1,
126
+ is_first);
127
+ }
128
+
129
+ // Handle intermediate items
130
+ for (std::size_t idx = local_start_1 + 1 ; idx < local_end_1 - 1 ; ++idx) {
131
+ const auto intermediate_item_1 = in_acc1[idx];
132
+ // we shouldn't seek in whole 2nd sequence. Just for the part where the
133
+ // 1st sequence should be
134
+ l_search_bound_2 =
135
+ detail::lower_bound (in_acc1, l_search_bound_2, r_search_bound_2,
136
+ intermediate_item_1, comp);
137
+ const std::size_t shift_1 = idx - start_1;
138
+ const std::size_t shift_2 = l_search_bound_2 - start_2;
139
+
140
+ set_value (out_acc1, start_out + shift_1 + shift_2, intermediate_item_1,
141
+ is_first);
142
+ }
143
+ }
144
+ // Process 2nd sequence
145
+ if (local_start_2 < local_end_2) {
146
+ // Reduce the range for searching within the 1st sequence and handle bound
147
+ // items find left border in 1st sequence
148
+ const auto local_l_item_2 = in_acc1[local_start_2];
149
+ std::size_t l_search_bound_1 =
150
+ detail::upper_bound (in_acc1, start_1, end_1, local_l_item_2, comp);
151
+ const std::size_t l_shift_1 = l_search_bound_1 - start_1;
152
+ const std::size_t l_shift_2 = local_start_2 - start_2;
153
+
154
+ set_value (out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2,
155
+ is_first);
156
+
157
+ std::size_t r_search_bound_1{};
158
+ // find right border in 1st sequence
159
+ if (local_size_2 > 1 ) {
160
+ const auto local_r_item_2 = in_acc1[local_end_2 - 1 ];
161
+ r_search_bound_1 = detail::upper_bound (in_acc1, l_search_bound_1, end_1,
162
+ local_r_item_2, comp);
163
+ const std::size_t r_shift_1 = r_search_bound_1 - start_1;
164
+ const std::size_t r_shift_2 = local_end_2 - 1 - start_2;
165
+
166
+ set_value (out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2,
167
+ is_first);
168
+ }
169
+
170
+ // Handle intermediate items
171
+ for (auto idx = local_start_2 + 1 ; idx < local_end_2 - 1 ; ++idx) {
172
+ const auto intermediate_item_2 = in_acc1[idx];
173
+ // we shouldn't seek in whole 1st sequence. Just for the part where the
174
+ // 2nd sequence should be
175
+ l_search_bound_1 =
176
+ detail::upper_bound (in_acc1, l_search_bound_1, r_search_bound_1,
177
+ intermediate_item_2, comp);
178
+ const std::size_t shift_1 = l_search_bound_1 - start_1;
179
+ const std::size_t shift_2 = idx - start_2;
180
+
181
+ set_value (out_acc1, start_out + shift_1 + shift_2, intermediate_item_2,
182
+ is_first);
183
+ }
184
+ }
185
+ }
186
+
187
+ template <typename Iter, typename Compare>
188
+ void bubble_sort (Iter first, const std::size_t begin, const std::size_t end,
189
+ Compare comp) {
190
+ if (begin < end) {
191
+ for (std::size_t i = begin; i < end; ++i) {
192
+ // Handle intermediate items
193
+ for (std::size_t idx = i + 1 ; idx < end; ++idx) {
194
+ if (comp (first[idx], first[i])) {
195
+ detail::swap_tuples (first[i], first[idx]);
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ template <typename Group, typename Iter, typename Compare>
203
+ void merge_sort (Group group, Iter first, const std::size_t n, Compare comp,
204
+ std::byte *scratch) {
205
+ using T = typename GetValueType<Iter>::type;
206
+ auto id = sycl::detail::Builder::getNDItem<Group::dimensions>();
207
+ const std::size_t idx = id.get_local_linear_id ();
208
+ const std::size_t local = group.get_local_range ().size ();
209
+ const std::size_t chunk = (n - 1 ) / local + 1 ;
210
+
211
+ // we need to sort within work item first
212
+ bubble_sort (first, idx * chunk, sycl::min ((idx + 1 ) * chunk, n), comp);
213
+ id.barrier ();
214
+
215
+ T *temp = reinterpret_cast <T *>(scratch);
216
+ bool data_in_temp = false ;
217
+ bool is_first = true ;
218
+ std::size_t sorted_size = 1 ;
219
+ while (sorted_size * chunk < n) {
220
+ const std::size_t start_1 =
221
+ sycl::min (2 * sorted_size * chunk * (idx / sorted_size), n);
222
+ const std::size_t end_1 = sycl::min (start_1 + sorted_size * chunk, n);
223
+ const std::size_t end_2 = sycl::min (end_1 + sorted_size * chunk, n);
224
+ const std::size_t offset = chunk * (idx % sorted_size);
225
+
226
+ if (!data_in_temp) {
227
+ merge (offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk,
228
+ is_first);
229
+ } else {
230
+ merge (offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk,
231
+ /* is_first*/ false );
232
+ }
233
+ id.barrier ();
234
+
235
+ data_in_temp = !data_in_temp;
236
+ sorted_size *= 2 ;
237
+ if (is_first)
238
+ is_first = false ;
239
+ }
240
+
241
+ // copy back if data is in a temporary storage
242
+ if (data_in_temp) {
243
+ for (std::size_t i = 0 ; i < chunk; ++i) {
244
+ if (idx * chunk + i < n) {
245
+ first[idx * chunk + i] = temp[idx * chunk + i];
246
+ }
247
+ }
248
+ id.barrier ();
249
+ }
250
+ }
251
+
252
+ } // namespace detail
253
+ } // namespace sycl
254
+ } // __SYCL_INLINE_NAMESPACE(cl)
255
+ #endif
256
+ #endif // __cplusplus >=201703L
0 commit comments