10
10
// For comparators {std::less, std::greater}
11
11
// For dimensions {1, 2}
12
12
// For group {work-group, sub-group}
13
+ // For sorters {default_sorter, radix_sorter}
13
14
// joint_sort with
14
15
// WG size = {16} or {1, 16}
15
16
// SG size = {8}
@@ -55,6 +56,29 @@ class CustomType {
55
56
size_t MVal = 0 ;
56
57
};
57
58
59
+ template <class CompT , class T > struct RadixSorterType ;
60
+
61
+ template <class T > struct RadixSorterType <std::greater<T>, T> {
62
+ using Type =
63
+ oneapi_exp::radix_sorter<T, oneapi_exp::sorting_order::descending>;
64
+ };
65
+
66
+ template <class T > struct RadixSorterType <std::less<T>, T> {
67
+ using Type =
68
+ oneapi_exp::radix_sorter<T, oneapi_exp::sorting_order::ascending>;
69
+ };
70
+
71
+ // Dummy overloads for CustomType which is not supported by radix sorter
72
+ template <> struct RadixSorterType <std::less<CustomType>, CustomType> {
73
+ using Type =
74
+ oneapi_exp::radix_sorter<int , oneapi_exp::sorting_order::ascending>;
75
+ };
76
+
77
+ template <> struct RadixSorterType <std::greater<CustomType>, CustomType> {
78
+ using Type =
79
+ oneapi_exp::radix_sorter<int , oneapi_exp::sorting_order::descending>;
80
+ };
81
+
58
82
constexpr size_t ReqSubGroupSize = 8 ;
59
83
60
84
template <UseGroupT UseGroup, int Dims, class T , class Compare >
@@ -68,17 +92,25 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
68
92
69
93
constexpr size_t NumSubGroups = WGSize / ReqSubGroupSize;
70
94
71
- std::size_t LocalMemorySize = 0 ;
72
- if (UseGroup == UseGroupT::SubGroup)
95
+ using RadixSorterT = typename RadixSorterType<Compare, T>::Type;
96
+
97
+ std::size_t LocalMemorySizeDefault = 0 ;
98
+ std::size_t LocalMemorySizeRadix = 0 ;
99
+ if (UseGroup == UseGroupT::SubGroup) {
73
100
// Each sub-group needs a piece of memory for sorting
74
- LocalMemorySize =
101
+ LocalMemorySizeDefault =
75
102
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
76
103
sycl::memory_scope::sub_group, ReqSubGroupSize * ElemsPerWI);
77
- else
104
+ LocalMemorySizeRadix = RadixSorterT::memory_required (
105
+ sycl::memory_scope::sub_group, ReqSubGroupSize * ElemsPerWI);
106
+ } else {
78
107
// A single chunk of memory for each work-group
79
- LocalMemorySize =
108
+ LocalMemorySizeDefault =
80
109
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
81
110
sycl::memory_scope::work_group, WGSize * ElemsPerWI);
111
+ LocalMemorySizeRadix = RadixSorterT::memory_required (
112
+ sycl::memory_scope::sub_group, WGSize * ElemsPerWI);
113
+ }
82
114
83
115
const sycl::nd_range<Dims> NDRange = [&]() {
84
116
if constexpr (Dims == 1 )
@@ -92,23 +124,36 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
92
124
std::vector<T> DataToSortCase0 = DataToSort;
93
125
std::vector<T> DataToSortCase1 = DataToSort;
94
126
std::vector<T> DataToSortCase2 = DataToSort;
127
+ std::vector<T> DataToSortCase3 = DataToSort;
95
128
96
129
// Sort data using 3 different versions of joint_sort API
97
130
{
98
131
sycl::buffer<T> BufToSort0 (DataToSortCase0.data (), DataToSortCase0.size ());
99
132
sycl::buffer<T> BufToSort1 (DataToSortCase1.data (), DataToSortCase1.size ());
100
133
sycl::buffer<T> BufToSort2 (DataToSortCase2.data (), DataToSortCase2.size ());
134
+ sycl::buffer<T> BufToSort3 (DataToSortCase3.data (), DataToSortCase3.size ());
101
135
102
136
Q.submit ([&](sycl::handler &CGH) {
103
137
auto AccToSort0 = sycl::accessor (BufToSort0, CGH);
104
138
auto AccToSort1 = sycl::accessor (BufToSort1, CGH);
105
139
auto AccToSort2 = sycl::accessor (BufToSort2, CGH);
140
+ auto AccToSort3 = sycl::accessor (BufToSort3, CGH);
106
141
107
142
// Allocate local memory for all sub-groups in a work-group
108
- const size_t TotalLocalMemSize = UseGroup == UseGroupT::SubGroup
109
- ? LocalMemorySize * NumSubGroups
110
- : LocalMemorySize;
111
- sycl::local_accessor<std::byte, 1 > Scratch ({TotalLocalMemSize}, CGH);
143
+ const size_t TotalLocalMemSizeDefault =
144
+ UseGroup == UseGroupT::SubGroup
145
+ ? LocalMemorySizeDefault * NumSubGroups
146
+ : LocalMemorySizeDefault;
147
+
148
+ const size_t TotalLocalMemSizeRadix =
149
+ UseGroup == UseGroupT::SubGroup ? LocalMemorySizeRadix * NumSubGroups
150
+ : LocalMemorySizeRadix;
151
+
152
+ sycl::local_accessor<std::byte, 1 > ScratchDefault (
153
+ {TotalLocalMemSizeDefault}, CGH);
154
+
155
+ sycl::local_accessor<std::byte, 1 > ScratchRadix ({TotalLocalMemSizeRadix},
156
+ CGH);
112
157
113
158
CGH.parallel_for <KernelNameJoint<IntWrapper<Dims>,
114
159
UseGroupWrapper<UseGroup>, T, Compare>>(
@@ -130,7 +175,7 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
130
175
: WGID;
131
176
const size_t LocalPartID =
132
177
UseGroup == UseGroupT::SubGroup
133
- ? LocalMemorySize * Group.get_group_linear_id ()
178
+ ? LocalMemorySizeDefault * Group.get_group_linear_id ()
134
179
: 0 ;
135
180
136
181
const size_t StartIdx = ChunkSize * PartID;
@@ -141,19 +186,32 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
141
186
if constexpr (std::is_same_v<Compare, std::less<T>>)
142
187
oneapi_exp::joint_sort (
143
188
oneapi_exp::group_with_scratchpad (
144
- Group,
145
- sycl::span{&Scratch[LocalPartID], LocalMemorySize }),
189
+ Group, sycl::span{&ScratchDefault[LocalPartID],
190
+ LocalMemorySizeDefault }),
146
191
&AccToSort0[StartIdx], &AccToSort0[EndIdx]);
147
192
148
193
oneapi_exp::joint_sort (
149
194
oneapi_exp::group_with_scratchpad (
150
- Group, sycl::span{&Scratch[LocalPartID], LocalMemorySize}),
195
+ Group, sycl::span{&ScratchDefault[LocalPartID],
196
+ LocalMemorySizeDefault}),
151
197
&AccToSort1[StartIdx], &AccToSort1[EndIdx], Comp);
152
198
153
199
oneapi_exp::joint_sort (
154
200
Group, &AccToSort2[StartIdx], &AccToSort2[EndIdx],
155
- oneapi_exp::default_sorter<Compare>(
156
- sycl::span{&Scratch[LocalPartID], LocalMemorySize}));
201
+ oneapi_exp::default_sorter<Compare>(sycl::span{
202
+ &ScratchDefault[LocalPartID], LocalMemorySizeDefault}));
203
+
204
+ const size_t LocalPartIDRadix =
205
+ UseGroup == UseGroupT::SubGroup
206
+ ? LocalMemorySizeRadix * Group.get_group_linear_id ()
207
+ : 0 ;
208
+
209
+ // Radix doesn't support custom types
210
+ if constexpr (!std::is_same_v<CustomType, T>)
211
+ oneapi_exp::joint_sort (
212
+ Group, &AccToSort3[StartIdx], &AccToSort3[EndIdx],
213
+ RadixSorterT (sycl::span{&ScratchRadix[LocalPartIDRadix],
214
+ LocalMemorySizeRadix}));
157
215
});
158
216
}).wait_and_throw ();
159
217
}
@@ -178,6 +236,9 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
178
236
179
237
assert (DataToSortCase1 == DataSorted);
180
238
assert (DataToSortCase2 == DataSorted);
239
+ // Radix doesn't support custom types
240
+ if constexpr (!std::is_same_v<CustomType, T>)
241
+ assert (DataToSortCase3 == DataSorted);
181
242
}
182
243
}
183
244
@@ -197,77 +258,116 @@ void RunSortOVerGroup(sycl::queue &Q, const std::vector<T> &DataToSort,
197
258
" Only one and two dimensional kernels are supported" );
198
259
}();
199
260
200
- std::size_t LocalMemorySize = 0 ;
201
- if (UseGroup == UseGroupT::SubGroup)
261
+ using RadixSorterT = typename RadixSorterType<Compare, T>::Type;
262
+
263
+ std::size_t LocalMemorySizeDefault = 0 ;
264
+ std::size_t LocalMemorySizeRadix = 0 ;
265
+ if (UseGroup == UseGroupT::SubGroup) {
202
266
// Each sub-group needs a piece of memory for sorting
203
- LocalMemorySize =
267
+ LocalMemorySizeDefault =
204
268
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
205
269
sycl::memory_scope::sub_group, sycl::range<1 >{ReqSubGroupSize});
206
- else
270
+
271
+ LocalMemorySizeRadix = RadixSorterT::template memory_required (
272
+ sycl::memory_scope::sub_group, sycl::range<1 >{ReqSubGroupSize});
273
+ } else {
207
274
// A single chunk of memory for each work-group
208
- LocalMemorySize =
275
+ LocalMemorySizeDefault =
209
276
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
210
277
sycl::memory_scope::work_group, sycl::range<1 >{NumOfElements});
211
278
279
+ LocalMemorySizeRadix = RadixSorterT::template memory_required (
280
+ sycl::memory_scope::work_group, sycl::range<1 >{NumOfElements});
281
+ }
282
+
212
283
std::vector<T> DataToSortCase0 = DataToSort;
213
284
std::vector<T> DataToSortCase1 = DataToSort;
214
285
std::vector<T> DataToSortCase2 = DataToSort;
286
+ std::vector<T> DataToSortCase3 = DataToSort;
215
287
216
288
// Sort data using 3 different versions of sort_over_group API
217
289
{
218
290
sycl::buffer<T> BufToSort0 (DataToSortCase0.data (), DataToSortCase0.size ());
219
291
sycl::buffer<T> BufToSort1 (DataToSortCase1.data (), DataToSortCase1.size ());
220
292
sycl::buffer<T> BufToSort2 (DataToSortCase2.data (), DataToSortCase2.size ());
293
+ sycl::buffer<T> BufToSort3 (DataToSortCase3.data (), DataToSortCase3.size ());
221
294
222
295
Q.submit ([&](sycl::handler &CGH) {
223
296
auto AccToSort0 = sycl::accessor (BufToSort0, CGH);
224
297
auto AccToSort1 = sycl::accessor (BufToSort1, CGH);
225
298
auto AccToSort2 = sycl::accessor (BufToSort2, CGH);
299
+ auto AccToSort3 = sycl::accessor (BufToSort3, CGH);
226
300
227
301
// Allocate local memory for all sub-groups in a work-group
228
- const size_t TotalLocalMemSize = UseGroup == UseGroupT::SubGroup
229
- ? LocalMemorySize * NumSubGroups
230
- : LocalMemorySize;
231
- sycl::local_accessor<std::byte, 1 > Scratch ({TotalLocalMemSize}, CGH);
302
+ const size_t TotalLocalMemSizeDefault =
303
+ UseGroup == UseGroupT::SubGroup
304
+ ? LocalMemorySizeDefault * NumSubGroups
305
+ : LocalMemorySizeDefault;
306
+ sycl::local_accessor<std::byte, 1 > ScratchDefault (
307
+ {TotalLocalMemSizeDefault}, CGH);
308
+
309
+ const size_t TotalLocalMemSizeRadix =
310
+ UseGroup == UseGroupT::SubGroup ? LocalMemorySizeRadix * NumSubGroups
311
+ : LocalMemorySizeRadix;
312
+
313
+ sycl::local_accessor<std::byte, 1 > ScratchRadix ({TotalLocalMemSizeRadix},
314
+ CGH);
232
315
233
316
CGH.parallel_for <KernelNameOverGroup<
234
317
IntWrapper<Dims>, UseGroupWrapper<UseGroup>, T, Compare>>(
235
- NDRange,
236
- [=](sycl::nd_item<Dims> id)
237
- [[intel::reqd_sub_group_size (ReqSubGroupSize)]] {
238
- const size_t GlobalLinearID = id.get_global_linear_id ();
239
-
240
- auto Group = [&]() {
241
- if constexpr (UseGroup == UseGroupT::SubGroup)
242
- return id.get_sub_group ();
243
- else
244
- return id.get_group ();
245
- }();
246
-
247
- // Each sub-group should use it's own part of the scratch pad
248
- const size_t ScratchShift =
249
- UseGroup == UseGroupT::SubGroup
250
- ? id.get_sub_group ().get_group_linear_id () *
251
- LocalMemorySize
252
- : 0 ;
253
- std::byte *ScratchPtr = &Scratch[0 ] + ScratchShift;
254
-
255
- if constexpr (std::is_same_v<Compare, std::less<T>>)
256
- AccToSort0[GlobalLinearID] = oneapi_exp::sort_over_group (
257
- oneapi_exp::group_with_scratchpad (
258
- Group, sycl::span{ScratchPtr, LocalMemorySize}),
259
- AccToSort0[GlobalLinearID]);
260
-
261
- AccToSort1[GlobalLinearID] = oneapi_exp::sort_over_group (
262
- oneapi_exp::group_with_scratchpad (
263
- Group, sycl::span{ScratchPtr, LocalMemorySize}),
264
- AccToSort1[GlobalLinearID], Comp);
265
-
266
- AccToSort2[GlobalLinearID] = oneapi_exp::sort_over_group (
267
- Group, AccToSort2[GlobalLinearID],
268
- oneapi_exp::default_sorter<Compare>(
269
- sycl::span{ScratchPtr, LocalMemorySize}));
270
- });
318
+ NDRange, [=](sycl::nd_item<Dims> id) [[intel::reqd_sub_group_size (
319
+ ReqSubGroupSize)]] {
320
+ const size_t GlobalLinearID = id.get_global_linear_id ();
321
+
322
+ auto Group = [&]() {
323
+ if constexpr (UseGroup == UseGroupT::SubGroup)
324
+ return id.get_sub_group ();
325
+ else
326
+ return id.get_group ();
327
+ }();
328
+
329
+ // Each sub-group should use it's own part of the scratch pad
330
+ const size_t ScratchShiftDefault =
331
+ UseGroup == UseGroupT::SubGroup
332
+ ? id.get_sub_group ().get_group_linear_id () *
333
+ LocalMemorySizeDefault
334
+ : 0 ;
335
+ std::byte *ScratchPtrDefault =
336
+ &ScratchDefault[0 ] + ScratchShiftDefault;
337
+
338
+ if constexpr (std::is_same_v<Compare, std::less<T>>)
339
+ AccToSort0[GlobalLinearID] = oneapi_exp::sort_over_group (
340
+ oneapi_exp::group_with_scratchpad (
341
+ Group,
342
+ sycl::span{ScratchPtrDefault, LocalMemorySizeDefault}),
343
+ AccToSort0[GlobalLinearID]);
344
+
345
+ AccToSort1[GlobalLinearID] = oneapi_exp::sort_over_group (
346
+ oneapi_exp::group_with_scratchpad (
347
+ Group,
348
+ sycl::span{ScratchPtrDefault, LocalMemorySizeDefault}),
349
+ AccToSort1[GlobalLinearID], Comp);
350
+
351
+ AccToSort2[GlobalLinearID] = oneapi_exp::sort_over_group (
352
+ Group, AccToSort2[GlobalLinearID],
353
+ oneapi_exp::default_sorter<Compare>(
354
+ sycl::span{ScratchPtrDefault, LocalMemorySizeDefault}));
355
+
356
+ // Each sub-group should use it's own part of the scratch pad
357
+ const size_t ScratchShiftRadix =
358
+ UseGroup == UseGroupT::SubGroup
359
+ ? id.get_sub_group ().get_group_linear_id () *
360
+ LocalMemorySizeRadix
361
+ : 0 ;
362
+ std::byte *ScratchPtrRadix = &ScratchRadix[0 ] + ScratchShiftRadix;
363
+
364
+ // Radix doesn't support custom types
365
+ if constexpr (!std::is_same_v<CustomType, T>)
366
+ AccToSort3[GlobalLinearID] = oneapi_exp::sort_over_group (
367
+ Group, AccToSort3[GlobalLinearID],
368
+ RadixSorterT (
369
+ sycl::span{ScratchPtrRadix, LocalMemorySizeRadix}));
370
+ });
271
371
}).wait_and_throw ();
272
372
}
273
373
@@ -290,6 +390,9 @@ void RunSortOVerGroup(sycl::queue &Q, const std::vector<T> &DataToSort,
290
390
291
391
assert (DataToSortCase1 == DataSorted);
292
392
assert (DataToSortCase2 == DataSorted);
393
+ // Radix doesn't support custom types
394
+ if constexpr (!std::is_same_v<CustomType, T>)
395
+ assert (DataToSortCase3 == DataSorted);
293
396
}
294
397
}
295
398
0 commit comments