Skip to content

[SYCL][NFC] Use reducer-access helper function instead of deduction guide #8411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,11 @@ template <typename ReducerT> class ReducerAccess {
ReducerT &MReducerRef;
};

// Deduction guide to simplify the use of ReducerAccess.
template <typename ReducerT>
ReducerAccess(ReducerT &) -> ReducerAccess<ReducerT>;
// Helper function to simplify the use of ReducerAccess. This avoids the need
// for potentially unsupported deduction guides.
template <typename ReducerT> auto getReducerAccess(ReducerT &Reducer) {
return ReducerAccess<ReducerT>{Reducer};
}

/// Use CRTP to avoid redefining shorthand operators in terms of combine
///
Expand Down Expand Up @@ -283,7 +285,7 @@ template <class Reducer> class combiner {
auto AtomicRef = sycl::atomic_ref<T, memory_order::relaxed,
getMemoryScope<Space>(), Space>(
address_space_cast<Space, access::decorated::no>(ReduVarPtr)[E]);
Functor(std::move(AtomicRef), ReducerAccess{*reducer}.getElement(E));
Functor(std::move(AtomicRef), getReducerAccess(*reducer).getElement(E));
}
}

Expand Down Expand Up @@ -956,7 +958,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
// Work-group cooperates to initialize multiple reduction variables
auto LID = NDId.get_local_id(0);
for (size_t E = LID; E < NElements; E += NDId.get_local_range(0)) {
GroupSum[E] = ReducerAccess(Reducer).getIdentity();
GroupSum[E] = getReducerAccess(Reducer).getIdentity();
}
workGroupBarrier();

Expand All @@ -969,7 +971,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
workGroupBarrier();
if (LID == 0) {
for (size_t E = 0; E < NElements; ++E) {
ReducerAccess{Reducer}.getElement(E) = GroupSum[E];
getReducerAccess(Reducer).getElement(E) = GroupSum[E];
}
Reducer.template atomic_combine(&Out[0]);
}
Expand Down Expand Up @@ -1019,7 +1021,7 @@ struct NDRangeReduction<
// reduce_over_group is only defined for each T, not for span<T, ...>
size_t LID = NDId.get_local_id(0);
for (int E = 0; E < NElements; ++E) {
auto &RedElem = ReducerAccess{Reducer}.getElement(E);
auto &RedElem = getReducerAccess(Reducer).getElement(E);
RedElem = reduce_over_group(Group, RedElem, BOp);
if (LID == 0) {
if (NWorkGroups == 1) {
Expand All @@ -1030,7 +1032,7 @@ struct NDRangeReduction<
Out[E] = RedElem;
} else {
PartialSums[NDId.get_group_linear_id() * NElements + E] =
ReducerAccess{Reducer}.getElement(E);
getReducerAccess(Reducer).getElement(E);
}
}
}
Expand All @@ -1053,7 +1055,7 @@ struct NDRangeReduction<
// Reduce each result separately
// TODO: Opportunity to parallelize across elements.
for (int E = 0; E < NElements; ++E) {
auto LocalSum = ReducerAccess{Reducer}.getIdentity();
auto LocalSum = getReducerAccess(Reducer).getIdentity();
for (size_t I = LID; I < NWorkGroups; I += WGSize)
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);
auto Result = reduce_over_group(Group, LocalSum, BOp);
Expand Down Expand Up @@ -1143,7 +1145,7 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
LocalReds[LID] = getReducerAccess(Reducer).getElement(E);

doTreeReduction(WGSize, LID, false, Identity, LocalReds, BOp,
[&]() { workGroupBarrier(); });
Expand Down Expand Up @@ -1218,8 +1220,8 @@ struct NDRangeReduction<reduction::strategy::group_reduce_and_atomic_cross_wg> {

typename Reduction::binary_operation BOp;
for (int E = 0; E < NElements; ++E) {
ReducerAccess{Reducer}.getElement(E) = reduce_over_group(
NDIt.get_group(), ReducerAccess{Reducer}.getElement(E), BOp);
getReducerAccess(Reducer).getElement(E) = reduce_over_group(
NDIt.get_group(), getReducerAccess(Reducer).getElement(E), BOp);
}
if (NDIt.get_local_linear_id() == 0)
Reducer.atomic_combine(&Out[0]);
Expand Down Expand Up @@ -1267,15 +1269,15 @@ struct NDRangeReduction<
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
LocalReds[LID] = getReducerAccess(Reducer).getElement(E);

typename Reduction::binary_operation BOp;
doTreeReduction(WGSize, LID, IsPow2WG,
ReducerAccess{Reducer}.getIdentity(), LocalReds, BOp,
[&]() { NDIt.barrier(); });
getReducerAccess(Reducer).getIdentity(), LocalReds,
BOp, [&]() { NDIt.barrier(); });

if (LID == 0) {
ReducerAccess{Reducer}.getElement(E) =
getReducerAccess(Reducer).getElement(E) =
IsPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]);
}

Expand Down Expand Up @@ -1343,7 +1345,7 @@ struct NDRangeReduction<
typename Reduction::binary_operation BOp;
for (int E = 0; E < NElements; ++E) {
typename Reduction::result_type PSum;
PSum = ReducerAccess{Reducer}.getElement(E);
PSum = getReducerAccess(Reducer).getElement(E);
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
if (NDIt.get_local_linear_id() == 0) {
if (IsUpdateOfUserVar)
Expand Down Expand Up @@ -1482,7 +1484,7 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
for (int E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
LocalReds[LID] = getReducerAccess(Reducer).getElement(E);

doTreeReduction(WGSize, LID, IsPow2WG, ReduIdentity, LocalReds, BOp,
[&]() { NDIt.barrier(); });
Expand Down Expand Up @@ -1756,7 +1758,7 @@ void reduCGFuncImplScalar(
size_t LID = NDIt.get_local_linear_id();

((std::get<Is>(LocalAccsTuple)[LID] =
ReducerAccess{std::get<Is>(ReducersTuple)}.getElement(0)),
getReducerAccess(std::get<Is>(ReducersTuple)).getElement(0)),
...);

// For work-groups, which size is not power of two, local accessors have
Expand Down Expand Up @@ -1807,7 +1809,7 @@ void reduCGFuncImplArrayHelper(bool Pow2WG, bool IsOneWG, nd_item<Dims> NDIt,
for (size_t E = 0; E < NElements; ++E) {

// Copy the element to local memory to prepare it for tree-reduction.
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
LocalReds[LID] = getReducerAccess(Reducer).getElement(E);

doTreeReduction(WGSize, LID, Pow2WG, Identity, LocalReds, BOp,
[&]() { NDIt.barrier(); });
Expand Down