@@ -51,46 +51,43 @@ template <typename Group, size_t Extent> class group_with_scratchpad {
51
51
// ---- sorters
52
52
template <typename Compare = std::less<>> class default_sorter {
53
53
Compare comp;
54
- std::byte *scratch;
55
- size_t scratch_size;
54
+ sycl::span<std::byte> scratch;
56
55
57
56
public:
58
57
template <size_t Extent>
59
58
default_sorter (sycl::span<std::byte, Extent> scratch_,
60
59
Compare comp_ = Compare())
61
- : comp(comp_), scratch(scratch_.data()), scratch_size(scratch_.size() ) {}
60
+ : comp(comp_), scratch(scratch_) {}
62
61
63
62
template <typename Group, typename Ptr>
64
- void operator ()(Group g, Ptr first, Ptr last) {
63
+ void operator ()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
64
+ [[maybe_unused]] Ptr last) {
65
65
#ifdef __SYCL_DEVICE_ONLY__
66
- using T = typename sycl::detail::GetValueType<Ptr>::type;
67
- if (scratch_size >= memory_required<T>(Group::fence_scope, last - first))
68
- sycl::detail::merge_sort (g, first, last - first, comp, scratch);
69
- // TODO: it's better to add else branch
66
+ // Per extension specification if scratch size is less than the value
67
+ // returned by memory_required then behavior is undefined, so we don't check
68
+ // that the scratch size statisfies the requirement.
69
+ sycl::detail::merge_sort (g, first, last - first, comp, scratch. data ());
70
70
#else
71
- (void )g;
72
- (void )first;
73
- (void )last;
74
71
throw sycl::exception (
75
72
std::error_code (PI_ERROR_INVALID_DEVICE, sycl::sycl_category ()),
76
73
" default_sorter constructor is not supported on host device." );
77
74
#endif
78
75
}
79
76
80
- template <typename Group, typename T> T operator ()(Group g, T val) {
77
+ template <typename Group, typename T>
78
+ T operator ()([[maybe_unused]] Group g, T val) {
81
79
#ifdef __SYCL_DEVICE_ONLY__
80
+ // Per extension specification if scratch size is less than the value
81
+ // returned by memory_required then behavior is undefined, so we don't check
82
+ // that the scratch size statisfies the requirement.
82
83
auto range_size = g.get_local_range ().size ();
83
- if (scratch_size >= memory_required<T>(Group::fence_scope, range_size)) {
84
- size_t local_id = g.get_local_linear_id ();
85
- T *temp = reinterpret_cast <T *>(scratch);
86
- ::new (temp + local_id) T (val);
87
- sycl::detail::merge_sort (g, temp, range_size, comp,
88
- scratch + range_size * sizeof (T));
89
- val = temp[local_id];
90
- }
91
- // TODO: it's better to add else branch
84
+ size_t local_id = g.get_local_linear_id ();
85
+ T *temp = reinterpret_cast <T *>(scratch.data ());
86
+ ::new (temp + local_id) T (val);
87
+ sycl::detail::merge_sort (g, temp, range_size, comp,
88
+ scratch.data () + range_size * sizeof (T));
89
+ val = temp[local_id];
92
90
#else
93
- (void )g;
94
91
throw sycl::exception (
95
92
std::error_code (PI_ERROR_INVALID_DEVICE, sycl::sycl_category ()),
96
93
" default_sorter operator() is not supported on host device." );
@@ -129,62 +126,56 @@ template <typename ValT, sorting_order OrderT = sorting_order::ascending,
129
126
unsigned int BitsPerPass = 4 >
130
127
class radix_sorter {
131
128
132
- std::byte * scratch = nullptr ;
129
+ sycl::span< std::byte> scratch;
133
130
uint32_t first_bit = 0 ;
134
131
uint32_t last_bit = 0 ;
135
- size_t scratch_size = 0 ;
136
132
137
133
static constexpr uint32_t bits = BitsPerPass;
134
+ using bitset_t = std::bitset<sizeof (ValT) * CHAR_BIT>;
138
135
139
136
public:
140
137
template <size_t Extent>
141
138
radix_sorter (sycl::span<std::byte, Extent> scratch_,
142
- const std::bitset<sizeof (ValT) *CHAR_BIT> mask =
143
- std::bitset<sizeof (ValT) * CHAR_BIT>(
144
- (std::numeric_limits<unsigned long long >::max)()))
145
- : scratch(scratch_.data()), scratch_size(scratch_.size()) {
139
+ const bitset_t mask = bitset_t {}.set())
140
+ : scratch(scratch_) {
146
141
static_assert ((std::is_arithmetic<ValT>::value ||
147
142
std::is_same<ValT, sycl::half>::value ||
148
143
std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
149
144
" radix sort is not usable" );
150
145
151
- first_bit = 0 ;
152
- while (first_bit < mask.size () && !mask[first_bit])
153
- ++first_bit;
154
-
155
- last_bit = first_bit;
156
- while (last_bit < mask.size () && mask[last_bit])
157
- ++last_bit;
146
+ for (first_bit = 0 ; first_bit < mask.size () && !mask[first_bit];
147
+ ++first_bit)
148
+ ;
149
+ for (last_bit = first_bit; last_bit < mask.size () && mask[last_bit];
150
+ ++last_bit)
151
+ ;
158
152
}
159
153
160
154
template <typename GroupT, typename PtrT>
161
- void operator ()(GroupT g, PtrT first, PtrT last) {
162
- (void )g;
163
- (void )first;
164
- (void )last;
155
+ void operator ()([[maybe_unused]] GroupT g, [[maybe_unused]] PtrT first,
156
+ [[maybe_unused]] PtrT last) {
165
157
#ifdef __SYCL_DEVICE_ONLY__
166
158
sycl::detail::privateDynamicSort</* is_key_value=*/ false ,
167
159
OrderT == sorting_order::ascending,
168
160
/* empty*/ 1 , BitsPerPass>(
169
- g, first, /* empty*/ first, ( last - first) > 0 ? (last - first) : 0 ,
170
- scratch, first_bit, last_bit);
161
+ g, first, /* empty*/ first, last - first, scratch. data (), first_bit ,
162
+ last_bit);
171
163
#else
172
164
throw sycl::exception (
173
165
std::error_code (PI_ERROR_INVALID_DEVICE, sycl::sycl_category ()),
174
166
" radix_sorter is not supported on host device." );
175
167
#endif
176
168
}
177
169
178
- template <typename GroupT> ValT operator ()(GroupT g, ValT val) {
179
- (void )g;
180
- (void )val;
170
+ template <typename GroupT>
171
+ ValT operator ()([[maybe_unused]] GroupT g, [[maybe_unused]] ValT val) {
181
172
#ifdef __SYCL_DEVICE_ONLY__
182
173
ValT result[]{val};
183
174
sycl::detail::privateStaticSort</* is_key_value=*/ false ,
184
175
/* is_blocked=*/ true ,
185
176
OrderT == sorting_order::ascending,
186
177
/* items_per_work_item=*/ 1 , bits>(
187
- g, result, /* empty*/ result, scratch, first_bit, last_bit);
178
+ g, result, /* empty*/ result, scratch. data () , first_bit, last_bit);
188
179
return result[0 ];
189
180
#else
190
181
throw sycl::exception (
@@ -193,20 +184,16 @@ class radix_sorter {
193
184
#endif
194
185
}
195
186
196
- static constexpr size_t memory_required (sycl::memory_scope scope ,
187
+ static constexpr size_t memory_required (sycl::memory_scope,
197
188
size_t range_size) {
198
- // Scope is not important so far
199
- (void )scope;
200
189
return range_size * sizeof (ValT) +
201
190
(1 << bits) * range_size * sizeof (uint32_t ) + alignof (uint32_t );
202
191
}
203
192
204
193
// memory_helpers
205
194
template <int dimensions = 1 >
206
- static constexpr size_t memory_required (sycl::memory_scope scope ,
195
+ static constexpr size_t memory_required (sycl::memory_scope,
207
196
sycl::range<dimensions> local_range) {
208
- // Scope is not important so far
209
- (void )scope;
210
197
return (std::max)(local_range.size () * sizeof (ValT),
211
198
local_range.size () * (1 << bits) * sizeof (uint32_t ));
212
199
}
0 commit comments