Skip to content

Commit 55c95b6

Browse files
Revert "[SYCL][Reduction] Hide reducer non-standard members and add identity (#8215)"
This reverts commit 505aa7d.
1 parent dc6ee4b commit 55c95b6

File tree

1 file changed

+30
-93
lines changed

1 file changed

+30
-93
lines changed

sycl/include/sycl/reduction.hpp

Lines changed: 30 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -159,51 +159,6 @@ struct ReducerTraits<reducer<T, BinaryOperation, Dims, Extent, View, Subst>> {
159159
static constexpr size_t extent = Extent;
160160
};
161161

162-
/// Helper class for accessing internal reducer member functions.
163-
template <typename ReducerT> class ReducerAccess {
164-
public:
165-
ReducerAccess(ReducerT &ReducerRef) : MReducerRef(ReducerRef) {}
166-
167-
template <typename ReducerRelayT = ReducerT> auto &getElement(size_t E) {
168-
return MReducerRef.getElement(E);
169-
}
170-
171-
template <typename ReducerRelayT = ReducerT>
172-
enable_if_t<
173-
IsKnownIdentityOp<typename ReducerRelayT::value_type,
174-
typename ReducerRelayT::binary_operation>::value,
175-
typename ReducerRelayT::value_type> constexpr getIdentity() {
176-
return getIdentityStatic();
177-
}
178-
179-
template <typename ReducerRelayT = ReducerT>
180-
enable_if_t<
181-
!IsKnownIdentityOp<typename ReducerRelayT::value_type,
182-
typename ReducerRelayT::binary_operation>::value,
183-
typename ReducerRelayT::value_type>
184-
getIdentity() {
185-
return MReducerRef.identity();
186-
}
187-
188-
// MSVC does not like static overloads of non-static functions, even if they
189-
// are made mutually exclusive through SFINAE. Instead we use a new static
190-
// function to be used when a static function is needed.
191-
template <typename ReducerRelayT = ReducerT>
192-
enable_if_t<
193-
IsKnownIdentityOp<typename ReducerRelayT::value_type,
194-
typename ReducerRelayT::binary_operation>::value,
195-
typename ReducerRelayT::value_type> static constexpr getIdentityStatic() {
196-
return ReducerT::getIdentity();
197-
}
198-
199-
private:
200-
ReducerT &MReducerRef;
201-
};
202-
203-
// Deduction guide to simplify the use of ReducerAccess.
204-
template <typename ReducerT>
205-
ReducerAccess(ReducerT &) -> ReducerAccess<ReducerT>;
206-
207162
/// Use CRTP to avoid redefining shorthand operators in terms of combine
208163
///
209164
/// Also, for many types with known identity the operation 'atomic_combine()'
@@ -283,7 +238,7 @@ template <class Reducer> class combiner {
283238
auto AtomicRef = sycl::atomic_ref<T, memory_order::relaxed,
284239
getMemoryScope<Space>(), Space>(
285240
address_space_cast<Space, access::decorated::no>(ReduVarPtr)[E]);
286-
Functor(std::move(AtomicRef), ReducerAccess{*reducer}.getElement(E));
241+
Functor(std::move(AtomicRef), reducer->getElement(E));
287242
}
288243
}
289244

@@ -400,15 +355,13 @@ class reducer<
400355
return *this;
401356
}
402357

403-
T identity() const { return MIdentity; }
404-
405-
private:
406-
template <typename ReducerT> friend class detail::ReducerAccess;
358+
T getIdentity() const { return MIdentity; }
407359

408360
T &getElement(size_t) { return MValue; }
409361
const T &getElement(size_t) const { return MValue; }
410-
411362
T MValue;
363+
364+
private:
412365
const T MIdentity;
413366
BinaryOperation MBinaryOp;
414367
};
@@ -439,12 +392,7 @@ class reducer<
439392
return *this;
440393
}
441394

442-
T identity() const { return getIdentity(); }
443-
444-
private:
445-
template <typename ReducerT> friend class detail::ReducerAccess;
446-
447-
static constexpr T getIdentity() {
395+
static T getIdentity() {
448396
return detail::known_identity_impl<BinaryOperation, T>::value;
449397
}
450398

@@ -471,8 +419,6 @@ class reducer<T, BinaryOperation, Dims, Extent, View,
471419
}
472420

473421
private:
474-
template <typename ReducerT> friend class detail::ReducerAccess;
475-
476422
T &MElement;
477423
BinaryOperation MBinaryOp;
478424
};
@@ -498,14 +444,11 @@ class reducer<
498444
return {MValue[Index], MBinaryOp};
499445
}
500446

501-
T identity() const { return MIdentity; }
502-
503-
private:
504-
template <typename ReducerT> friend class detail::ReducerAccess;
505-
447+
T getIdentity() const { return MIdentity; }
506448
T &getElement(size_t E) { return MValue[E]; }
507449
const T &getElement(size_t E) const { return MValue[E]; }
508450

451+
private:
509452
marray<T, Extent> MValue;
510453
const T MIdentity;
511454
BinaryOperation MBinaryOp;
@@ -534,18 +477,14 @@ class reducer<
534477
return {MValue[Index], BinaryOperation()};
535478
}
536479

537-
T identity() const { return getIdentity(); }
538-
539-
private:
540-
template <typename ReducerT> friend class detail::ReducerAccess;
541-
542-
static constexpr T getIdentity() {
480+
static T getIdentity() {
543481
return detail::known_identity_impl<BinaryOperation, T>::value;
544482
}
545483

546484
T &getElement(size_t E) { return MValue[E]; }
547485
const T &getElement(size_t E) const { return MValue[E]; }
548486

487+
private:
549488
marray<T, Extent> MValue;
550489
};
551490

@@ -830,7 +769,8 @@ class reduction_impl
830769
// list of known operations does not break the existing programs.
831770
if constexpr (is_known_identity) {
832771
(void)Identity;
833-
return ReducerAccess<reducer_type>::getIdentityStatic();
772+
return reducer_type::getIdentity();
773+
834774
} else {
835775
return Identity;
836776
}
@@ -848,8 +788,8 @@ class reduction_impl
848788
template <typename _self = self,
849789
enable_if_t<_self::is_known_identity> * = nullptr>
850790
reduction_impl(RedOutVar Var, bool InitializeToIdentity = false)
851-
: algo(ReducerAccess<reducer_type>::getIdentityStatic(),
852-
BinaryOperation(), InitializeToIdentity, Var) {
791+
: algo(reducer_type::getIdentity(), BinaryOperation(),
792+
InitializeToIdentity, Var) {
853793
if constexpr (!is_usm)
854794
if (Var.size() != 1)
855795
throw sycl::runtime_error(errc::invalid,
@@ -956,7 +896,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
956896
// Work-group cooperates to initialize multiple reduction variables
957897
auto LID = NDId.get_local_id(0);
958898
for (size_t E = LID; E < NElements; E += NDId.get_local_range(0)) {
959-
GroupSum[E] = ReducerAccess(Reducer).getIdentity();
899+
GroupSum[E] = Reducer.getIdentity();
960900
}
961901
workGroupBarrier();
962902

@@ -969,7 +909,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
969909
workGroupBarrier();
970910
if (LID == 0) {
971911
for (size_t E = 0; E < NElements; ++E) {
972-
ReducerAccess{Reducer}.getElement(E) = GroupSum[E];
912+
Reducer.getElement(E) = GroupSum[E];
973913
}
974914
Reducer.template atomic_combine(&Out[0]);
975915
}
@@ -1019,7 +959,7 @@ struct NDRangeReduction<
1019959
// reduce_over_group is only defined for each T, not for span<T, ...>
1020960
size_t LID = NDId.get_local_id(0);
1021961
for (int E = 0; E < NElements; ++E) {
1022-
auto &RedElem = ReducerAccess{Reducer}.getElement(E);
962+
auto &RedElem = Reducer.getElement(E);
1023963
RedElem = reduce_over_group(Group, RedElem, BOp);
1024964
if (LID == 0) {
1025965
if (NWorkGroups == 1) {
@@ -1030,7 +970,7 @@ struct NDRangeReduction<
1030970
Out[E] = RedElem;
1031971
} else {
1032972
PartialSums[NDId.get_group_linear_id() * NElements + E] =
1033-
ReducerAccess{Reducer}.getElement(E);
973+
Reducer.getElement(E);
1034974
}
1035975
}
1036976
}
@@ -1053,7 +993,7 @@ struct NDRangeReduction<
1053993
// Reduce each result separately
1054994
// TODO: Opportunity to parallelize across elements.
1055995
for (int E = 0; E < NElements; ++E) {
1056-
auto LocalSum = ReducerAccess{Reducer}.getIdentity();
996+
auto LocalSum = Reducer.getIdentity();
1057997
for (size_t I = LID; I < NWorkGroups; I += WGSize)
1058998
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);
1059999
auto Result = reduce_over_group(Group, LocalSum, BOp);
@@ -1143,7 +1083,7 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
11431083
for (int E = 0; E < NElements; ++E) {
11441084

11451085
// Copy the element to local memory to prepare it for tree-reduction.
1146-
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
1086+
LocalReds[LID] = Reducer.getElement(E);
11471087

11481088
doTreeReduction(WGSize, LID, false, Identity, LocalReds, BOp,
11491089
[&]() { workGroupBarrier(); });
@@ -1218,8 +1158,8 @@ struct NDRangeReduction<reduction::strategy::group_reduce_and_atomic_cross_wg> {
12181158

12191159
typename Reduction::binary_operation BOp;
12201160
for (int E = 0; E < NElements; ++E) {
1221-
ReducerAccess{Reducer}.getElement(E) = reduce_over_group(
1222-
NDIt.get_group(), ReducerAccess{Reducer}.getElement(E), BOp);
1161+
Reducer.getElement(E) =
1162+
reduce_over_group(NDIt.get_group(), Reducer.getElement(E), BOp);
12231163
}
12241164
if (NDIt.get_local_linear_id() == 0)
12251165
Reducer.atomic_combine(&Out[0]);
@@ -1267,15 +1207,14 @@ struct NDRangeReduction<
12671207
for (int E = 0; E < NElements; ++E) {
12681208

12691209
// Copy the element to local memory to prepare it for tree-reduction.
1270-
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
1210+
LocalReds[LID] = Reducer.getElement(E);
12711211

12721212
typename Reduction::binary_operation BOp;
1273-
doTreeReduction(WGSize, LID, IsPow2WG,
1274-
ReducerAccess{Reducer}.getIdentity(), LocalReds, BOp,
1275-
[&]() { NDIt.barrier(); });
1213+
doTreeReduction(WGSize, LID, IsPow2WG, Reducer.getIdentity(),
1214+
LocalReds, BOp, [&]() { NDIt.barrier(); });
12761215

12771216
if (LID == 0) {
1278-
ReducerAccess{Reducer}.getElement(E) =
1217+
Reducer.getElement(E) =
12791218
IsPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]);
12801219
}
12811220

@@ -1343,7 +1282,7 @@ struct NDRangeReduction<
13431282
typename Reduction::binary_operation BOp;
13441283
for (int E = 0; E < NElements; ++E) {
13451284
typename Reduction::result_type PSum;
1346-
PSum = ReducerAccess{Reducer}.getElement(E);
1285+
PSum = Reducer.getElement(E);
13471286
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
13481287
if (NDIt.get_local_linear_id() == 0) {
13491288
if (IsUpdateOfUserVar)
@@ -1407,8 +1346,7 @@ struct NDRangeReduction<
14071346
typename Reduction::result_type PSum =
14081347
(HasUniformWG || (GID < NWorkItems))
14091348
? In[GID * NElements + E]
1410-
: ReducerAccess<typename Reduction::reducer_type>::
1411-
getIdentityStatic();
1349+
: Reduction::reducer_type::getIdentity();
14121350
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
14131351
if (NDIt.get_local_linear_id() == 0) {
14141352
if (IsUpdateOfUserVar)
@@ -1482,7 +1420,7 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
14821420
for (int E = 0; E < NElements; ++E) {
14831421

14841422
// Copy the element to local memory to prepare it for tree-reduction.
1485-
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
1423+
LocalReds[LID] = Reducer.getElement(E);
14861424

14871425
doTreeReduction(WGSize, LID, IsPow2WG, ReduIdentity, LocalReds, BOp,
14881426
[&]() { NDIt.barrier(); });
@@ -1755,8 +1693,7 @@ void reduCGFuncImplScalar(
17551693
size_t WGSize = NDIt.get_local_range().size();
17561694
size_t LID = NDIt.get_local_linear_id();
17571695

1758-
((std::get<Is>(LocalAccsTuple)[LID] =
1759-
ReducerAccess{std::get<Is>(ReducersTuple)}.getElement(0)),
1696+
((std::get<Is>(LocalAccsTuple)[LID] = std::get<Is>(ReducersTuple).MValue),
17601697
...);
17611698

17621699
// For work-groups, which size is not power of two, local accessors have
@@ -1807,7 +1744,7 @@ void reduCGFuncImplArrayHelper(bool Pow2WG, bool IsOneWG, nd_item<Dims> NDIt,
18071744
for (size_t E = 0; E < NElements; ++E) {
18081745

18091746
// Copy the element to local memory to prepare it for tree-reduction.
1810-
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
1747+
LocalReds[LID] = Reducer.getElement(E);
18111748

18121749
doTreeReduction(WGSize, LID, Pow2WG, Identity, LocalReds, BOp,
18131750
[&]() { NDIt.barrier(); });

0 commit comments

Comments
 (0)