Skip to content

Commit 505aa7d

Browse files
[SYCL][Reduction] Hide reducer non-standard members and add identity (#8215)
This commit hides the members in reducer that are not mentioned in the SYCL 2020 specification and introduces the identity member function. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 680c1b3 commit 505aa7d

File tree

1 file changed

+93
-30
lines changed

1 file changed

+93
-30
lines changed

sycl/include/sycl/reduction.hpp

Lines changed: 93 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,51 @@ struct ReducerTraits<reducer<T, BinaryOperation, Dims, Extent, View, Subst>> {
159159
static constexpr size_t extent = Extent;
160160
};
161161

162+
/// Helper class for accessing internal reducer member functions.
163+
template <typename ReducerT> class ReducerAccess {
164+
public:
165+
ReducerAccess(ReducerT &ReducerRef) : MReducerRef(ReducerRef) {}
166+
167+
template <typename ReducerRelayT = ReducerT> auto &getElement(size_t E) {
168+
return MReducerRef.getElement(E);
169+
}
170+
171+
template <typename ReducerRelayT = ReducerT>
172+
enable_if_t<
173+
IsKnownIdentityOp<typename ReducerRelayT::value_type,
174+
typename ReducerRelayT::binary_operation>::value,
175+
typename ReducerRelayT::value_type> constexpr getIdentity() {
176+
return getIdentityStatic();
177+
}
178+
179+
template <typename ReducerRelayT = ReducerT>
180+
enable_if_t<
181+
!IsKnownIdentityOp<typename ReducerRelayT::value_type,
182+
typename ReducerRelayT::binary_operation>::value,
183+
typename ReducerRelayT::value_type>
184+
getIdentity() {
185+
return MReducerRef.identity();
186+
}
187+
188+
// MSVC does not like static overloads of non-static functions, even if they
189+
// are made mutually exclusive through SFINAE. Instead we use a new static
190+
// function to be used when a static function is needed.
191+
template <typename ReducerRelayT = ReducerT>
192+
enable_if_t<
193+
IsKnownIdentityOp<typename ReducerRelayT::value_type,
194+
typename ReducerRelayT::binary_operation>::value,
195+
typename ReducerRelayT::value_type> static constexpr getIdentityStatic() {
196+
return ReducerT::getIdentity();
197+
}
198+
199+
private:
200+
ReducerT &MReducerRef;
201+
};
202+
203+
// Deduction guide to simplify the use of ReducerAccess.
204+
template <typename ReducerT>
205+
ReducerAccess(ReducerT &) -> ReducerAccess<ReducerT>;
206+
162207
/// Use CRTP to avoid redefining shorthand operators in terms of combine
163208
///
164209
/// Also, for many types with known identity the operation 'atomic_combine()'
@@ -238,7 +283,7 @@ template <class Reducer> class combiner {
238283
auto AtomicRef = sycl::atomic_ref<T, memory_order::relaxed,
239284
getMemoryScope<Space>(), Space>(
240285
address_space_cast<Space, access::decorated::no>(ReduVarPtr)[E]);
241-
Functor(std::move(AtomicRef), reducer->getElement(E));
286+
Functor(std::move(AtomicRef), ReducerAccess{*reducer}.getElement(E));
242287
}
243288
}
244289

@@ -355,13 +400,15 @@ class reducer<
355400
return *this;
356401
}
357402

358-
T getIdentity() const { return MIdentity; }
403+
T identity() const { return MIdentity; }
404+
405+
private:
406+
template <typename ReducerT> friend class detail::ReducerAccess;
359407

360408
T &getElement(size_t) { return MValue; }
361409
const T &getElement(size_t) const { return MValue; }
362-
T MValue;
363410

364-
private:
411+
T MValue;
365412
const T MIdentity;
366413
BinaryOperation MBinaryOp;
367414
};
@@ -392,7 +439,12 @@ class reducer<
392439
return *this;
393440
}
394441

395-
static T getIdentity() {
442+
T identity() const { return getIdentity(); }
443+
444+
private:
445+
template <typename ReducerT> friend class detail::ReducerAccess;
446+
447+
static constexpr T getIdentity() {
396448
return detail::known_identity_impl<BinaryOperation, T>::value;
397449
}
398450

@@ -419,6 +471,8 @@ class reducer<T, BinaryOperation, Dims, Extent, View,
419471
}
420472

421473
private:
474+
template <typename ReducerT> friend class detail::ReducerAccess;
475+
422476
T &MElement;
423477
BinaryOperation MBinaryOp;
424478
};
@@ -444,11 +498,14 @@ class reducer<
444498
return {MValue[Index], MBinaryOp};
445499
}
446500

447-
T getIdentity() const { return MIdentity; }
501+
T identity() const { return MIdentity; }
502+
503+
private:
504+
template <typename ReducerT> friend class detail::ReducerAccess;
505+
448506
T &getElement(size_t E) { return MValue[E]; }
449507
const T &getElement(size_t E) const { return MValue[E]; }
450508

451-
private:
452509
marray<T, Extent> MValue;
453510
const T MIdentity;
454511
BinaryOperation MBinaryOp;
@@ -477,14 +534,18 @@ class reducer<
477534
return {MValue[Index], BinaryOperation()};
478535
}
479536

480-
static T getIdentity() {
537+
T identity() const { return getIdentity(); }
538+
539+
private:
540+
template <typename ReducerT> friend class detail::ReducerAccess;
541+
542+
static constexpr T getIdentity() {
481543
return detail::known_identity_impl<BinaryOperation, T>::value;
482544
}
483545

484546
T &getElement(size_t E) { return MValue[E]; }
485547
const T &getElement(size_t E) const { return MValue[E]; }
486548

487-
private:
488549
marray<T, Extent> MValue;
489550
};
490551

@@ -769,8 +830,7 @@ class reduction_impl
769830
// list of known operations does not break the existing programs.
770831
if constexpr (is_known_identity) {
771832
(void)Identity;
772-
return reducer_type::getIdentity();
773-
833+
return ReducerAccess<reducer_type>::getIdentityStatic();
774834
} else {
775835
return Identity;
776836
}
@@ -788,8 +848,8 @@ class reduction_impl
788848
template <typename _self = self,
789849
enable_if_t<_self::is_known_identity> * = nullptr>
790850
reduction_impl(RedOutVar Var, bool InitializeToIdentity = false)
791-
: algo(reducer_type::getIdentity(), BinaryOperation(),
792-
InitializeToIdentity, Var) {
851+
: algo(ReducerAccess<reducer_type>::getIdentityStatic(),
852+
BinaryOperation(), InitializeToIdentity, Var) {
793853
if constexpr (!is_usm)
794854
if (Var.size() != 1)
795855
throw sycl::runtime_error(errc::invalid,
@@ -896,7 +956,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
896956
// Work-group cooperates to initialize multiple reduction variables
897957
auto LID = NDId.get_local_id(0);
898958
for (size_t E = LID; E < NElements; E += NDId.get_local_range(0)) {
899-
GroupSum[E] = Reducer.getIdentity();
959+
GroupSum[E] = ReducerAccess(Reducer).getIdentity();
900960
}
901961
workGroupBarrier();
902962

@@ -909,7 +969,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
909969
workGroupBarrier();
910970
if (LID == 0) {
911971
for (size_t E = 0; E < NElements; ++E) {
912-
Reducer.getElement(E) = GroupSum[E];
972+
ReducerAccess{Reducer}.getElement(E) = GroupSum[E];
913973
}
914974
Reducer.template atomic_combine(&Out[0]);
915975
}
@@ -959,7 +1019,7 @@ struct NDRangeReduction<
9591019
// reduce_over_group is only defined for each T, not for span<T, ...>
9601020
size_t LID = NDId.get_local_id(0);
9611021
for (int E = 0; E < NElements; ++E) {
962-
auto &RedElem = Reducer.getElement(E);
1022+
auto &RedElem = ReducerAccess{Reducer}.getElement(E);
9631023
RedElem = reduce_over_group(Group, RedElem, BOp);
9641024
if (LID == 0) {
9651025
if (NWorkGroups == 1) {
@@ -970,7 +1030,7 @@ struct NDRangeReduction<
9701030
Out[E] = RedElem;
9711031
} else {
9721032
PartialSums[NDId.get_group_linear_id() * NElements + E] =
973-
Reducer.getElement(E);
1033+
ReducerAccess{Reducer}.getElement(E);
9741034
}
9751035
}
9761036
}
@@ -993,7 +1053,7 @@ struct NDRangeReduction<
9931053
// Reduce each result separately
9941054
// TODO: Opportunity to parallelize across elements.
9951055
for (int E = 0; E < NElements; ++E) {
996-
auto LocalSum = Reducer.getIdentity();
1056+
auto LocalSum = ReducerAccess{Reducer}.getIdentity();
9971057
for (size_t I = LID; I < NWorkGroups; I += WGSize)
9981058
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);
9991059
auto Result = reduce_over_group(Group, LocalSum, BOp);
@@ -1083,7 +1143,7 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
10831143
for (int E = 0; E < NElements; ++E) {
10841144

10851145
// Copy the element to local memory to prepare it for tree-reduction.
1086-
LocalReds[LID] = Reducer.getElement(E);
1146+
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
10871147

10881148
doTreeReduction(WGSize, LID, false, Identity, LocalReds, BOp,
10891149
[&]() { workGroupBarrier(); });
@@ -1158,8 +1218,8 @@ struct NDRangeReduction<reduction::strategy::group_reduce_and_atomic_cross_wg> {
11581218

11591219
typename Reduction::binary_operation BOp;
11601220
for (int E = 0; E < NElements; ++E) {
1161-
Reducer.getElement(E) =
1162-
reduce_over_group(NDIt.get_group(), Reducer.getElement(E), BOp);
1221+
ReducerAccess{Reducer}.getElement(E) = reduce_over_group(
1222+
NDIt.get_group(), ReducerAccess{Reducer}.getElement(E), BOp);
11631223
}
11641224
if (NDIt.get_local_linear_id() == 0)
11651225
Reducer.atomic_combine(&Out[0]);
@@ -1207,14 +1267,15 @@ struct NDRangeReduction<
12071267
for (int E = 0; E < NElements; ++E) {
12081268

12091269
// Copy the element to local memory to prepare it for tree-reduction.
1210-
LocalReds[LID] = Reducer.getElement(E);
1270+
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
12111271

12121272
typename Reduction::binary_operation BOp;
1213-
doTreeReduction(WGSize, LID, IsPow2WG, Reducer.getIdentity(),
1214-
LocalReds, BOp, [&]() { NDIt.barrier(); });
1273+
doTreeReduction(WGSize, LID, IsPow2WG,
1274+
ReducerAccess{Reducer}.getIdentity(), LocalReds, BOp,
1275+
[&]() { NDIt.barrier(); });
12151276

12161277
if (LID == 0) {
1217-
Reducer.getElement(E) =
1278+
ReducerAccess{Reducer}.getElement(E) =
12181279
IsPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]);
12191280
}
12201281

@@ -1282,7 +1343,7 @@ struct NDRangeReduction<
12821343
typename Reduction::binary_operation BOp;
12831344
for (int E = 0; E < NElements; ++E) {
12841345
typename Reduction::result_type PSum;
1285-
PSum = Reducer.getElement(E);
1346+
PSum = ReducerAccess{Reducer}.getElement(E);
12861347
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
12871348
if (NDIt.get_local_linear_id() == 0) {
12881349
if (IsUpdateOfUserVar)
@@ -1346,7 +1407,8 @@ struct NDRangeReduction<
13461407
typename Reduction::result_type PSum =
13471408
(HasUniformWG || (GID < NWorkItems))
13481409
? In[GID * NElements + E]
1349-
: Reduction::reducer_type::getIdentity();
1410+
: ReducerAccess<typename Reduction::reducer_type>::
1411+
getIdentityStatic();
13501412
PSum = reduce_over_group(NDIt.get_group(), PSum, BOp);
13511413
if (NDIt.get_local_linear_id() == 0) {
13521414
if (IsUpdateOfUserVar)
@@ -1420,7 +1482,7 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
14201482
for (int E = 0; E < NElements; ++E) {
14211483

14221484
// Copy the element to local memory to prepare it for tree-reduction.
1423-
LocalReds[LID] = Reducer.getElement(E);
1485+
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
14241486

14251487
doTreeReduction(WGSize, LID, IsPow2WG, ReduIdentity, LocalReds, BOp,
14261488
[&]() { NDIt.barrier(); });
@@ -1693,7 +1755,8 @@ void reduCGFuncImplScalar(
16931755
size_t WGSize = NDIt.get_local_range().size();
16941756
size_t LID = NDIt.get_local_linear_id();
16951757

1696-
((std::get<Is>(LocalAccsTuple)[LID] = std::get<Is>(ReducersTuple).MValue),
1758+
((std::get<Is>(LocalAccsTuple)[LID] =
1759+
ReducerAccess{std::get<Is>(ReducersTuple)}.getElement(0)),
16971760
...);
16981761

16991762
// For work-groups, which size is not power of two, local accessors have
@@ -1744,7 +1807,7 @@ void reduCGFuncImplArrayHelper(bool Pow2WG, bool IsOneWG, nd_item<Dims> NDIt,
17441807
for (size_t E = 0; E < NElements; ++E) {
17451808

17461809
// Copy the element to local memory to prepare it for tree-reduction.
1747-
LocalReds[LID] = Reducer.getElement(E);
1810+
LocalReds[LID] = ReducerAccess{Reducer}.getElement(E);
17481811

17491812
doTreeReduction(WGSize, LID, Pow2WG, Identity, LocalReds, BOp,
17501813
[&]() { NDIt.barrier(); });

0 commit comments

Comments
 (0)