@@ -159,51 +159,6 @@ struct ReducerTraits<reducer<T, BinaryOperation, Dims, Extent, View, Subst>> {
159
159
static constexpr size_t extent = Extent;
160
160
};
161
161
162
- // / Helper class for accessing internal reducer member functions.
163
- template <typename ReducerT> class ReducerAccess {
164
- public:
165
- ReducerAccess (ReducerT &ReducerRef) : MReducerRef(ReducerRef) {}
166
-
167
- template <typename ReducerRelayT = ReducerT> auto &getElement (size_t E) {
168
- return MReducerRef.getElement (E);
169
- }
170
-
171
- template <typename ReducerRelayT = ReducerT>
172
- enable_if_t <
173
- IsKnownIdentityOp<typename ReducerRelayT::value_type,
174
- typename ReducerRelayT::binary_operation>::value,
175
- typename ReducerRelayT::value_type> constexpr getIdentity () {
176
- return getIdentityStatic ();
177
- }
178
-
179
- template <typename ReducerRelayT = ReducerT>
180
- enable_if_t <
181
- !IsKnownIdentityOp<typename ReducerRelayT::value_type,
182
- typename ReducerRelayT::binary_operation>::value,
183
- typename ReducerRelayT::value_type>
184
- getIdentity () {
185
- return MReducerRef.identity ();
186
- }
187
-
188
- // MSVC does not like static overloads of non-static functions, even if they
189
- // are made mutually exclusive through SFINAE. Instead we use a new static
190
- // function to be used when a static function is needed.
191
- template <typename ReducerRelayT = ReducerT>
192
- enable_if_t <
193
- IsKnownIdentityOp<typename ReducerRelayT::value_type,
194
- typename ReducerRelayT::binary_operation>::value,
195
- typename ReducerRelayT::value_type> static constexpr getIdentityStatic () {
196
- return ReducerT::getIdentity ();
197
- }
198
-
199
- private:
200
- ReducerT &MReducerRef;
201
- };
202
-
203
- // Deduction guide to simplify the use of ReducerAccess.
204
- template <typename ReducerT>
205
- ReducerAccess (ReducerT &) -> ReducerAccess<ReducerT>;
206
-
207
162
// / Use CRTP to avoid redefining shorthand operators in terms of combine
208
163
// /
209
164
// / Also, for many types with known identity the operation 'atomic_combine()'
@@ -283,7 +238,7 @@ template <class Reducer> class combiner {
283
238
auto AtomicRef = sycl::atomic_ref<T, memory_order::relaxed,
284
239
getMemoryScope<Space>(), Space>(
285
240
address_space_cast<Space, access::decorated::no>(ReduVarPtr)[E]);
286
- Functor (std::move (AtomicRef), ReducerAccess{* reducer}. getElement (E));
241
+ Functor (std::move (AtomicRef), reducer-> getElement (E));
287
242
}
288
243
}
289
244
@@ -400,15 +355,13 @@ class reducer<
400
355
return *this ;
401
356
}
402
357
403
- T identity () const { return MIdentity; }
404
-
405
- private:
406
- template <typename ReducerT> friend class detail ::ReducerAccess;
358
+ T getIdentity () const { return MIdentity; }
407
359
408
360
T &getElement (size_t ) { return MValue; }
409
361
const T &getElement (size_t ) const { return MValue; }
410
-
411
362
T MValue;
363
+
364
+ private:
412
365
const T MIdentity;
413
366
BinaryOperation MBinaryOp;
414
367
};
@@ -439,12 +392,7 @@ class reducer<
439
392
return *this ;
440
393
}
441
394
442
- T identity () const { return getIdentity (); }
443
-
444
- private:
445
- template <typename ReducerT> friend class detail ::ReducerAccess;
446
-
447
- static constexpr T getIdentity () {
395
+ static T getIdentity () {
448
396
return detail::known_identity_impl<BinaryOperation, T>::value;
449
397
}
450
398
@@ -471,8 +419,6 @@ class reducer<T, BinaryOperation, Dims, Extent, View,
471
419
}
472
420
473
421
private:
474
- template <typename ReducerT> friend class detail ::ReducerAccess;
475
-
476
422
T &MElement;
477
423
BinaryOperation MBinaryOp;
478
424
};
@@ -498,14 +444,11 @@ class reducer<
498
444
return {MValue[Index], MBinaryOp};
499
445
}
500
446
501
- T identity () const { return MIdentity; }
502
-
503
- private:
504
- template <typename ReducerT> friend class detail ::ReducerAccess;
505
-
447
+ T getIdentity () const { return MIdentity; }
506
448
T &getElement (size_t E) { return MValue[E]; }
507
449
const T &getElement (size_t E) const { return MValue[E]; }
508
450
451
+ private:
509
452
marray<T, Extent> MValue;
510
453
const T MIdentity;
511
454
BinaryOperation MBinaryOp;
@@ -534,18 +477,14 @@ class reducer<
534
477
return {MValue[Index], BinaryOperation ()};
535
478
}
536
479
537
- T identity () const { return getIdentity (); }
538
-
539
- private:
540
- template <typename ReducerT> friend class detail ::ReducerAccess;
541
-
542
- static constexpr T getIdentity () {
480
+ static T getIdentity () {
543
481
return detail::known_identity_impl<BinaryOperation, T>::value;
544
482
}
545
483
546
484
T &getElement (size_t E) { return MValue[E]; }
547
485
const T &getElement (size_t E) const { return MValue[E]; }
548
486
487
+ private:
549
488
marray<T, Extent> MValue;
550
489
};
551
490
@@ -830,7 +769,8 @@ class reduction_impl
830
769
// list of known operations does not break the existing programs.
831
770
if constexpr (is_known_identity) {
832
771
(void )Identity;
833
- return ReducerAccess<reducer_type>::getIdentityStatic ();
772
+ return reducer_type::getIdentity ();
773
+
834
774
} else {
835
775
return Identity;
836
776
}
@@ -848,8 +788,8 @@ class reduction_impl
848
788
template <typename _self = self,
849
789
enable_if_t <_self::is_known_identity> * = nullptr >
850
790
reduction_impl (RedOutVar Var, bool InitializeToIdentity = false )
851
- : algo(ReducerAccess< reducer_type>::getIdentityStatic (),
852
- BinaryOperation (), InitializeToIdentity, Var) {
791
+ : algo(reducer_type::getIdentity(), BinaryOperation (),
792
+ InitializeToIdentity, Var) {
853
793
if constexpr (!is_usm)
854
794
if (Var.size () != 1 )
855
795
throw sycl::runtime_error (errc::invalid,
@@ -956,7 +896,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
956
896
// Work-group cooperates to initialize multiple reduction variables
957
897
auto LID = NDId.get_local_id (0 );
958
898
for (size_t E = LID; E < NElements; E += NDId.get_local_range (0 )) {
959
- GroupSum[E] = ReducerAccess ( Reducer) .getIdentity ();
899
+ GroupSum[E] = Reducer.getIdentity ();
960
900
}
961
901
workGroupBarrier ();
962
902
@@ -969,7 +909,7 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
969
909
workGroupBarrier ();
970
910
if (LID == 0 ) {
971
911
for (size_t E = 0 ; E < NElements; ++E) {
972
- ReducerAccess{ Reducer} .getElement (E) = GroupSum[E];
912
+ Reducer.getElement (E) = GroupSum[E];
973
913
}
974
914
Reducer.template atomic_combine (&Out[0 ]);
975
915
}
@@ -1019,7 +959,7 @@ struct NDRangeReduction<
1019
959
// reduce_over_group is only defined for each T, not for span<T, ...>
1020
960
size_t LID = NDId.get_local_id (0 );
1021
961
for (int E = 0 ; E < NElements; ++E) {
1022
- auto &RedElem = ReducerAccess{ Reducer} .getElement (E);
962
+ auto &RedElem = Reducer.getElement (E);
1023
963
RedElem = reduce_over_group (Group, RedElem, BOp);
1024
964
if (LID == 0 ) {
1025
965
if (NWorkGroups == 1 ) {
@@ -1030,7 +970,7 @@ struct NDRangeReduction<
1030
970
Out[E] = RedElem;
1031
971
} else {
1032
972
PartialSums[NDId.get_group_linear_id () * NElements + E] =
1033
- ReducerAccess{ Reducer} .getElement (E);
973
+ Reducer.getElement (E);
1034
974
}
1035
975
}
1036
976
}
@@ -1053,7 +993,7 @@ struct NDRangeReduction<
1053
993
// Reduce each result separately
1054
994
// TODO: Opportunity to parallelize across elements.
1055
995
for (int E = 0 ; E < NElements; ++E) {
1056
- auto LocalSum = ReducerAccess{ Reducer} .getIdentity ();
996
+ auto LocalSum = Reducer.getIdentity ();
1057
997
for (size_t I = LID; I < NWorkGroups; I += WGSize)
1058
998
LocalSum = BOp (LocalSum, PartialSums[I * NElements + E]);
1059
999
auto Result = reduce_over_group (Group, LocalSum, BOp);
@@ -1143,7 +1083,7 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
1143
1083
for (int E = 0 ; E < NElements; ++E) {
1144
1084
1145
1085
// Copy the element to local memory to prepare it for tree-reduction.
1146
- LocalReds[LID] = ReducerAccess{ Reducer} .getElement (E);
1086
+ LocalReds[LID] = Reducer.getElement (E);
1147
1087
1148
1088
doTreeReduction (WGSize, LID, false , Identity, LocalReds, BOp,
1149
1089
[&]() { workGroupBarrier (); });
@@ -1218,8 +1158,8 @@ struct NDRangeReduction<reduction::strategy::group_reduce_and_atomic_cross_wg> {
1218
1158
1219
1159
typename Reduction::binary_operation BOp;
1220
1160
for (int E = 0 ; E < NElements; ++E) {
1221
- ReducerAccess{ Reducer} .getElement (E) = reduce_over_group (
1222
- NDIt.get_group (), ReducerAccess{ Reducer} .getElement (E), BOp);
1161
+ Reducer.getElement (E) =
1162
+ reduce_over_group ( NDIt.get_group (), Reducer.getElement (E), BOp);
1223
1163
}
1224
1164
if (NDIt.get_local_linear_id () == 0 )
1225
1165
Reducer.atomic_combine (&Out[0 ]);
@@ -1267,15 +1207,14 @@ struct NDRangeReduction<
1267
1207
for (int E = 0 ; E < NElements; ++E) {
1268
1208
1269
1209
// Copy the element to local memory to prepare it for tree-reduction.
1270
- LocalReds[LID] = ReducerAccess{ Reducer} .getElement (E);
1210
+ LocalReds[LID] = Reducer.getElement (E);
1271
1211
1272
1212
typename Reduction::binary_operation BOp;
1273
- doTreeReduction (WGSize, LID, IsPow2WG,
1274
- ReducerAccess{Reducer}.getIdentity (), LocalReds, BOp,
1275
- [&]() { NDIt.barrier (); });
1213
+ doTreeReduction (WGSize, LID, IsPow2WG, Reducer.getIdentity (),
1214
+ LocalReds, BOp, [&]() { NDIt.barrier (); });
1276
1215
1277
1216
if (LID == 0 ) {
1278
- ReducerAccess{ Reducer} .getElement (E) =
1217
+ Reducer.getElement (E) =
1279
1218
IsPow2WG ? LocalReds[0 ] : BOp (LocalReds[0 ], LocalReds[WGSize]);
1280
1219
}
1281
1220
@@ -1343,7 +1282,7 @@ struct NDRangeReduction<
1343
1282
typename Reduction::binary_operation BOp;
1344
1283
for (int E = 0 ; E < NElements; ++E) {
1345
1284
typename Reduction::result_type PSum;
1346
- PSum = ReducerAccess{ Reducer} .getElement (E);
1285
+ PSum = Reducer.getElement (E);
1347
1286
PSum = reduce_over_group (NDIt.get_group (), PSum, BOp);
1348
1287
if (NDIt.get_local_linear_id () == 0 ) {
1349
1288
if (IsUpdateOfUserVar)
@@ -1407,8 +1346,7 @@ struct NDRangeReduction<
1407
1346
typename Reduction::result_type PSum =
1408
1347
(HasUniformWG || (GID < NWorkItems))
1409
1348
? In[GID * NElements + E]
1410
- : ReducerAccess<typename Reduction::reducer_type>::
1411
- getIdentityStatic ();
1349
+ : Reduction::reducer_type::getIdentity ();
1412
1350
PSum = reduce_over_group (NDIt.get_group (), PSum, BOp);
1413
1351
if (NDIt.get_local_linear_id () == 0 ) {
1414
1352
if (IsUpdateOfUserVar)
@@ -1482,7 +1420,7 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
1482
1420
for (int E = 0 ; E < NElements; ++E) {
1483
1421
1484
1422
// Copy the element to local memory to prepare it for tree-reduction.
1485
- LocalReds[LID] = ReducerAccess{ Reducer} .getElement (E);
1423
+ LocalReds[LID] = Reducer.getElement (E);
1486
1424
1487
1425
doTreeReduction (WGSize, LID, IsPow2WG, ReduIdentity, LocalReds, BOp,
1488
1426
[&]() { NDIt.barrier (); });
@@ -1755,8 +1693,7 @@ void reduCGFuncImplScalar(
1755
1693
size_t WGSize = NDIt.get_local_range ().size ();
1756
1694
size_t LID = NDIt.get_local_linear_id ();
1757
1695
1758
- ((std::get<Is>(LocalAccsTuple)[LID] =
1759
- ReducerAccess{std::get<Is>(ReducersTuple)}.getElement (0 )),
1696
+ ((std::get<Is>(LocalAccsTuple)[LID] = std::get<Is>(ReducersTuple).MValue ),
1760
1697
...);
1761
1698
1762
1699
// For work-groups, which size is not power of two, local accessors have
@@ -1807,7 +1744,7 @@ void reduCGFuncImplArrayHelper(bool Pow2WG, bool IsOneWG, nd_item<Dims> NDIt,
1807
1744
for (size_t E = 0 ; E < NElements; ++E) {
1808
1745
1809
1746
// Copy the element to local memory to prepare it for tree-reduction.
1810
- LocalReds[LID] = ReducerAccess{ Reducer} .getElement (E);
1747
+ LocalReds[LID] = Reducer.getElement (E);
1811
1748
1812
1749
doTreeReduction (WGSize, LID, Pow2WG, Identity, LocalReds, BOp,
1813
1750
[&]() { NDIt.barrier (); });
0 commit comments