Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 0db4626

Browse files
authored
[SYCL][LIT] Add test for radix sorter (#1435)
1 parent a60d6b7 commit 0db4626

File tree

1 file changed

+163
-60
lines changed

1 file changed

+163
-60
lines changed

SYCL/GroupAlgorithm/SYCL2020/sort.cpp

Lines changed: 163 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// For comparators {std::less, std::greater}
1111
// For dimensions {1, 2}
1212
// For group {work-group, sub-group}
13+
// For sorters {default_sorter, radix_sorter}
1314
// joint_sort with
1415
// WG size = {16} or {1, 16}
1516
// SG size = {8}
@@ -55,6 +56,29 @@ class CustomType {
5556
size_t MVal = 0;
5657
};
5758

59+
template <class CompT, class T> struct RadixSorterType;
60+
61+
template <class T> struct RadixSorterType<std::greater<T>, T> {
62+
using Type =
63+
oneapi_exp::radix_sorter<T, oneapi_exp::sorting_order::descending>;
64+
};
65+
66+
template <class T> struct RadixSorterType<std::less<T>, T> {
67+
using Type =
68+
oneapi_exp::radix_sorter<T, oneapi_exp::sorting_order::ascending>;
69+
};
70+
71+
// Dummy overloads for CustomType which is not supported by radix sorter
72+
template <> struct RadixSorterType<std::less<CustomType>, CustomType> {
73+
using Type =
74+
oneapi_exp::radix_sorter<int, oneapi_exp::sorting_order::ascending>;
75+
};
76+
77+
template <> struct RadixSorterType<std::greater<CustomType>, CustomType> {
78+
using Type =
79+
oneapi_exp::radix_sorter<int, oneapi_exp::sorting_order::descending>;
80+
};
81+
5882
constexpr size_t ReqSubGroupSize = 8;
5983

6084
template <UseGroupT UseGroup, int Dims, class T, class Compare>
@@ -68,17 +92,25 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
6892

6993
constexpr size_t NumSubGroups = WGSize / ReqSubGroupSize;
7094

71-
std::size_t LocalMemorySize = 0;
72-
if (UseGroup == UseGroupT::SubGroup)
95+
using RadixSorterT = typename RadixSorterType<Compare, T>::Type;
96+
97+
std::size_t LocalMemorySizeDefault = 0;
98+
std::size_t LocalMemorySizeRadix = 0;
99+
if (UseGroup == UseGroupT::SubGroup) {
73100
// Each sub-group needs a piece of memory for sorting
74-
LocalMemorySize =
101+
LocalMemorySizeDefault =
75102
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
76103
sycl::memory_scope::sub_group, ReqSubGroupSize * ElemsPerWI);
77-
else
104+
LocalMemorySizeRadix = RadixSorterT::memory_required(
105+
sycl::memory_scope::sub_group, ReqSubGroupSize * ElemsPerWI);
106+
} else {
78107
// A single chunk of memory for each work-group
79-
LocalMemorySize =
108+
LocalMemorySizeDefault =
80109
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
81110
sycl::memory_scope::work_group, WGSize * ElemsPerWI);
111+
LocalMemorySizeRadix = RadixSorterT::memory_required(
112+
sycl::memory_scope::sub_group, WGSize * ElemsPerWI);
113+
}
82114

83115
const sycl::nd_range<Dims> NDRange = [&]() {
84116
if constexpr (Dims == 1)
@@ -92,23 +124,36 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
92124
std::vector<T> DataToSortCase0 = DataToSort;
93125
std::vector<T> DataToSortCase1 = DataToSort;
94126
std::vector<T> DataToSortCase2 = DataToSort;
127+
std::vector<T> DataToSortCase3 = DataToSort;
95128

96129
// Sort data using 3 different versions of joint_sort API
97130
{
98131
sycl::buffer<T> BufToSort0(DataToSortCase0.data(), DataToSortCase0.size());
99132
sycl::buffer<T> BufToSort1(DataToSortCase1.data(), DataToSortCase1.size());
100133
sycl::buffer<T> BufToSort2(DataToSortCase2.data(), DataToSortCase2.size());
134+
sycl::buffer<T> BufToSort3(DataToSortCase3.data(), DataToSortCase3.size());
101135

102136
Q.submit([&](sycl::handler &CGH) {
103137
auto AccToSort0 = sycl::accessor(BufToSort0, CGH);
104138
auto AccToSort1 = sycl::accessor(BufToSort1, CGH);
105139
auto AccToSort2 = sycl::accessor(BufToSort2, CGH);
140+
auto AccToSort3 = sycl::accessor(BufToSort3, CGH);
106141

107142
// Allocate local memory for all sub-groups in a work-group
108-
const size_t TotalLocalMemSize = UseGroup == UseGroupT::SubGroup
109-
? LocalMemorySize * NumSubGroups
110-
: LocalMemorySize;
111-
sycl::local_accessor<std::byte, 1> Scratch({TotalLocalMemSize}, CGH);
143+
const size_t TotalLocalMemSizeDefault =
144+
UseGroup == UseGroupT::SubGroup
145+
? LocalMemorySizeDefault * NumSubGroups
146+
: LocalMemorySizeDefault;
147+
148+
const size_t TotalLocalMemSizeRadix =
149+
UseGroup == UseGroupT::SubGroup ? LocalMemorySizeRadix * NumSubGroups
150+
: LocalMemorySizeRadix;
151+
152+
sycl::local_accessor<std::byte, 1> ScratchDefault(
153+
{TotalLocalMemSizeDefault}, CGH);
154+
155+
sycl::local_accessor<std::byte, 1> ScratchRadix({TotalLocalMemSizeRadix},
156+
CGH);
112157

113158
CGH.parallel_for<KernelNameJoint<IntWrapper<Dims>,
114159
UseGroupWrapper<UseGroup>, T, Compare>>(
@@ -130,7 +175,7 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
130175
: WGID;
131176
const size_t LocalPartID =
132177
UseGroup == UseGroupT::SubGroup
133-
? LocalMemorySize * Group.get_group_linear_id()
178+
? LocalMemorySizeDefault * Group.get_group_linear_id()
134179
: 0;
135180

136181
const size_t StartIdx = ChunkSize * PartID;
@@ -141,19 +186,32 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
141186
if constexpr (std::is_same_v<Compare, std::less<T>>)
142187
oneapi_exp::joint_sort(
143188
oneapi_exp::group_with_scratchpad(
144-
Group,
145-
sycl::span{&Scratch[LocalPartID], LocalMemorySize}),
189+
Group, sycl::span{&ScratchDefault[LocalPartID],
190+
LocalMemorySizeDefault}),
146191
&AccToSort0[StartIdx], &AccToSort0[EndIdx]);
147192

148193
oneapi_exp::joint_sort(
149194
oneapi_exp::group_with_scratchpad(
150-
Group, sycl::span{&Scratch[LocalPartID], LocalMemorySize}),
195+
Group, sycl::span{&ScratchDefault[LocalPartID],
196+
LocalMemorySizeDefault}),
151197
&AccToSort1[StartIdx], &AccToSort1[EndIdx], Comp);
152198

153199
oneapi_exp::joint_sort(
154200
Group, &AccToSort2[StartIdx], &AccToSort2[EndIdx],
155-
oneapi_exp::default_sorter<Compare>(
156-
sycl::span{&Scratch[LocalPartID], LocalMemorySize}));
201+
oneapi_exp::default_sorter<Compare>(sycl::span{
202+
&ScratchDefault[LocalPartID], LocalMemorySizeDefault}));
203+
204+
const size_t LocalPartIDRadix =
205+
UseGroup == UseGroupT::SubGroup
206+
? LocalMemorySizeRadix * Group.get_group_linear_id()
207+
: 0;
208+
209+
// Radix doesn't support custom types
210+
if constexpr (!std::is_same_v<CustomType, T>)
211+
oneapi_exp::joint_sort(
212+
Group, &AccToSort3[StartIdx], &AccToSort3[EndIdx],
213+
RadixSorterT(sycl::span{&ScratchRadix[LocalPartIDRadix],
214+
LocalMemorySizeRadix}));
157215
});
158216
}).wait_and_throw();
159217
}
@@ -178,6 +236,9 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
178236

179237
assert(DataToSortCase1 == DataSorted);
180238
assert(DataToSortCase2 == DataSorted);
239+
// Radix doesn't support custom types
240+
if constexpr (!std::is_same_v<CustomType, T>)
241+
assert(DataToSortCase3 == DataSorted);
181242
}
182243
}
183244

@@ -197,77 +258,116 @@ void RunSortOVerGroup(sycl::queue &Q, const std::vector<T> &DataToSort,
197258
"Only one and two dimensional kernels are supported");
198259
}();
199260

200-
std::size_t LocalMemorySize = 0;
201-
if (UseGroup == UseGroupT::SubGroup)
261+
using RadixSorterT = typename RadixSorterType<Compare, T>::Type;
262+
263+
std::size_t LocalMemorySizeDefault = 0;
264+
std::size_t LocalMemorySizeRadix = 0;
265+
if (UseGroup == UseGroupT::SubGroup) {
202266
// Each sub-group needs a piece of memory for sorting
203-
LocalMemorySize =
267+
LocalMemorySizeDefault =
204268
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
205269
sycl::memory_scope::sub_group, sycl::range<1>{ReqSubGroupSize});
206-
else
270+
271+
LocalMemorySizeRadix = RadixSorterT::template memory_required(
272+
sycl::memory_scope::sub_group, sycl::range<1>{ReqSubGroupSize});
273+
} else {
207274
// A single chunk of memory for each work-group
208-
LocalMemorySize =
275+
LocalMemorySizeDefault =
209276
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
210277
sycl::memory_scope::work_group, sycl::range<1>{NumOfElements});
211278

279+
LocalMemorySizeRadix = RadixSorterT::template memory_required(
280+
sycl::memory_scope::work_group, sycl::range<1>{NumOfElements});
281+
}
282+
212283
std::vector<T> DataToSortCase0 = DataToSort;
213284
std::vector<T> DataToSortCase1 = DataToSort;
214285
std::vector<T> DataToSortCase2 = DataToSort;
286+
std::vector<T> DataToSortCase3 = DataToSort;
215287

216288
// Sort data using 3 different versions of sort_over_group API
217289
{
218290
sycl::buffer<T> BufToSort0(DataToSortCase0.data(), DataToSortCase0.size());
219291
sycl::buffer<T> BufToSort1(DataToSortCase1.data(), DataToSortCase1.size());
220292
sycl::buffer<T> BufToSort2(DataToSortCase2.data(), DataToSortCase2.size());
293+
sycl::buffer<T> BufToSort3(DataToSortCase3.data(), DataToSortCase3.size());
221294

222295
Q.submit([&](sycl::handler &CGH) {
223296
auto AccToSort0 = sycl::accessor(BufToSort0, CGH);
224297
auto AccToSort1 = sycl::accessor(BufToSort1, CGH);
225298
auto AccToSort2 = sycl::accessor(BufToSort2, CGH);
299+
auto AccToSort3 = sycl::accessor(BufToSort3, CGH);
226300

227301
// Allocate local memory for all sub-groups in a work-group
228-
const size_t TotalLocalMemSize = UseGroup == UseGroupT::SubGroup
229-
? LocalMemorySize * NumSubGroups
230-
: LocalMemorySize;
231-
sycl::local_accessor<std::byte, 1> Scratch({TotalLocalMemSize}, CGH);
302+
const size_t TotalLocalMemSizeDefault =
303+
UseGroup == UseGroupT::SubGroup
304+
? LocalMemorySizeDefault * NumSubGroups
305+
: LocalMemorySizeDefault;
306+
sycl::local_accessor<std::byte, 1> ScratchDefault(
307+
{TotalLocalMemSizeDefault}, CGH);
308+
309+
const size_t TotalLocalMemSizeRadix =
310+
UseGroup == UseGroupT::SubGroup ? LocalMemorySizeRadix * NumSubGroups
311+
: LocalMemorySizeRadix;
312+
313+
sycl::local_accessor<std::byte, 1> ScratchRadix({TotalLocalMemSizeRadix},
314+
CGH);
232315

233316
CGH.parallel_for<KernelNameOverGroup<
234317
IntWrapper<Dims>, UseGroupWrapper<UseGroup>, T, Compare>>(
235-
NDRange,
236-
[=](sycl::nd_item<Dims> id)
237-
[[intel::reqd_sub_group_size(ReqSubGroupSize)]] {
238-
const size_t GlobalLinearID = id.get_global_linear_id();
239-
240-
auto Group = [&]() {
241-
if constexpr (UseGroup == UseGroupT::SubGroup)
242-
return id.get_sub_group();
243-
else
244-
return id.get_group();
245-
}();
246-
247-
// Each sub-group should use it's own part of the scratch pad
248-
const size_t ScratchShift =
249-
UseGroup == UseGroupT::SubGroup
250-
? id.get_sub_group().get_group_linear_id() *
251-
LocalMemorySize
252-
: 0;
253-
std::byte *ScratchPtr = &Scratch[0] + ScratchShift;
254-
255-
if constexpr (std::is_same_v<Compare, std::less<T>>)
256-
AccToSort0[GlobalLinearID] = oneapi_exp::sort_over_group(
257-
oneapi_exp::group_with_scratchpad(
258-
Group, sycl::span{ScratchPtr, LocalMemorySize}),
259-
AccToSort0[GlobalLinearID]);
260-
261-
AccToSort1[GlobalLinearID] = oneapi_exp::sort_over_group(
262-
oneapi_exp::group_with_scratchpad(
263-
Group, sycl::span{ScratchPtr, LocalMemorySize}),
264-
AccToSort1[GlobalLinearID], Comp);
265-
266-
AccToSort2[GlobalLinearID] = oneapi_exp::sort_over_group(
267-
Group, AccToSort2[GlobalLinearID],
268-
oneapi_exp::default_sorter<Compare>(
269-
sycl::span{ScratchPtr, LocalMemorySize}));
270-
});
318+
NDRange, [=](sycl::nd_item<Dims> id) [[intel::reqd_sub_group_size(
319+
ReqSubGroupSize)]] {
320+
const size_t GlobalLinearID = id.get_global_linear_id();
321+
322+
auto Group = [&]() {
323+
if constexpr (UseGroup == UseGroupT::SubGroup)
324+
return id.get_sub_group();
325+
else
326+
return id.get_group();
327+
}();
328+
329+
// Each sub-group should use it's own part of the scratch pad
330+
const size_t ScratchShiftDefault =
331+
UseGroup == UseGroupT::SubGroup
332+
? id.get_sub_group().get_group_linear_id() *
333+
LocalMemorySizeDefault
334+
: 0;
335+
std::byte *ScratchPtrDefault =
336+
&ScratchDefault[0] + ScratchShiftDefault;
337+
338+
if constexpr (std::is_same_v<Compare, std::less<T>>)
339+
AccToSort0[GlobalLinearID] = oneapi_exp::sort_over_group(
340+
oneapi_exp::group_with_scratchpad(
341+
Group,
342+
sycl::span{ScratchPtrDefault, LocalMemorySizeDefault}),
343+
AccToSort0[GlobalLinearID]);
344+
345+
AccToSort1[GlobalLinearID] = oneapi_exp::sort_over_group(
346+
oneapi_exp::group_with_scratchpad(
347+
Group,
348+
sycl::span{ScratchPtrDefault, LocalMemorySizeDefault}),
349+
AccToSort1[GlobalLinearID], Comp);
350+
351+
AccToSort2[GlobalLinearID] = oneapi_exp::sort_over_group(
352+
Group, AccToSort2[GlobalLinearID],
353+
oneapi_exp::default_sorter<Compare>(
354+
sycl::span{ScratchPtrDefault, LocalMemorySizeDefault}));
355+
356+
// Each sub-group should use it's own part of the scratch pad
357+
const size_t ScratchShiftRadix =
358+
UseGroup == UseGroupT::SubGroup
359+
? id.get_sub_group().get_group_linear_id() *
360+
LocalMemorySizeRadix
361+
: 0;
362+
std::byte *ScratchPtrRadix = &ScratchRadix[0] + ScratchShiftRadix;
363+
364+
// Radix doesn't support custom types
365+
if constexpr (!std::is_same_v<CustomType, T>)
366+
AccToSort3[GlobalLinearID] = oneapi_exp::sort_over_group(
367+
Group, AccToSort3[GlobalLinearID],
368+
RadixSorterT(
369+
sycl::span{ScratchPtrRadix, LocalMemorySizeRadix}));
370+
});
271371
}).wait_and_throw();
272372
}
273373

@@ -290,6 +390,9 @@ void RunSortOVerGroup(sycl::queue &Q, const std::vector<T> &DataToSort,
290390

291391
assert(DataToSortCase1 == DataSorted);
292392
assert(DataToSortCase2 == DataSorted);
393+
// Radix doesn't support custom types
394+
if constexpr (!std::is_same_v<CustomType, T>)
395+
assert(DataToSortCase3 == DataSorted);
293396
}
294397
}
295398

0 commit comments

Comments
 (0)