Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

[SYCL][LIT] Add test for radix sorter. #1435

Merged
merged 3 commits into from
Dec 14, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 163 additions & 60 deletions SYCL/GroupAlgorithm/SYCL2020/sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// For comparators {std::less, std::greater}
// For dimensions {1, 2}
// For group {work-group, sub-group}
// For sorters {default_sorter, radix_sorter}
// joint_sort with
// WG size = {16} or {1, 16}
// SG size = {8}
Expand Down Expand Up @@ -55,6 +56,29 @@ class CustomType {
size_t MVal = 0;
};

template <class CompT, class T> struct RadixSorterType;

template <class T> struct RadixSorterType<std::greater<T>, T> {
using Type =
oneapi_exp::radix_sorter<T, oneapi_exp::sorting_order::descending>;
};

template <class T> struct RadixSorterType<std::less<T>, T> {
using Type =
oneapi_exp::radix_sorter<T, oneapi_exp::sorting_order::ascending>;
};

// Dummy overloads for CustomType which is not supported by radix sorter
template <> struct RadixSorterType<std::less<CustomType>, CustomType> {
using Type =
oneapi_exp::radix_sorter<int, oneapi_exp::sorting_order::ascending>;
};

template <> struct RadixSorterType<std::greater<CustomType>, CustomType> {
using Type =
oneapi_exp::radix_sorter<int, oneapi_exp::sorting_order::descending>;
};

constexpr size_t ReqSubGroupSize = 8;

template <UseGroupT UseGroup, int Dims, class T, class Compare>
Expand All @@ -68,17 +92,25 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,

constexpr size_t NumSubGroups = WGSize / ReqSubGroupSize;

std::size_t LocalMemorySize = 0;
if (UseGroup == UseGroupT::SubGroup)
using RadixSorterT = typename RadixSorterType<Compare, T>::Type;

std::size_t LocalMemorySizeDefault = 0;
std::size_t LocalMemorySizeRadix = 0;
if (UseGroup == UseGroupT::SubGroup) {
// Each sub-group needs a piece of memory for sorting
LocalMemorySize =
LocalMemorySizeDefault =
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
sycl::memory_scope::sub_group, ReqSubGroupSize * ElemsPerWI);
else
LocalMemorySizeRadix = RadixSorterT::memory_required(
sycl::memory_scope::sub_group, ReqSubGroupSize * ElemsPerWI);
} else {
// A single chunk of memory for each work-group
LocalMemorySize =
LocalMemorySizeDefault =
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
sycl::memory_scope::work_group, WGSize * ElemsPerWI);
LocalMemorySizeRadix = RadixSorterT::memory_required(
sycl::memory_scope::sub_group, WGSize * ElemsPerWI);
}

const sycl::nd_range<Dims> NDRange = [&]() {
if constexpr (Dims == 1)
Expand All @@ -92,23 +124,36 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
std::vector<T> DataToSortCase0 = DataToSort;
std::vector<T> DataToSortCase1 = DataToSort;
std::vector<T> DataToSortCase2 = DataToSort;
std::vector<T> DataToSortCase3 = DataToSort;

// Sort data using 3 different versions of joint_sort API
{
sycl::buffer<T> BufToSort0(DataToSortCase0.data(), DataToSortCase0.size());
sycl::buffer<T> BufToSort1(DataToSortCase1.data(), DataToSortCase1.size());
sycl::buffer<T> BufToSort2(DataToSortCase2.data(), DataToSortCase2.size());
sycl::buffer<T> BufToSort3(DataToSortCase3.data(), DataToSortCase3.size());

Q.submit([&](sycl::handler &CGH) {
auto AccToSort0 = sycl::accessor(BufToSort0, CGH);
auto AccToSort1 = sycl::accessor(BufToSort1, CGH);
auto AccToSort2 = sycl::accessor(BufToSort2, CGH);
auto AccToSort3 = sycl::accessor(BufToSort3, CGH);

// Allocate local memory for all sub-groups in a work-group
const size_t TotalLocalMemSize = UseGroup == UseGroupT::SubGroup
? LocalMemorySize * NumSubGroups
: LocalMemorySize;
sycl::local_accessor<std::byte, 1> Scratch({TotalLocalMemSize}, CGH);
const size_t TotalLocalMemSizeDefault =
UseGroup == UseGroupT::SubGroup
? LocalMemorySizeDefault * NumSubGroups
: LocalMemorySizeDefault;

const size_t TotalLocalMemSizeRadix =
UseGroup == UseGroupT::SubGroup ? LocalMemorySizeRadix * NumSubGroups
: LocalMemorySizeRadix;

sycl::local_accessor<std::byte, 1> ScratchDefault(
{TotalLocalMemSizeDefault}, CGH);

sycl::local_accessor<std::byte, 1> ScratchRadix({TotalLocalMemSizeRadix},
CGH);

CGH.parallel_for<KernelNameJoint<IntWrapper<Dims>,
UseGroupWrapper<UseGroup>, T, Compare>>(
Expand All @@ -130,7 +175,7 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
: WGID;
const size_t LocalPartID =
UseGroup == UseGroupT::SubGroup
? LocalMemorySize * Group.get_group_linear_id()
? LocalMemorySizeDefault * Group.get_group_linear_id()
: 0;

const size_t StartIdx = ChunkSize * PartID;
Expand All @@ -141,19 +186,32 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
if constexpr (std::is_same_v<Compare, std::less<T>>)
oneapi_exp::joint_sort(
oneapi_exp::group_with_scratchpad(
Group,
sycl::span{&Scratch[LocalPartID], LocalMemorySize}),
Group, sycl::span{&ScratchDefault[LocalPartID],
LocalMemorySizeDefault}),
&AccToSort0[StartIdx], &AccToSort0[EndIdx]);

oneapi_exp::joint_sort(
oneapi_exp::group_with_scratchpad(
Group, sycl::span{&Scratch[LocalPartID], LocalMemorySize}),
Group, sycl::span{&ScratchDefault[LocalPartID],
LocalMemorySizeDefault}),
&AccToSort1[StartIdx], &AccToSort1[EndIdx], Comp);

oneapi_exp::joint_sort(
Group, &AccToSort2[StartIdx], &AccToSort2[EndIdx],
oneapi_exp::default_sorter<Compare>(
sycl::span{&Scratch[LocalPartID], LocalMemorySize}));
oneapi_exp::default_sorter<Compare>(sycl::span{
&ScratchDefault[LocalPartID], LocalMemorySizeDefault}));

const size_t LocalPartIDRadix =
UseGroup == UseGroupT::SubGroup
? LocalMemorySizeRadix * Group.get_group_linear_id()
: 0;

// Radix doesn't support custom types
if constexpr (!std::is_same_v<CustomType, T>)
oneapi_exp::joint_sort(
Group, &AccToSort3[StartIdx], &AccToSort3[EndIdx],
RadixSorterT(sycl::span{&ScratchRadix[LocalPartIDRadix],
LocalMemorySizeRadix}));
});
}).wait_and_throw();
}
Expand All @@ -178,6 +236,9 @@ void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,

assert(DataToSortCase1 == DataSorted);
assert(DataToSortCase2 == DataSorted);
// Radix doesn't support custom types
if constexpr (!std::is_same_v<CustomType, T>)
assert(DataToSortCase3 == DataSorted);
}
}

Expand All @@ -197,77 +258,116 @@ void RunSortOVerGroup(sycl::queue &Q, const std::vector<T> &DataToSort,
"Only one and two dimensional kernels are supported");
}();

std::size_t LocalMemorySize = 0;
if (UseGroup == UseGroupT::SubGroup)
using RadixSorterT = typename RadixSorterType<Compare, T>::Type;

std::size_t LocalMemorySizeDefault = 0;
std::size_t LocalMemorySizeRadix = 0;
if (UseGroup == UseGroupT::SubGroup) {
// Each sub-group needs a piece of memory for sorting
LocalMemorySize =
LocalMemorySizeDefault =
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
sycl::memory_scope::sub_group, sycl::range<1>{ReqSubGroupSize});
else

LocalMemorySizeRadix = RadixSorterT::template memory_required(
sycl::memory_scope::sub_group, sycl::range<1>{ReqSubGroupSize});
} else {
// A single chunk of memory for each work-group
LocalMemorySize =
LocalMemorySizeDefault =
oneapi_exp::default_sorter<Compare>::template memory_required<T>(
sycl::memory_scope::work_group, sycl::range<1>{NumOfElements});

LocalMemorySizeRadix = RadixSorterT::template memory_required(
sycl::memory_scope::work_group, sycl::range<1>{NumOfElements});
}

std::vector<T> DataToSortCase0 = DataToSort;
std::vector<T> DataToSortCase1 = DataToSort;
std::vector<T> DataToSortCase2 = DataToSort;
std::vector<T> DataToSortCase3 = DataToSort;

// Sort data using 3 different versions of sort_over_group API
{
sycl::buffer<T> BufToSort0(DataToSortCase0.data(), DataToSortCase0.size());
sycl::buffer<T> BufToSort1(DataToSortCase1.data(), DataToSortCase1.size());
sycl::buffer<T> BufToSort2(DataToSortCase2.data(), DataToSortCase2.size());
sycl::buffer<T> BufToSort3(DataToSortCase3.data(), DataToSortCase3.size());

Q.submit([&](sycl::handler &CGH) {
auto AccToSort0 = sycl::accessor(BufToSort0, CGH);
auto AccToSort1 = sycl::accessor(BufToSort1, CGH);
auto AccToSort2 = sycl::accessor(BufToSort2, CGH);
auto AccToSort3 = sycl::accessor(BufToSort3, CGH);

// Allocate local memory for all sub-groups in a work-group
const size_t TotalLocalMemSize = UseGroup == UseGroupT::SubGroup
? LocalMemorySize * NumSubGroups
: LocalMemorySize;
sycl::local_accessor<std::byte, 1> Scratch({TotalLocalMemSize}, CGH);
const size_t TotalLocalMemSizeDefault =
UseGroup == UseGroupT::SubGroup
? LocalMemorySizeDefault * NumSubGroups
: LocalMemorySizeDefault;
sycl::local_accessor<std::byte, 1> ScratchDefault(
{TotalLocalMemSizeDefault}, CGH);

const size_t TotalLocalMemSizeRadix =
UseGroup == UseGroupT::SubGroup ? LocalMemorySizeRadix * NumSubGroups
: LocalMemorySizeRadix;

sycl::local_accessor<std::byte, 1> ScratchRadix({TotalLocalMemSizeRadix},
CGH);

CGH.parallel_for<KernelNameOverGroup<
IntWrapper<Dims>, UseGroupWrapper<UseGroup>, T, Compare>>(
NDRange,
[=](sycl::nd_item<Dims> id)
[[intel::reqd_sub_group_size(ReqSubGroupSize)]] {
const size_t GlobalLinearID = id.get_global_linear_id();

auto Group = [&]() {
if constexpr (UseGroup == UseGroupT::SubGroup)
return id.get_sub_group();
else
return id.get_group();
}();

// Each sub-group should use it's own part of the scratch pad
const size_t ScratchShift =
UseGroup == UseGroupT::SubGroup
? id.get_sub_group().get_group_linear_id() *
LocalMemorySize
: 0;
std::byte *ScratchPtr = &Scratch[0] + ScratchShift;

if constexpr (std::is_same_v<Compare, std::less<T>>)
AccToSort0[GlobalLinearID] = oneapi_exp::sort_over_group(
oneapi_exp::group_with_scratchpad(
Group, sycl::span{ScratchPtr, LocalMemorySize}),
AccToSort0[GlobalLinearID]);

AccToSort1[GlobalLinearID] = oneapi_exp::sort_over_group(
oneapi_exp::group_with_scratchpad(
Group, sycl::span{ScratchPtr, LocalMemorySize}),
AccToSort1[GlobalLinearID], Comp);

AccToSort2[GlobalLinearID] = oneapi_exp::sort_over_group(
Group, AccToSort2[GlobalLinearID],
oneapi_exp::default_sorter<Compare>(
sycl::span{ScratchPtr, LocalMemorySize}));
});
NDRange, [=](sycl::nd_item<Dims> id) [[intel::reqd_sub_group_size(
ReqSubGroupSize)]] {
const size_t GlobalLinearID = id.get_global_linear_id();

auto Group = [&]() {
if constexpr (UseGroup == UseGroupT::SubGroup)
return id.get_sub_group();
else
return id.get_group();
}();

// Each sub-group should use it's own part of the scratch pad
const size_t ScratchShiftDefault =
UseGroup == UseGroupT::SubGroup
? id.get_sub_group().get_group_linear_id() *
LocalMemorySizeDefault
: 0;
std::byte *ScratchPtrDefault =
&ScratchDefault[0] + ScratchShiftDefault;

if constexpr (std::is_same_v<Compare, std::less<T>>)
AccToSort0[GlobalLinearID] = oneapi_exp::sort_over_group(
oneapi_exp::group_with_scratchpad(
Group,
sycl::span{ScratchPtrDefault, LocalMemorySizeDefault}),
AccToSort0[GlobalLinearID]);

AccToSort1[GlobalLinearID] = oneapi_exp::sort_over_group(
oneapi_exp::group_with_scratchpad(
Group,
sycl::span{ScratchPtrDefault, LocalMemorySizeDefault}),
AccToSort1[GlobalLinearID], Comp);

AccToSort2[GlobalLinearID] = oneapi_exp::sort_over_group(
Group, AccToSort2[GlobalLinearID],
oneapi_exp::default_sorter<Compare>(
sycl::span{ScratchPtrDefault, LocalMemorySizeDefault}));

// Each sub-group should use it's own part of the scratch pad
const size_t ScratchShiftRadix =
UseGroup == UseGroupT::SubGroup
? id.get_sub_group().get_group_linear_id() *
LocalMemorySizeRadix
: 0;
std::byte *ScratchPtrRadix = &ScratchRadix[0] + ScratchShiftRadix;

// Radix doesn't support custom types
if constexpr (!std::is_same_v<CustomType, T>)
AccToSort3[GlobalLinearID] = oneapi_exp::sort_over_group(
Group, AccToSort3[GlobalLinearID],
RadixSorterT(
sycl::span{ScratchPtrRadix, LocalMemorySizeRadix}));
});
}).wait_and_throw();
}

Expand All @@ -290,6 +390,9 @@ void RunSortOVerGroup(sycl::queue &Q, const std::vector<T> &DataToSort,

assert(DataToSortCase1 == DataSorted);
assert(DataToSortCase2 == DataSorted);
// Radix doesn't support custom types
if constexpr (!std::is_same_v<CustomType, T>)
assert(DataToSortCase3 == DataSorted);
}
}

Expand Down