Skip to content

Commit 932ae56

Browse files
authored
[SYCL][Group algorithms] Add group sorting algorithms implementation (#4439)
This PR introduce the implementation for the Group Sort extension. The PR includes: Feature macro joint_sort and sort_over_group functions default_sorter group_with_scratchpad Algorithms are quite general. It should work with custom data types, custom comparators, custom sorters. The PR doesn't include: radix_sorter optimized specialization for arithmetic types Tests are here: intel/llvm-test-suite#438 Signed-off-by: Fedorov, Andrey [email protected]
1 parent a635873 commit 932ae56

File tree

5 files changed

+515
-12
lines changed

5 files changed

+515
-12
lines changed

sycl/doc/extensions/GroupAlgorithms/SYCL_INTEL_group_sort.asciidoc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ namespace sycl::ext::oneapi::experimental {
146146
class default_sorter {
147147
public:
148148
template<std::size_t Extent>
149-
default_sorter(sycl::span<uint8_t, Extent> scratch, Compare comp = Compare());
149+
default_sorter(sycl::span<std::byte, Extent> scratch, Compare comp = Compare());
150150
151151
template<typename Group, typename Ptr>
152152
void operator()(Group g, Ptr first, Ptr last);
@@ -167,7 +167,7 @@ namespace sycl::ext::oneapi::experimental {
167167
class radix_sorter {
168168
public:
169169
template<std::size_t Extent>
170-
radix_sorter(sycl::span<uint8_t, Extent> scratch,
170+
radix_sorter(sycl::span<std::byte, Extent> scratch,
171171
const std::bitset<sizeof(T) * CHAR_BIT> mask =
172172
std::bitset<sizeof(T) * CHAR_BIT> (std::numeric_limits<unsigned long long>::max()));
173173
@@ -215,7 +215,7 @@ Table 4. Constructors of the `default_sorter` class.
215215
|Constructor|Description
216216

217217
|`template<std::size_t Extent>
218-
default_sorter(sycl::span<uint8_t, Extent> scratch, Compare comp = Compare())`
218+
default_sorter(sycl::span<std::byte, Extent> scratch, Compare comp = Compare())`
219219
|Creates the `default_sorter` object using `comp`.
220220
Additional memory for the algorithm is provided using `scratch`.
221221
If `scratch.size()` is less than the value returned by
@@ -264,7 +264,7 @@ Table 6. Constructors of the `radix_sorter` class.
264264
|Constructor|Description
265265

266266
|`template<std::size_t Extent>
267-
radix_sorter(sycl::span<uint8_t, Extent> scratch, const std::bitset<sizeof(T) * CHAR_BIT> mask = std::bitset<sizeof(T) * CHAR_BIT>
267+
radix_sorter(sycl::span<std::byte, Extent> scratch, const std::bitset<sizeof(T) * CHAR_BIT> mask = std::bitset<sizeof(T) * CHAR_BIT>
268268
(std::numeric_limits<unsigned long long>::max()))`
269269
|Creates the `radix_sorter` object to sort values considering only bits
270270
that corresponds to 1 in `mask`.
@@ -350,16 +350,16 @@ namespace sycl::ext::oneapi::experimental {
350350
class group_with_scratchpad
351351
{
352352
public:
353-
group_with_scratchpad(Group group, sycl::span<uint8_t, Extent> scratch);
353+
group_with_scratchpad(Group group, sycl::span<std::byte, Extent> scratch);
354354
Group get_group() const;
355355
356-
sycl::span<uint8_t, Extent>
356+
sycl::span<std::byte, Extent>
357357
get_memory() const;
358358
};
359359
360360
// Deduction guides
361361
template<typename Group, std::size_t Extent>
362-
group_with_scratchpad(Group, sycl::span<uint8_t, Extent>)
362+
group_with_scratchpad(Group, sycl::span<std::byte, Extent>)
363363
-> group_with_scratchpad<Group, Extent>;
364364
365365
}
@@ -372,7 +372,7 @@ Table 9. Constructors of the `group_with_scratchpad` class.
372372
|===
373373
|Constructor|Description
374374

375-
|`group_with_scratchpad(Group group, sycl::span<uint8_t, Extent> scratch)`
375+
|`group_with_scratchpad(Group group, sycl::span<std::byte, Extent> scratch)`
376376
|Creates the `group_with_scratchpad` object using `group` and `scratch`.
377377
`sycl::is_group_v<std::decay_t<Group>>` must be true.
378378
`scratch.size()` must not be less than value returned by the `memory_required` method
@@ -388,7 +388,7 @@ Table 10. Member functions of the `group_with_scratchpad` class.
388388
|`Group get_group() const`
389389
|Returns the `Group` class object that is handled by the `group_with_scratchpad` object.
390390

391-
|`sycl::span<uint8_t, Extent>
391+
|`sycl::span<std::byte, Extent>
392392
get_memory() const`
393393
|Returns `sycl::span` that represents an additional memory
394394
that is handled by the `group_with_scratchpad` object.
@@ -508,7 +508,7 @@ size_t temp_memory_size =
508508
509509
q.submit([&](sycl::handler& h) {
510510
auto acc = sycl::accessor(buf, h);
511-
auto scratch = sycl::local_accessor<uint8_t, 1>( {temp_memory_size}, h );
511+
auto scratch = sycl::local_accessor<std::byte, 1>( {temp_memory_size}, h );
512512
513513
h.parallel_for(
514514
sycl::nd_range<1>{ /*global_size = */ {256}, /*local_size = */ {256} },
@@ -546,7 +546,7 @@ size_t temp_memory_size =
546546
547547
q.submit([&](sycl::handler& h) {
548548
auto acc = sycl::accessor(buf, h);
549-
auto scratch = sycl::local_accessor<uint8_t, 1>( {temp_memory_size}, h);
549+
auto scratch = sycl::local_accessor<std::byte, 1>( {temp_memory_size}, h);
550550
551551
h.parallel_for(
552552
sycl::nd_range<1>{ local_range, local_range },
@@ -583,7 +583,7 @@ size_t temp_memory_size =
583583
q.submit([&](sycl::handler& h) {
584584
auto keys_acc = sycl::accessor(keys_buf, h);
585585
auto vals_acc = sycl::accessor(vals_buf, h);
586-
auto scratch = sycl::local_accessor<uint8_t, 1>( {temp_memory_size}, h);
586+
auto scratch = sycl::local_accessor<std::byte, 1>( {temp_memory_size}, h);
587587
588588
h.parallel_for(
589589
sycl::nd_range<1>{ /*global_size = */ {1024}, /*local_size = */ {256} },
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
//==------------ group_sort_impl.hpp ---------------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// This file includes some functions for group sorting algorithm implementations
9+
//
10+
11+
#pragma once
12+
13+
#if __cplusplus >= 201703L
14+
#include <CL/sycl/detail/helpers.hpp>
15+
16+
#ifdef __SYCL_DEVICE_ONLY__
17+
18+
__SYCL_INLINE_NAMESPACE(cl) {
19+
namespace sycl {
20+
namespace detail {
21+
22+
// ---- merge sort implementation
23+
24+
// following two functions could be useless if std::[lower|upper]_bound worked
25+
// well
26+
template <typename Acc, typename Value, typename Compare>
27+
std::size_t lower_bound(Acc acc, std::size_t first, std::size_t last,
28+
const Value &value, Compare comp) {
29+
std::size_t n = last - first;
30+
std::size_t cur = n;
31+
std::size_t it;
32+
while (n > 0) {
33+
it = first;
34+
cur = n / 2;
35+
it += cur;
36+
if (comp(acc[it], value)) {
37+
n -= cur + 1, first = ++it;
38+
} else
39+
n = cur;
40+
}
41+
return first;
42+
}
43+
44+
template <typename Acc, typename Value, typename Compare>
45+
std::size_t upper_bound(Acc acc, const std::size_t first,
46+
const std::size_t last, const Value &value,
47+
Compare comp) {
48+
return detail::lower_bound(acc, first, last, value,
49+
[comp](auto x, auto y) { return !comp(y, x); });
50+
}
51+
52+
// swap for all data types including tuple-like types
53+
template <typename T> void swap_tuples(T &a, T &b) { std::swap(a, b); }
54+
55+
template <template <typename...> class TupleLike, typename T1, typename T2>
56+
void swap_tuples(TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
57+
std::swap(std::get<0>(a), std::get<0>(b));
58+
std::swap(std::get<1>(a), std::get<1>(b));
59+
}
60+
61+
template <typename Iter> struct GetValueType {
62+
using type = typename std::iterator_traits<Iter>::value_type;
63+
};
64+
65+
template <typename ElementType, access::address_space Space>
66+
struct GetValueType<sycl::multi_ptr<ElementType, Space>> {
67+
using type = ElementType;
68+
};
69+
70+
// since we couldn't assign data to raw memory, it's better to use placement
71+
// for first assignment
72+
template <typename Acc, typename T>
73+
void set_value(Acc ptr, const std::size_t idx, const T &val, bool is_first) {
74+
if (is_first) {
75+
::new (ptr + idx) T(val);
76+
} else {
77+
ptr[idx] = val;
78+
}
79+
}
80+
81+
template <typename InAcc, typename OutAcc, typename Compare>
82+
void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
83+
const std::size_t start_1, const std::size_t end_1,
84+
const std::size_t end_2, const std::size_t start_out, Compare comp,
85+
const std::size_t chunk, bool is_first) {
86+
const std::size_t start_2 = end_1;
87+
// Borders of the sequences to merge within this call
88+
const std::size_t local_start_1 =
89+
sycl::min(static_cast<std::size_t>(offset + start_1), end_1);
90+
const std::size_t local_end_1 =
91+
sycl::min(static_cast<std::size_t>(local_start_1 + chunk), end_1);
92+
const std::size_t local_start_2 =
93+
sycl::min(static_cast<std::size_t>(offset + start_2), end_2);
94+
const std::size_t local_end_2 =
95+
sycl::min(static_cast<std::size_t>(local_start_2 + chunk), end_2);
96+
97+
const std::size_t local_size_1 = local_end_1 - local_start_1;
98+
const std::size_t local_size_2 = local_end_2 - local_start_2;
99+
100+
// TODO: process cases where all elements of 1st sequence > 2nd, 2nd > 1st
101+
// to improve performance
102+
103+
// Process 1st sequence
104+
if (local_start_1 < local_end_1) {
105+
// Reduce the range for searching within the 2nd sequence and handle bound
106+
// items find left border in 2nd sequence
107+
const auto local_l_item_1 = in_acc1[local_start_1];
108+
std::size_t l_search_bound_2 =
109+
detail::lower_bound(in_acc1, start_2, end_2, local_l_item_1, comp);
110+
const std::size_t l_shift_1 = local_start_1 - start_1;
111+
const std::size_t l_shift_2 = l_search_bound_2 - start_2;
112+
113+
set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1,
114+
is_first);
115+
116+
std::size_t r_search_bound_2{};
117+
// find right border in 2nd sequence
118+
if (local_size_1 > 1) {
119+
const auto local_r_item_1 = in_acc1[local_end_1 - 1];
120+
r_search_bound_2 = detail::lower_bound(in_acc1, l_search_bound_2, end_2,
121+
local_r_item_1, comp);
122+
const auto r_shift_1 = local_end_1 - 1 - start_1;
123+
const auto r_shift_2 = r_search_bound_2 - start_2;
124+
125+
set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_1,
126+
is_first);
127+
}
128+
129+
// Handle intermediate items
130+
for (std::size_t idx = local_start_1 + 1; idx < local_end_1 - 1; ++idx) {
131+
const auto intermediate_item_1 = in_acc1[idx];
132+
// we shouldn't seek in whole 2nd sequence. Just for the part where the
133+
// 1st sequence should be
134+
l_search_bound_2 =
135+
detail::lower_bound(in_acc1, l_search_bound_2, r_search_bound_2,
136+
intermediate_item_1, comp);
137+
const std::size_t shift_1 = idx - start_1;
138+
const std::size_t shift_2 = l_search_bound_2 - start_2;
139+
140+
set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_1,
141+
is_first);
142+
}
143+
}
144+
// Process 2nd sequence
145+
if (local_start_2 < local_end_2) {
146+
// Reduce the range for searching within the 1st sequence and handle bound
147+
// items find left border in 1st sequence
148+
const auto local_l_item_2 = in_acc1[local_start_2];
149+
std::size_t l_search_bound_1 =
150+
detail::upper_bound(in_acc1, start_1, end_1, local_l_item_2, comp);
151+
const std::size_t l_shift_1 = l_search_bound_1 - start_1;
152+
const std::size_t l_shift_2 = local_start_2 - start_2;
153+
154+
set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2,
155+
is_first);
156+
157+
std::size_t r_search_bound_1{};
158+
// find right border in 1st sequence
159+
if (local_size_2 > 1) {
160+
const auto local_r_item_2 = in_acc1[local_end_2 - 1];
161+
r_search_bound_1 = detail::upper_bound(in_acc1, l_search_bound_1, end_1,
162+
local_r_item_2, comp);
163+
const std::size_t r_shift_1 = r_search_bound_1 - start_1;
164+
const std::size_t r_shift_2 = local_end_2 - 1 - start_2;
165+
166+
set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2,
167+
is_first);
168+
}
169+
170+
// Handle intermediate items
171+
for (auto idx = local_start_2 + 1; idx < local_end_2 - 1; ++idx) {
172+
const auto intermediate_item_2 = in_acc1[idx];
173+
// we shouldn't seek in whole 1st sequence. Just for the part where the
174+
// 2nd sequence should be
175+
l_search_bound_1 =
176+
detail::upper_bound(in_acc1, l_search_bound_1, r_search_bound_1,
177+
intermediate_item_2, comp);
178+
const std::size_t shift_1 = l_search_bound_1 - start_1;
179+
const std::size_t shift_2 = idx - start_2;
180+
181+
set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_2,
182+
is_first);
183+
}
184+
}
185+
}
186+
187+
template <typename Iter, typename Compare>
188+
void bubble_sort(Iter first, const std::size_t begin, const std::size_t end,
189+
Compare comp) {
190+
if (begin < end) {
191+
for (std::size_t i = begin; i < end; ++i) {
192+
// Handle intermediate items
193+
for (std::size_t idx = i + 1; idx < end; ++idx) {
194+
if (comp(first[idx], first[i])) {
195+
detail::swap_tuples(first[i], first[idx]);
196+
}
197+
}
198+
}
199+
}
200+
}
201+
202+
template <typename Group, typename Iter, typename Compare>
203+
void merge_sort(Group group, Iter first, const std::size_t n, Compare comp,
204+
std::byte *scratch) {
205+
using T = typename GetValueType<Iter>::type;
206+
auto id = sycl::detail::Builder::getNDItem<Group::dimensions>();
207+
const std::size_t idx = id.get_local_linear_id();
208+
const std::size_t local = group.get_local_range().size();
209+
const std::size_t chunk = (n - 1) / local + 1;
210+
211+
// we need to sort within work item first
212+
bubble_sort(first, idx * chunk, sycl::min((idx + 1) * chunk, n), comp);
213+
id.barrier();
214+
215+
T *temp = reinterpret_cast<T *>(scratch);
216+
bool data_in_temp = false;
217+
bool is_first = true;
218+
std::size_t sorted_size = 1;
219+
while (sorted_size * chunk < n) {
220+
const std::size_t start_1 =
221+
sycl::min(2 * sorted_size * chunk * (idx / sorted_size), n);
222+
const std::size_t end_1 = sycl::min(start_1 + sorted_size * chunk, n);
223+
const std::size_t end_2 = sycl::min(end_1 + sorted_size * chunk, n);
224+
const std::size_t offset = chunk * (idx % sorted_size);
225+
226+
if (!data_in_temp) {
227+
merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk,
228+
is_first);
229+
} else {
230+
merge(offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk,
231+
/*is_first*/ false);
232+
}
233+
id.barrier();
234+
235+
data_in_temp = !data_in_temp;
236+
sorted_size *= 2;
237+
if (is_first)
238+
is_first = false;
239+
}
240+
241+
// copy back if data is in a temporary storage
242+
if (data_in_temp) {
243+
for (std::size_t i = 0; i < chunk; ++i) {
244+
if (idx * chunk + i < n) {
245+
first[idx * chunk + i] = temp[idx * chunk + i];
246+
}
247+
}
248+
id.barrier();
249+
}
250+
}
251+
252+
} // namespace detail
253+
} // namespace sycl
254+
} // __SYCL_INLINE_NAMESPACE(cl)
255+
#endif
256+
#endif // __cplusplus >=201703L

sycl/include/CL/sycl/group_algorithm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <CL/sycl/nd_item.hpp>
1919
#include <CL/sycl/sub_group.hpp>
2020
#include <sycl/ext/oneapi/functional.hpp>
21+
#include <sycl/ext/oneapi/group_sort.hpp>
2122

2223
__SYCL_INLINE_NAMESPACE(cl) {
2324
namespace sycl {

0 commit comments

Comments
 (0)