Skip to content

[NFCI][SYCL] Remove Reduction::getOutPointer #7184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 50 additions & 71 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,14 +686,6 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {

RedOutVar &getUserRedVar() { return MRedOut; }

static inline result_type *getOutPointer(result_type *OutPtr) {
return OutPtr;
}
template <class AccessorType>
static inline result_type *getOutPointer(const AccessorType &OutAcc) {
return OutAcc.get_pointer().get();
}

private:
// Array reduction is performed element-wise to avoid stack growth, hence
// 1-dimensional always.
Expand Down Expand Up @@ -895,7 +887,7 @@ bool reduCGFuncForRangeFastAtomics(handler &CGH, KernelType KernelFunc,
for (size_t E = 0; E < NElements; ++E) {
Reducer.getElement(E) = GroupSum[E];
}
Reducer.template atomic_combine(Reduction::getOutPointer(Out));
Reducer.template atomic_combine(&Out[0]);
}
});
return Reduction::is_usm || Redu.initializeToIdentity();
Expand Down Expand Up @@ -947,12 +939,11 @@ bool reduCGFuncForRangeFastReduce(handler &CGH, KernelType KernelFunc,
RedElem = reduce_over_group(Group, RedElem, BOp);
if (LID == 0) {
if (NWorkGroups == 1) {
auto &OutElem = Reduction::getOutPointer(Out)[E];
// Can avoid using partial sum and write the final result
// immediately.
if (IsUpdateOfUserVar)
RedElem = BOp(RedElem, OutElem);
OutElem = RedElem;
RedElem = BOp(RedElem, Out[E]);
Out[E] = RedElem;
} else {
PartialSums[NDId.get_group_linear_id() * NElements + E] =
Reducer.getElement(E);
Expand All @@ -978,16 +969,15 @@ bool reduCGFuncForRangeFastReduce(handler &CGH, KernelType KernelFunc,
// Reduce each result separately
// TODO: Opportunity to parallelize across elements.
for (int E = 0; E < NElements; ++E) {
auto &OutElem = Reduction::getOutPointer(Out)[E];
auto LocalSum = Reducer.getIdentity();
for (size_t I = LID; I < NWorkGroups; I += WGSize)
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);
auto Result = reduce_over_group(Group, LocalSum, BOp);

if (LID == 0) {
if (IsUpdateOfUserVar)
Result = BOp(Result, OutElem);
OutElem = Result;
Result = BOp(Result, Out[E]);
Out[E] = Result;
}
}
}
Expand Down Expand Up @@ -1071,10 +1061,9 @@ bool reduCGFuncForRangeBasic(handler &CGH, KernelType KernelFunc,
if (LID == 0) {
auto V = BOp(LocalReds[0], LocalReds[WGSize]);
if (NWorkGroups == 1 && IsUpdateOfUserVar)
V = BOp(V, Reduction::getOutPointer(Out)[E]);
V = BOp(V, Out[E]);
// if NWorkGroups == 1, then PartialsSum and Out point to same memory.
Reduction::getOutPointer(
PartialSums)[NDId.get_group_linear_id() * NElements + E] = V;
PartialSums[NDId.get_group_linear_id() * NElements + E] = V;
}
}

Expand All @@ -1095,9 +1084,7 @@ bool reduCGFuncForRangeBasic(handler &CGH, KernelType KernelFunc,
for (int E = 0; E < NElements; ++E) {
auto LocalSum = Identity;
for (size_t I = LID; I < NWorkGroups; I += WGSize)
LocalSum =
BOp(LocalSum,
Reduction::getOutPointer(PartialSums)[I * NElements + E]);
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);

LocalReds[LID] = LocalSum;
if (LID == 0)
Expand All @@ -1116,8 +1103,8 @@ bool reduCGFuncForRangeBasic(handler &CGH, KernelType KernelFunc,
if (LID == 0) {
auto V = BOp(LocalReds[0], LocalReds[WGSize]);
if (IsUpdateOfUserVar)
V = BOp(V, Reduction::getOutPointer(Out)[E]);
Reduction::getOutPointer(Out)[E] = V;
V = BOp(V, Out[E]);
Out[E] = V;
}
}
}
Expand Down Expand Up @@ -1189,7 +1176,7 @@ void reduCGFuncForNDRangeBothFastReduceAndAtomics(handler &CGH,
reduce_over_group(NDIt.get_group(), Reducer.getElement(E), BOp);
}
if (NDIt.get_local_linear_id() == 0)
Reducer.atomic_combine(Reduction::getOutPointer(Out));
Reducer.atomic_combine(&Out[0]);
});
}

Expand Down Expand Up @@ -1270,7 +1257,7 @@ void reduCGFuncForNDRangeFastAtomicsOnly(handler &CGH, bool IsPow2WG,
}

if (LID == 0) {
Reducer.atomic_combine(Reduction::getOutPointer(Out));
Reducer.atomic_combine(&Out[0]);
}
});
}
Expand Down Expand Up @@ -1316,8 +1303,8 @@ void reduCGFuncForNDRangeFastReduceOnly(handler &CGH, KernelType KernelFunc,
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
if (NDIt.get_local_linear_id() == 0) {
if (IsUpdateOfUserVar)
PSum = BOp(Reduction::getOutPointer(Out)[E], PSum);
Reduction::getOutPointer(Out)[WGID * NElements + E] = PSum;
PSum = BOp(Out[E], PSum);
Out[WGID * NElements + E] = PSum;
}
}
});
Expand Down Expand Up @@ -1397,8 +1384,8 @@ void reduCGFuncForNDRangeBasic(handler &CGH, bool IsPow2WG,
typename Reduction::result_type PSum =
IsPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]);
if (IsUpdateOfUserVar)
PSum = BOp(*(Reduction::getOutPointer(Out)), PSum);
Reduction::getOutPointer(Out)[GrID * NElements + E] = PSum;
PSum = BOp(Out[0], PSum);
Out[GrID * NElements + E] = PSum;
}

// Ensure item 0 is finished with LocalReds before next iteration
Expand Down Expand Up @@ -1448,8 +1435,8 @@ void reduAuxCGFuncFastReduceImpl(handler &CGH, bool UniformWG,
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
if (NDIt.get_local_linear_id() == 0) {
if (IsUpdateOfUserVar)
PSum = BOp(Reduction::getOutPointer(Out)[E], PSum);
Reduction::getOutPointer(Out)[WGID * NElements + E] = PSum;
PSum = BOp(Out[E], PSum);
Out[WGID * NElements + E] = PSum;
}
}
});
Expand Down Expand Up @@ -1525,8 +1512,8 @@ void reduAuxCGFuncNoFastReduceNorAtomicImpl(handler &CGH, bool UniformPow2WG,
typename Reduction::result_type PSum =
UniformPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]);
if (IsUpdateOfUserVar)
PSum = BOp(*(Reduction::getOutPointer(Out)), PSum);
Reduction::getOutPointer(Out)[GrID * NElements + E] = PSum;
PSum = BOp(Out[0], PSum);
Out[GrID * NElements + E] = PSum;
}

// Ensure item 0 is finished with LocalReds before next iteration
Expand Down Expand Up @@ -1748,24 +1735,20 @@ void writeReduSumsToOutAccs(
// Add the initial value of user's variable to the final result.
if (IsOneWG)
std::tie(std::get<Is>(LocalAccs)[0]...) = std::make_tuple(std::get<Is>(
BOPs)(std::get<Is>(LocalAccs)[0],
IsInitializeToIdentity[Is]
? std::get<Is>(IdentityVals)
: std::tuple_element_t<Is, std::tuple<Reductions...>>::
getOutPointer(std::get<Is>(OutAccs))[0])...);
BOPs)(std::get<Is>(LocalAccs)[0], IsInitializeToIdentity[Is]
? std::get<Is>(IdentityVals)
: std::get<Is>(OutAccs)[0])...);

if (Pow2WG) {
// The partial sums for the work-group are stored in 0-th elements of local
// accessors. Simply write those sums to output accessors.
std::tie(std::tuple_element_t<Is, std::tuple<Reductions...>>::getOutPointer(
std::get<Is>(OutAccs))[OutAccIndex]...) =
std::tie(std::get<Is>(OutAccs)[OutAccIndex]...) =
std::make_tuple(std::get<Is>(LocalAccs)[0]...);
} else {
// Each of local accessors keeps two partial sums: in 0-th and WGsize-th
// elements. Combine them into final partial sums and write to output
// accessors.
std::tie(std::tuple_element_t<Is, std::tuple<Reductions...>>::getOutPointer(
std::get<Is>(OutAccs))[OutAccIndex]...) =
std::tie(std::get<Is>(OutAccs)[OutAccIndex]...) =
std::make_tuple(std::get<Is>(BOPs)(std::get<Is>(LocalAccs)[0],
std::get<Is>(LocalAccs)[WGSize])...);
}
Expand Down Expand Up @@ -1932,23 +1915,21 @@ void reduCGFuncImplArrayHelper(bool Pow2WG, bool IsOneWG, nd_item<Dims> NDIt,
if (LID == 0) {
if (IsOneWG) {
LocalReds[0] =
BOp(LocalReds[0], IsInitializeToIdentity
? Identity
: Reduction::getOutPointer(Out)[E]);
BOp(LocalReds[0], IsInitializeToIdentity ? Identity : Out[E]);
}

size_t GrID = NDIt.get_group_linear_id();
if (Pow2WG) {
// The partial sums for the work-group are stored in 0-th elements of
// local accessors. Simply write those sums to output accessors.
Reduction::getOutPointer(Out)[GrID * NElements + E] = LocalReds[0];
} else {
// Each of local accessors keeps two partial sums: in 0-th and WGsize-th
// elements. Combine them into final partial sums and write to output
// accessors.
Reduction::getOutPointer(Out)[GrID * NElements + E] =
BOp(LocalReds[0], LocalReds[WGSize]);
}
Out[GrID * NElements + E] =
Pow2WG ?
// The partial sums for the work-group are stored in 0-th
// elements of local accessors. Simply write those sums to
// output accessors.
LocalReds[0]
:
// Each of local accessors keeps two partial sums: in 0-th
// and WGsize-th elements. Combine them into final partial
// sums and write to output accessors.
BOp(LocalReds[0], LocalReds[WGSize]);
}

// Ensure item 0 is finished with LocalReds before next iteration
Expand Down Expand Up @@ -2090,7 +2071,7 @@ void reduCGFuncAtomic64(handler &CGH, KernelType KernelFunc,
}

if (NDIt.get_local_linear_id() == 0) {
Reducer.atomic_combine(Reduction::getOutPointer(Out));
Reducer.atomic_combine(&Out[0]);
}
});
}
Expand Down Expand Up @@ -2199,23 +2180,21 @@ void reduAuxCGFuncImplArrayHelper(bool UniformPow2WG, bool IsOneWG,
if (LID == 0) {
if (IsOneWG) {
LocalReds[0] =
BOp(LocalReds[0], IsInitializeToIdentity
? Identity
: Reduction::getOutPointer(Out)[E]);
BOp(LocalReds[0], IsInitializeToIdentity ? Identity : Out[E]);
}

size_t GrID = NDIt.get_group_linear_id();
if (UniformPow2WG) {
// The partial sums for the work-group are stored in 0-th elements of
// local accessors. Simply write those sums to output accessors.
Reduction::getOutPointer(Out)[GrID * NElements + E] = LocalReds[0];
} else {
// Each of local accessors keeps two partial sums: in 0-th and WGsize-th
// elements. Combine them into final partial sums and write to output
// accessors.
Reduction::getOutPointer(Out)[GrID * NElements + E] =
BOp(LocalReds[0], LocalReds[WGSize]);
}
Out[GrID * NElements + E] =
UniformPow2WG ?
// The partial sums for the work-group are stored in
// 0-th elements of local accessors. Simply write those
// sums to output accessors.
LocalReds[0]
:
// Each of local accessors keeps two partial sums: in
// 0-th and WGsize-th elements. Combine them into final
// partial sums and write to output accessors.
BOp(LocalReds[0], LocalReds[WGSize]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uff! Those comments are really messing with the formatter. Might it make sense to move the comments to before the statement and combine them? Say something like:

If it is a uniform work-group with a power-of-two size, ..., otherwise ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree clang-format did something weird here, but, IMO, it's not that bad. I plan more refactoring, will see if can come up with something in one of next PRs.

}

// Ensure item 0 is finished with LocalReds before next iteration
Expand Down