Skip to content

Commit 673ce2c

Browse files
authored
[SYCL] Cleanup SYCL sorters implementation (#13962)
1 parent 1fa2ac8 commit 673ce2c

File tree

3 files changed

+45
-60
lines changed

3 files changed

+45
-60
lines changed

sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp

Lines changed: 38 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -51,46 +51,43 @@ template <typename Group, size_t Extent> class group_with_scratchpad {
5151
// ---- sorters
5252
template <typename Compare = std::less<>> class default_sorter {
5353
Compare comp;
54-
std::byte *scratch;
55-
size_t scratch_size;
54+
sycl::span<std::byte> scratch;
5655

5756
public:
5857
template <size_t Extent>
5958
default_sorter(sycl::span<std::byte, Extent> scratch_,
6059
Compare comp_ = Compare())
61-
: comp(comp_), scratch(scratch_.data()), scratch_size(scratch_.size()) {}
60+
: comp(comp_), scratch(scratch_) {}
6261

6362
template <typename Group, typename Ptr>
64-
void operator()(Group g, Ptr first, Ptr last) {
63+
void operator()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
64+
[[maybe_unused]] Ptr last) {
6565
#ifdef __SYCL_DEVICE_ONLY__
66-
using T = typename sycl::detail::GetValueType<Ptr>::type;
67-
if (scratch_size >= memory_required<T>(Group::fence_scope, last - first))
68-
sycl::detail::merge_sort(g, first, last - first, comp, scratch);
69-
// TODO: it's better to add else branch
66+
// Per extension specification if scratch size is less than the value
67+
// returned by memory_required then behavior is undefined, so we don't check
68+
// that the scratch size statisfies the requirement.
69+
sycl::detail::merge_sort(g, first, last - first, comp, scratch.data());
7070
#else
71-
(void)g;
72-
(void)first;
73-
(void)last;
7471
throw sycl::exception(
7572
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
7673
"default_sorter constructor is not supported on host device.");
7774
#endif
7875
}
7976

80-
template <typename Group, typename T> T operator()(Group g, T val) {
77+
template <typename Group, typename T>
78+
T operator()([[maybe_unused]] Group g, T val) {
8179
#ifdef __SYCL_DEVICE_ONLY__
80+
// Per extension specification if scratch size is less than the value
81+
// returned by memory_required then behavior is undefined, so we don't check
82+
// that the scratch size statisfies the requirement.
8283
auto range_size = g.get_local_range().size();
83-
if (scratch_size >= memory_required<T>(Group::fence_scope, range_size)) {
84-
size_t local_id = g.get_local_linear_id();
85-
T *temp = reinterpret_cast<T *>(scratch);
86-
::new (temp + local_id) T(val);
87-
sycl::detail::merge_sort(g, temp, range_size, comp,
88-
scratch + range_size * sizeof(T));
89-
val = temp[local_id];
90-
}
91-
// TODO: it's better to add else branch
84+
size_t local_id = g.get_local_linear_id();
85+
T *temp = reinterpret_cast<T *>(scratch.data());
86+
::new (temp + local_id) T(val);
87+
sycl::detail::merge_sort(g, temp, range_size, comp,
88+
scratch.data() + range_size * sizeof(T));
89+
val = temp[local_id];
9290
#else
93-
(void)g;
9491
throw sycl::exception(
9592
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
9693
"default_sorter operator() is not supported on host device.");
@@ -129,62 +126,56 @@ template <typename ValT, sorting_order OrderT = sorting_order::ascending,
129126
unsigned int BitsPerPass = 4>
130127
class radix_sorter {
131128

132-
std::byte *scratch = nullptr;
129+
sycl::span<std::byte> scratch;
133130
uint32_t first_bit = 0;
134131
uint32_t last_bit = 0;
135-
size_t scratch_size = 0;
136132

137133
static constexpr uint32_t bits = BitsPerPass;
134+
using bitset_t = std::bitset<sizeof(ValT) * CHAR_BIT>;
138135

139136
public:
140137
template <size_t Extent>
141138
radix_sorter(sycl::span<std::byte, Extent> scratch_,
142-
const std::bitset<sizeof(ValT) *CHAR_BIT> mask =
143-
std::bitset<sizeof(ValT) * CHAR_BIT>(
144-
(std::numeric_limits<unsigned long long>::max)()))
145-
: scratch(scratch_.data()), scratch_size(scratch_.size()) {
139+
const bitset_t mask = bitset_t{}.set())
140+
: scratch(scratch_) {
146141
static_assert((std::is_arithmetic<ValT>::value ||
147142
std::is_same<ValT, sycl::half>::value ||
148143
std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
149144
"radix sort is not usable");
150145

151-
first_bit = 0;
152-
while (first_bit < mask.size() && !mask[first_bit])
153-
++first_bit;
154-
155-
last_bit = first_bit;
156-
while (last_bit < mask.size() && mask[last_bit])
157-
++last_bit;
146+
for (first_bit = 0; first_bit < mask.size() && !mask[first_bit];
147+
++first_bit)
148+
;
149+
for (last_bit = first_bit; last_bit < mask.size() && mask[last_bit];
150+
++last_bit)
151+
;
158152
}
159153

160154
template <typename GroupT, typename PtrT>
161-
void operator()(GroupT g, PtrT first, PtrT last) {
162-
(void)g;
163-
(void)first;
164-
(void)last;
155+
void operator()([[maybe_unused]] GroupT g, [[maybe_unused]] PtrT first,
156+
[[maybe_unused]] PtrT last) {
165157
#ifdef __SYCL_DEVICE_ONLY__
166158
sycl::detail::privateDynamicSort</*is_key_value=*/false,
167159
OrderT == sorting_order::ascending,
168160
/*empty*/ 1, BitsPerPass>(
169-
g, first, /*empty*/ first, (last - first) > 0 ? (last - first) : 0,
170-
scratch, first_bit, last_bit);
161+
g, first, /*empty*/ first, last - first, scratch.data(), first_bit,
162+
last_bit);
171163
#else
172164
throw sycl::exception(
173165
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
174166
"radix_sorter is not supported on host device.");
175167
#endif
176168
}
177169

178-
template <typename GroupT> ValT operator()(GroupT g, ValT val) {
179-
(void)g;
180-
(void)val;
170+
template <typename GroupT>
171+
ValT operator()([[maybe_unused]] GroupT g, [[maybe_unused]] ValT val) {
181172
#ifdef __SYCL_DEVICE_ONLY__
182173
ValT result[]{val};
183174
sycl::detail::privateStaticSort</*is_key_value=*/false,
184175
/*is_blocked=*/true,
185176
OrderT == sorting_order::ascending,
186177
/*items_per_work_item=*/1, bits>(
187-
g, result, /*empty*/ result, scratch, first_bit, last_bit);
178+
g, result, /*empty*/ result, scratch.data(), first_bit, last_bit);
188179
return result[0];
189180
#else
190181
throw sycl::exception(
@@ -193,20 +184,16 @@ class radix_sorter {
193184
#endif
194185
}
195186

196-
static constexpr size_t memory_required(sycl::memory_scope scope,
187+
static constexpr size_t memory_required(sycl::memory_scope,
197188
size_t range_size) {
198-
// Scope is not important so far
199-
(void)scope;
200189
return range_size * sizeof(ValT) +
201190
(1 << bits) * range_size * sizeof(uint32_t) + alignof(uint32_t);
202191
}
203192

204193
// memory_helpers
205194
template <int dimensions = 1>
206-
static constexpr size_t memory_required(sycl::memory_scope scope,
195+
static constexpr size_t memory_required(sycl::memory_scope,
207196
sycl::range<dimensions> local_range) {
208-
// Scope is not important so far
209-
(void)scope;
210197
return (std::max)(local_range.size() * sizeof(ValT),
211198
local_range.size() * (1 << bits) * sizeof(uint32_t));
212199
}

sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,11 @@ struct is_sorter : decltype(is_sorter_impl<Sorter, Group, ValOrPtr>::test(0)) {
7373
// ---- sort_over_group
7474
template <typename Group, typename T, typename Sorter>
7575
std::enable_if_t<detail::is_sorter<Sorter, Group, T>::value, T>
76-
sort_over_group(Group group, T value, Sorter sorter) {
76+
sort_over_group([[maybe_unused]] Group group, [[maybe_unused]] T value,
77+
[[maybe_unused]] Sorter sorter) {
7778
#ifdef __SYCL_DEVICE_ONLY__
7879
return sorter(group, value);
7980
#else
80-
(void)group;
81-
(void)value;
82-
(void)sorter;
8381
throw sycl::exception(
8482
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
8583
"Group algorithms are not supported on host device.");
@@ -106,14 +104,11 @@ sort_over_group(experimental::group_with_scratchpad<Group, Extent> exec,
106104
// ---- joint_sort
107105
template <typename Group, typename Iter, typename Sorter>
108106
std::enable_if_t<detail::is_sorter<Sorter, Group, Iter>::value, void>
109-
joint_sort(Group group, Iter first, Iter last, Sorter sorter) {
107+
joint_sort([[maybe_unused]] Group group, [[maybe_unused]] Iter first,
108+
[[maybe_unused]] Iter last, [[maybe_unused]] Sorter sorter) {
110109
#ifdef __SYCL_DEVICE_ONLY__
111110
sorter(group, first, last);
112111
#else
113-
(void)group;
114-
(void)first;
115-
(void)last;
116-
(void)sorter;
117112
throw sycl::exception(
118113
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
119114
"Group algorithms are not supported on host device.");

sycl/test-e2e/GroupAlgorithm/SYCL2020/sort.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
187187
const size_t EndIdx =
188188
std::min(ChunkSize * (PartID + 1), NumOfElements);
189189

190+
if (EndIdx <= StartIdx)
191+
return;
192+
190193
// This version of API always sorts in ascending order
191194
if constexpr (std::is_same_v<Compare, std::less<T>>)
192195
oneapi_exp::joint_sort(

0 commit comments

Comments
 (0)