Skip to content

Commit 0360e6a

Browse files
authored
[SYCL] Follow up fixes for group_sort extension (#14591)
* Change `memory_required` queries to return the exact size of local memory which needs to be allocated for `group_sort` algorithm, currently it is a bit more than required. * Fix a bug in `align_key_value_scratch` where I didn't take into account that `std::align` changes value of `KeysScratchSpace`. To find scratch of the values I need to use original value of this variable. * Fix mistake in the test where keys/values were mixed up. This is more of a cosmetic change. RunSortKeyValueOverGroup functions accepts two vectors - the first is considered keys, and the second is considered values. Verification is done in the same function (with the same assumption that the first vector is keys and the second is values). So it doesn't actually matter in which order aforementioned arguments are provided, test still serves it purpose.
1 parent 7d7ab2f commit 0360e6a

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

sycl/include/sycl/detail/group_sort_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ align_key_value_scratch(sycl::span<std::byte> scratch, Group g,
7272
scratch_ptr =
7373
std::align(alignof(KeyTy), KeysSize, scratch_ptr, KeysScratchSpace);
7474
keys_scratch_begin = ::new (scratch_ptr) KeyTy[number_of_elements];
75-
scratch_ptr = scratch.data() + KeysScratchSpace;
75+
scratch_ptr = scratch.data() + KeysSize + alignof(KeyTy);
7676
scratch_ptr = std::align(alignof(ValueTy), ValuesSize, scratch_ptr,
7777
ValuesScratchSpace);
7878
values_scratch_begin = ::new (scratch_ptr) ValueTy[number_of_elements];

sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,13 @@ class group_sorter {
341341
#endif
342342
}
343343

344-
static std::size_t memory_required(sycl::memory_scope scope,
344+
static std::size_t memory_required([[maybe_unused]] sycl::memory_scope scope,
345345
size_t range_size) {
346-
return 2 * joint_sorter<>::template memory_required<T>(
347-
scope, range_size * ElementsPerWorkItem);
346+
// We need a space (in bytes) for the buffer of output values and the
347+
// temporary buffer. Where number of elements in each buffer is range_size
348+
// (group size) multiplied by elements per work item. Also we have to align
349+
// these two buffers, so need an additional space of size alignof(T).
350+
return 2 * range_size * ElementsPerWorkItem * sizeof(T) + alignof(T);
348351
}
349352
};
350353

@@ -449,11 +452,18 @@ class group_key_value_sorter {
449452
#endif
450453
}
451454

452-
static std::size_t memory_required(sycl::memory_scope scope,
455+
static std::size_t memory_required([[maybe_unused]] sycl::memory_scope scope,
453456
std::size_t range_size) {
454-
return group_sorter<std::tuple<KeyTy, ValueTy>, CompareT,
455-
ElementsPerWorkItem>::memory_required(scope,
456-
range_size);
457+
// We need a space (in bytes) for the following buffers:
458+
// 1. Output buffer for keys and temporary buffer for keys.
459+
// 2. Output buffer for values and temporary buffer for values.
460+
// Where number of elements in each buffer is range_size (group size)
461+
// multiplied by elements per work item. We have to align buffers of keys
462+
// and buffers of values, so need an additional space equal to maximum
463+
// between alignment requirements of types KeyTy and ValueTy.
464+
return 2 * range_size * ElementsPerWorkItem *
465+
(sizeof(KeyTy) + sizeof(ValueTy)) +
466+
(std::max)(alignof(KeyTy), alignof(ValueTy));
457467
}
458468
};
459469
} // namespace default_sorters

sycl/test-e2e/GroupAlgorithm/SYCL2020/group_sort/key_value_sort.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,9 @@ void RunOverType(sycl::queue &Q, size_t DataSize) {
269269
auto RunOnDataAndComp = [&](const std::vector<KeyTy> &Keys,
270270
const std::vector<ValueTy> &Data,
271271
const auto &Comparator) {
272-
RunSortKeyValueOverGroup<UseGroupT::WorkGroup, 1>(Q, Data, Keys,
272+
RunSortKeyValueOverGroup<UseGroupT::WorkGroup, 1>(Q, Keys, Data,
273273
Comparator);
274-
RunSortKeyValueOverGroup<UseGroupT::WorkGroup, 2>(Q, Data, Keys,
274+
RunSortKeyValueOverGroup<UseGroupT::WorkGroup, 2>(Q, Keys, Data,
275275
Comparator);
276276

277277
if (Q.get_backend() == sycl::backend::ext_oneapi_cuda ||
@@ -280,8 +280,8 @@ void RunOverType(sycl::queue &Q, size_t DataSize) {
280280
return;
281281
}
282282

283-
RunSortKeyValueOverGroup<UseGroupT::SubGroup, 1>(Q, Data, Keys, Comparator);
284-
RunSortKeyValueOverGroup<UseGroupT::SubGroup, 2>(Q, Data, Keys, Comparator);
283+
RunSortKeyValueOverGroup<UseGroupT::SubGroup, 1>(Q, Keys, Data, Comparator);
284+
RunSortKeyValueOverGroup<UseGroupT::SubGroup, 2>(Q, Keys, Data, Comparator);
285285
};
286286

287287
RunOnDataAndComp(KeysRandom, DataRandom, std::less<KeyTy>{});

0 commit comments

Comments
 (0)