Skip to content

[SYCL] Key/Value sorting with fixed-size private array input #14399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 79 additions & 23 deletions sycl/include/sycl/detail/group_sort_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
#include <climits>

#include <sycl/builtins.hpp>
#include <sycl/detail/key_value_iterator.hpp>
#include <sycl/group_algorithm.hpp>
#include <sycl/group_barrier.hpp>
#include <sycl/sycl_span.hpp>

#include <iterator>
#include <memory>

namespace sycl {
Expand Down Expand Up @@ -52,8 +54,46 @@ static __SYCL_ALWAYS_INLINE T *align_scratch(sycl::span<std::byte> scratch,
scratch_begin = sycl::group_broadcast(g, scratch_begin);
return scratch_begin;
}

template <typename KeyTy, typename ValueTy, typename Group>
static __SYCL_ALWAYS_INLINE std::pair<KeyTy *, ValueTy *>
align_key_value_scratch(sycl::span<std::byte> scratch, Group g,
size_t number_of_elements) {
size_t KeysSize = number_of_elements * sizeof(KeyTy);
size_t ValuesSize = number_of_elements * sizeof(ValueTy);
size_t KeysScratchSpace = KeysSize + alignof(KeyTy);
size_t ValuesScratchSpace = ValuesSize + alignof(ValueTy);

KeyTy *keys_scratch_begin = nullptr;
ValueTy *values_scratch_begin = nullptr;
sycl::group_barrier(g);
if (g.leader()) {
void *scratch_ptr = scratch.data();
scratch_ptr =
std::align(alignof(KeyTy), KeysSize, scratch_ptr, KeysScratchSpace);
keys_scratch_begin = ::new (scratch_ptr) KeyTy[number_of_elements];
scratch_ptr = scratch.data() + KeysScratchSpace;
scratch_ptr = std::align(alignof(ValueTy), ValuesSize, scratch_ptr,
ValuesScratchSpace);
values_scratch_begin = ::new (scratch_ptr) ValueTy[number_of_elements];
}
// Broadcast leader's pointer (the beginning of the scratch) to all work
// items in the group.
keys_scratch_begin = sycl::group_broadcast(g, keys_scratch_begin);
values_scratch_begin = sycl::group_broadcast(g, values_scratch_begin);
return std::make_pair(keys_scratch_begin, values_scratch_begin);
}
#endif

// Swap tuples of references.
template <template <typename...> class Tuple, typename... T>
void swap(Tuple<T &...> &&first, Tuple<T &...> &&second) {
auto lhs = first;
auto rhs = second;
// Do std::swap for each element of the tuple.
std::swap(lhs, rhs);
}

// ---- merge sort implementation

// following two functions could be useless if std::[lower|upper]_bound worked
Expand Down Expand Up @@ -83,15 +123,6 @@ size_t upper_bound(Acc acc, const size_t first, const size_t last,
[comp](auto x, auto y) { return !comp(y, x); });
}

// swap for all data types including tuple-like types
template <typename T> void swap_tuples(T &a, T &b) { std::swap(a, b); }

template <template <typename...> class TupleLike, typename T1, typename T2>
void swap_tuples(TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
std::swap(std::get<0>(a), std::get<0>(b));
std::swap(std::get<1>(a), std::get<1>(b));
}

template <typename Iter> struct GetValueType {
using type = typename std::iterator_traits<Iter>::value_type;
};
Expand Down Expand Up @@ -207,18 +238,18 @@ void bubble_sort(Iter first, const size_t begin, const size_t end,
if (begin < end) {
for (size_t i = begin; i < end; ++i) {
// Handle intermediate items
for (size_t idx = i + 1; idx < end; ++idx) {
if (comp(first[idx], first[i])) {
detail::swap_tuples(first[i], first[idx]);
for (size_t idx = begin; idx < begin + (end - 1 - i); ++idx) {
if (comp(first[idx + 1], first[idx])) {
detail::swap(first[idx], first[idx + 1]);
}
}
}
}
}

template <typename Group, typename Iter, typename T, typename Compare>
template <typename Group, typename Iter, typename ScratchIter, typename Compare>
void merge_sort(Group group, Iter first, const size_t n, Compare comp,
T *scratch) {
ScratchIter scratch) {
const size_t idx = group.get_local_linear_id();
const size_t local = group.get_local_range().size();
const size_t chunk = (n - 1) / local + 1;
Expand Down Expand Up @@ -608,15 +639,41 @@ void performRadixIterDynamicSize(

// The iteration of radix sort for known number of elements per work item
template <size_t items_per_work_item, uint32_t radix_bits, bool is_comp_asc,
bool is_key_value_sort, bool is_blocked, typename KeysT,
typename ValsT, typename GroupT>
bool is_key_value_sort, bool is_input_blocked, bool is_output_blocked,
typename KeysT, typename ValsT, typename GroupT>
void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
const uint32_t first_iter,
const uint32_t last_iter, KeysT *keys,
ValsT *vals, const ScratchMemory &memory) {
const uint32_t radix_states = getStatesInBits(radix_bits);
const size_t wgsize = group.get_local_linear_range();
const size_t idx = group.get_local_linear_id();

const ScratchMemory &keys_temp = memory;
const ScratchMemory vals_temp =
memory + wgsize * items_per_work_item * sizeof(KeysT);

// If input is striped, reroder items using scratch memory before sorting,
// this only needs to be done at the first iteration.
if constexpr (!is_input_blocked) {
if (radix_iter == first_iter) {
for (uint32_t i = 0; i < items_per_work_item; ++i) {
size_t shift = i * wgsize + idx;
keys_temp.get<KeysT>(shift) = keys[i];
if constexpr (is_key_value_sort)
vals_temp.get<ValsT>(shift) = vals[i];
}
sycl::group_barrier(group);
for (uint32_t i = 0; i < items_per_work_item; ++i) {
size_t shift = idx * items_per_work_item + i;
keys[i] = keys_temp.get<KeysT>(shift);
if constexpr (is_key_value_sort)
vals[i] = vals_temp.get<ValsT>(shift);
}
sycl::group_barrier(group);
}
}

// 1.1. count per witem: create a private array for storing count values
uint32_t count_arr[items_per_work_item] = {0};
uint32_t ranks[items_per_work_item] = {0};
Expand Down Expand Up @@ -666,9 +723,6 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
sycl::group_barrier(group);

// 3. Reorder
const ScratchMemory &keys_temp = memory;
const ScratchMemory vals_temp =
memory + wgsize * items_per_work_item * sizeof(KeysT);
for (uint32_t i = 0; i < items_per_work_item; ++i) {
keys_temp.get<KeysT>(ranks[i]) = keys[i];
if constexpr (is_key_value_sort)
Expand All @@ -680,7 +734,7 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
// 4. Copy back to input
for (uint32_t i = 0; i < items_per_work_item; ++i) {
size_t shift = idx * items_per_work_item + i;
if constexpr (!is_blocked) {
if constexpr (!is_output_blocked) {
if (radix_iter == last_iter - 1)
shift = i * wgsize + idx;
}
Expand Down Expand Up @@ -728,7 +782,8 @@ void privateDynamicSort(GroupT group, KeysT *keys, ValsT *values,
}
}

template <bool is_key_value_sort, bool is_blocked, bool is_comp_asc,
template <bool is_key_value_sort, bool is_intput_blocked,
bool is_output_blocked, bool is_comp_asc,
size_t items_per_work_item = 1, uint32_t radix_bits = 4,
typename GroupT, typename T, typename U>
void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch,
Expand All @@ -739,8 +794,9 @@ void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch,

for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) {
performRadixIterStaticSize<items_per_work_item, radix_bits, is_comp_asc,
is_key_value_sort, is_blocked>(
group, radix_iter, last_iter, keys, values, scratch);
is_key_value_sort, is_intput_blocked,
is_output_blocked>(
group, radix_iter, first_iter, last_iter, keys, values, scratch);
sycl::group_barrier(group);
}
}
Expand Down
94 changes: 94 additions & 0 deletions sycl/include/sycl/detail/key_value_iterator.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
//==------------ key_value_iterator.hpp ------------------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file includes key/value iterator implementation used for group_sort
// algorithms.
//

#pragma once
#include <iterator>
#include <tuple>
#include <utility>

namespace sycl {
inline namespace _V1 {
namespace detail {

template <typename T1, typename T2> class key_value_iterator {
public:
key_value_iterator(T1 *Keys, T2 *Values) : KeyValue{Keys, Values} {}

using difference_type = std::ptrdiff_t;
using value_type = std::tuple<T1, T2>;
using reference = std::tuple<T1 &, T2 &>;
using pointer = std::tuple<T1 *, T2 *>;
using iterator_category = std::random_access_iterator_tag;

reference operator*() const {
return std::tie(*(std::get<0>(KeyValue)), *(std::get<1>(KeyValue)));
}

reference operator[](difference_type i) const { return *(*this + i); }

difference_type operator-(const key_value_iterator &it) const {
return std::get<0>(KeyValue) - std::get<0>(it.KeyValue);
}

key_value_iterator &operator+=(difference_type i) {
KeyValue =
std::make_tuple(std::get<0>(KeyValue) + i, std::get<1>(KeyValue) + i);
return *this;
}
key_value_iterator &operator-=(difference_type i) { return *this += -i; }
key_value_iterator &operator++() { return *this += 1; }
key_value_iterator &operator--() { return *this -= 1; }
std::tuple<T1 *, T2 *> base() const { return KeyValue; }
key_value_iterator operator++(int) {
key_value_iterator it(*this);
++(*this);
return it;
}
key_value_iterator operator--(int) {
key_value_iterator it(*this);
--(*this);
return it;
}

key_value_iterator operator-(difference_type i) const {
key_value_iterator it(*this);
return it -= i;
}
key_value_iterator operator+(difference_type i) const {
key_value_iterator it(*this);
return it += i;
}
friend key_value_iterator operator+(difference_type i,
const key_value_iterator &it) {
return it + i;
}

bool operator==(const key_value_iterator &it) const {
return *this - it == 0;
}

bool operator!=(const key_value_iterator &it) const { return !(*this == it); }
bool operator<(const key_value_iterator &it) const { return *this - it < 0; }
bool operator>(const key_value_iterator &it) const { return it < *this; }
bool operator<=(const key_value_iterator &it) const { return !(*this > it); }
bool operator>=(const key_value_iterator &it) const { return !(*this < it); }

private:
std::tuple<T1 *, T2 *> KeyValue;
};

template <typename T> void swap(T &first, T &second) {
std::swap(first, second);
}

} // namespace detail
} // namespace _V1
} // namespace sycl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add empty line at the EOF

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, will have to move extension doc from proposed to supported in another PR, will make this correction there.

Loading
Loading