@@ -159,6 +159,51 @@ struct ReducerTraits<reducer<T, BinaryOperation, Dims, Extent, View, Subst>> {
159
159
static constexpr size_t extent = Extent;
160
160
};
161
161
162
+ // / Helper class for accessing internal reducer member functions.
163
+ template <typename ReducerT> class ReducerAccess {
164
+ public:
165
+ ReducerAccess (ReducerT &ReducerRef) : MReducerRef(ReducerRef) {}
166
+
167
+ template <typename ReducerRelayT = ReducerT> auto &getElement (size_t E) {
168
+ return MReducerRef.getElement (E);
169
+ }
170
+
171
+ template <typename ReducerRelayT = ReducerT>
172
+ enable_if_t <
173
+ IsKnownIdentityOp<typename ReducerRelayT::value_type,
174
+ typename ReducerRelayT::binary_operation>::value,
175
+ typename ReducerRelayT::value_type> constexpr getIdentity () {
176
+ return getIdentityStatic ();
177
+ }
178
+
179
+ template <typename ReducerRelayT = ReducerT>
180
+ enable_if_t <
181
+ !IsKnownIdentityOp<typename ReducerRelayT::value_type,
182
+ typename ReducerRelayT::binary_operation>::value,
183
+ typename ReducerRelayT::value_type>
184
+ getIdentity () {
185
+ return MReducerRef.identity ();
186
+ }
187
+
188
+ // MSVC does not like static overloads of non-static functions, even if they
189
+ // are made mutually exclusive through SFINAE. Instead we use a new static
190
+ // function to be used when a static function is needed.
191
+ template <typename ReducerRelayT = ReducerT>
192
+ enable_if_t <
193
+ IsKnownIdentityOp<typename ReducerRelayT::value_type,
194
+ typename ReducerRelayT::binary_operation>::value,
195
+ typename ReducerRelayT::value_type> static constexpr getIdentityStatic () {
196
+ return ReducerT::getIdentity ();
197
+ }
198
+
199
+ private:
200
+ ReducerT &MReducerRef;
201
+ };
202
+
203
+ // Deduction guide to simplify the use of ReducerAccess.
204
+ template <typename ReducerT>
205
+ ReducerAccess (ReducerT &) -> ReducerAccess<ReducerT>;
206
+
162
207
// / Use CRTP to avoid redefining shorthand operators in terms of combine
163
208
// /
164
209
// / Also, for many types with known identity the operation 'atomic_combine()'
@@ -238,7 +283,7 @@ template <class Reducer> class combiner {
238
283
auto AtomicRef = sycl::atomic_ref<T, memory_order::relaxed,
239
284
getMemoryScope<Space>(), Space>(
240
285
address_space_cast<Space, access::decorated::no>(ReduVarPtr)[E]);
241
- Functor (std::move (AtomicRef), reducer-> getElement (E));
286
+ Functor (std::move (AtomicRef), ReducerAccess{* reducer}. getElement (E));
242
287
}
243
288
}
244
289
@@ -355,13 +400,15 @@ class reducer<
355
400
return *this ;
356
401
}
357
402
358
- T getIdentity () const { return MIdentity; }
403
+ T identity () const { return MIdentity; }
404
+
405
+ private:
406
+ template <typename ReducerT> friend class detail ::ReducerAccess;
359
407
360
408
T &getElement (size_t ) { return MValue; }
361
409
const T &getElement (size_t ) const { return MValue; }
362
- T MValue;
363
410
364
- private:
411
+ T MValue;
365
412
const T MIdentity;
366
413
BinaryOperation MBinaryOp;
367
414
};
@@ -392,7 +439,12 @@ class reducer<
392
439
return *this ;
393
440
}
394
441
395
- static T getIdentity () {
442
+ T identity () const { return getIdentity (); }
443
+
444
+ private:
445
+ template <typename ReducerT> friend class detail ::ReducerAccess;
446
+
447
+ static constexpr T getIdentity () {
396
448
return detail::known_identity_impl<BinaryOperation, T>::value;
397
449
}
398
450
@@ -419,6 +471,8 @@ class reducer<T, BinaryOperation, Dims, Extent, View,
419
471
}
420
472
421
473
private:
474
+ template <typename ReducerT> friend class detail ::ReducerAccess;
475
+
422
476
T &MElement;
423
477
BinaryOperation MBinaryOp;
424
478
};
@@ -444,11 +498,14 @@ class reducer<
444
498
return {MValue[Index], MBinaryOp};
445
499
}
446
500
447
- T getIdentity () const { return MIdentity; }
501
+ T identity () const { return MIdentity; }
502
+
503
+ private:
504
+ template <typename ReducerT> friend class detail ::ReducerAccess;
505
+
448
506
T &getElement (size_t E) { return MValue[E]; }
449
507
const T &getElement (size_t E) const { return MValue[E]; }
450
508
451
- private:
452
509
marray<T, Extent> MValue;
453
510
const T MIdentity;
454
511
BinaryOperation MBinaryOp;
@@ -477,14 +534,18 @@ class reducer<
477
534
return {MValue[Index], BinaryOperation ()};
478
535
}
479
536
480
- static T getIdentity () {
537
+ T identity () const { return getIdentity (); }
538
+
539
+ private:
540
+ template <typename ReducerT> friend class detail ::ReducerAccess;
541
+
542
+ static constexpr T getIdentity () {
481
543
return detail::known_identity_impl<BinaryOperation, T>::value;
482
544
}
483
545
484
546
T &getElement (size_t E) { return MValue[E]; }
485
547
const T &getElement (size_t E) const { return MValue[E]; }
486
548
487
- private:
488
549
marray<T, Extent> MValue;
489
550
};
490
551
@@ -769,8 +830,7 @@ class reduction_impl
769
830
// list of known operations does not break the existing programs.
770
831
if constexpr (is_known_identity) {
771
832
(void )Identity;
772
- return reducer_type::getIdentity ();
773
-
833
+ return ReducerAccess<reducer_type>::getIdentityStatic ();
774
834
} else {
775
835
return Identity;
776
836
}
@@ -788,8 +848,8 @@ class reduction_impl
788
848
template <typename _self = self,
789
849
enable_if_t <_self::is_known_identity> * = nullptr >
790
850
reduction_impl (RedOutVar Var, bool InitializeToIdentity = false )
791
- : algo(reducer_type::getIdentity(), BinaryOperation (),
792
- InitializeToIdentity, Var) {
851
+ : algo(ReducerAccess< reducer_type>::getIdentityStatic (),
852
+ BinaryOperation (), InitializeToIdentity, Var) {
793
853
if constexpr (!is_usm)
794
854
if (Var.size () != 1 )
795
855
throw sycl::runtime_error (errc::invalid,
@@ -896,7 +956,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
896
956
// Work-group cooperates to initialize multiple reduction variables
897
957
auto LID = NDId.get_local_id (0 );
898
958
for (size_t E = LID; E < NElements; E += NDId.get_local_range (0 )) {
899
- GroupSum[E] = Reducer.getIdentity ();
959
+ GroupSum[E] = ReducerAccess ( Reducer) .getIdentity ();
900
960
}
901
961
workGroupBarrier ();
902
962
@@ -909,7 +969,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
909
969
workGroupBarrier ();
910
970
if (LID == 0 ) {
911
971
for (size_t E = 0 ; E < NElements; ++E) {
912
- Reducer.getElement (E) = GroupSum[E];
972
+ ReducerAccess{ Reducer} .getElement (E) = GroupSum[E];
913
973
}
914
974
Reducer.template atomic_combine (&Out[0 ]);
915
975
}
@@ -959,7 +1019,7 @@ struct NDRangeReduction<
959
1019
// reduce_over_group is only defined for each T, not for span<T, ...>
960
1020
size_t LID = NDId.get_local_id (0 );
961
1021
for (int E = 0 ; E < NElements; ++E) {
962
- auto &RedElem = Reducer.getElement (E);
1022
+ auto &RedElem = ReducerAccess{ Reducer} .getElement (E);
963
1023
RedElem = reduce_over_group (Group, RedElem, BOp);
964
1024
if (LID == 0 ) {
965
1025
if (NWorkGroups == 1 ) {
@@ -970,7 +1030,7 @@ struct NDRangeReduction<
970
1030
Out[E] = RedElem;
971
1031
} else {
972
1032
PartialSums[NDId.get_group_linear_id () * NElements + E] =
973
- Reducer.getElement (E);
1033
+ ReducerAccess{ Reducer} .getElement (E);
974
1034
}
975
1035
}
976
1036
}
@@ -993,7 +1053,7 @@ struct NDRangeReduction<
993
1053
// Reduce each result separately
994
1054
// TODO: Opportunity to parallelize across elements.
995
1055
for (int E = 0 ; E < NElements; ++E) {
996
- auto LocalSum = Reducer.getIdentity ();
1056
+ auto LocalSum = ReducerAccess{ Reducer} .getIdentity ();
997
1057
for (size_t I = LID; I < NWorkGroups; I += WGSize)
998
1058
LocalSum = BOp (LocalSum, PartialSums[I * NElements + E]);
999
1059
auto Result = reduce_over_group (Group, LocalSum, BOp);
@@ -1083,7 +1143,7 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
1083
1143
for (int E = 0 ; E < NElements; ++E) {
1084
1144
1085
1145
// Copy the element to local memory to prepare it for tree-reduction.
1086
- LocalReds[LID] = Reducer.getElement (E);
1146
+ LocalReds[LID] = ReducerAccess{ Reducer} .getElement (E);
1087
1147
1088
1148
doTreeReduction (WGSize, LID, false , Identity, LocalReds, BOp,
1089
1149
[&]() { workGroupBarrier (); });
@@ -1158,8 +1218,8 @@ struct NDRangeReduction<reduction::strategy::group_reduce_and_atomic_cross_wg> {
1158
1218
1159
1219
typename Reduction::binary_operation BOp;
1160
1220
for (int E = 0 ; E < NElements; ++E) {
1161
- Reducer.getElement (E) =
1162
- reduce_over_group ( NDIt.get_group (), Reducer.getElement (E), BOp);
1221
+ ReducerAccess{ Reducer} .getElement (E) = reduce_over_group (
1222
+ NDIt.get_group (), ReducerAccess{ Reducer} .getElement (E), BOp);
1163
1223
}
1164
1224
if (NDIt.get_local_linear_id () == 0 )
1165
1225
Reducer.atomic_combine (&Out[0 ]);
@@ -1207,14 +1267,15 @@ struct NDRangeReduction<
1207
1267
for (int E = 0 ; E < NElements; ++E) {
1208
1268
1209
1269
// Copy the element to local memory to prepare it for tree-reduction.
1210
- LocalReds[LID] = Reducer.getElement (E);
1270
+ LocalReds[LID] = ReducerAccess{ Reducer} .getElement (E);
1211
1271
1212
1272
typename Reduction::binary_operation BOp;
1213
- doTreeReduction (WGSize, LID, IsPow2WG, Reducer.getIdentity (),
1214
- LocalReds, BOp, [&]() { NDIt.barrier (); });
1273
+ doTreeReduction (WGSize, LID, IsPow2WG,
1274
+ ReducerAccess{Reducer}.getIdentity (), LocalReds, BOp,
1275
+ [&]() { NDIt.barrier (); });
1215
1276
1216
1277
if (LID == 0 ) {
1217
- Reducer.getElement (E) =
1278
+ ReducerAccess{ Reducer} .getElement (E) =
1218
1279
IsPow2WG ? LocalReds[0 ] : BOp (LocalReds[0 ], LocalReds[WGSize]);
1219
1280
}
1220
1281
@@ -1282,7 +1343,7 @@ struct NDRangeReduction<
1282
1343
typename Reduction::binary_operation BOp;
1283
1344
for (int E = 0 ; E < NElements; ++E) {
1284
1345
typename Reduction::result_type PSum;
1285
- PSum = Reducer.getElement (E);
1346
+ PSum = ReducerAccess{ Reducer} .getElement (E);
1286
1347
PSum = reduce_over_group (NDIt.get_group (), PSum, BOp);
1287
1348
if (NDIt.get_local_linear_id () == 0 ) {
1288
1349
if (IsUpdateOfUserVar)
@@ -1346,7 +1407,8 @@ struct NDRangeReduction<
1346
1407
typename Reduction::result_type PSum =
1347
1408
(HasUniformWG || (GID < NWorkItems))
1348
1409
? In[GID * NElements + E]
1349
- : Reduction::reducer_type::getIdentity ();
1410
+ : ReducerAccess<typename Reduction::reducer_type>::
1411
+ getIdentityStatic ();
1350
1412
PSum = reduce_over_group (NDIt.get_group (), PSum, BOp);
1351
1413
if (NDIt.get_local_linear_id () == 0 ) {
1352
1414
if (IsUpdateOfUserVar)
@@ -1420,7 +1482,7 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
1420
1482
for (int E = 0 ; E < NElements; ++E) {
1421
1483
1422
1484
// Copy the element to local memory to prepare it for tree-reduction.
1423
- LocalReds[LID] = Reducer.getElement (E);
1485
+ LocalReds[LID] = ReducerAccess{ Reducer} .getElement (E);
1424
1486
1425
1487
doTreeReduction (WGSize, LID, IsPow2WG, ReduIdentity, LocalReds, BOp,
1426
1488
[&]() { NDIt.barrier (); });
@@ -1693,7 +1755,8 @@ void reduCGFuncImplScalar(
1693
1755
size_t WGSize = NDIt.get_local_range ().size ();
1694
1756
size_t LID = NDIt.get_local_linear_id ();
1695
1757
1696
- ((std::get<Is>(LocalAccsTuple)[LID] = std::get<Is>(ReducersTuple).MValue ),
1758
+ ((std::get<Is>(LocalAccsTuple)[LID] =
1759
+ ReducerAccess{std::get<Is>(ReducersTuple)}.getElement (0 )),
1697
1760
...);
1698
1761
1699
1762
// For work-groups, which size is not power of two, local accessors have
@@ -1744,7 +1807,7 @@ void reduCGFuncImplArrayHelper(bool Pow2WG, bool IsOneWG, nd_item<Dims> NDIt,
1744
1807
for (size_t E = 0 ; E < NElements; ++E) {
1745
1808
1746
1809
// Copy the element to local memory to prepare it for tree-reduction.
1747
- LocalReds[LID] = Reducer.getElement (E);
1810
+ LocalReds[LID] = ReducerAccess{ Reducer} .getElement (E);
1748
1811
1749
1812
doTreeReduction (WGSize, LID, Pow2WG, Identity, LocalReds, BOp,
1750
1813
[&]() { NDIt.barrier (); });
0 commit comments