@@ -147,18 +147,17 @@ using IsKnownIdentityOp =
147
147
template <typename T, class BinaryOperation , typename Subst = void >
148
148
class reducer {
149
149
public:
150
- reducer (const T &Identity) : MValue(Identity), MIdentity(Identity) {}
151
- void combine (const T &Partial) {
152
- BinaryOperation BOp;
153
- MValue = BOp (MValue, Partial);
154
- }
150
+ reducer (const T &Identity, BinaryOperation BOp)
151
+ : MValue(Identity), MIdentity(Identity), MBinaryOp(BOp) {}
152
+ void combine (const T &Partial) { MValue = MBinaryOp (MValue, Partial); }
155
153
156
154
T getIdentity () const { return MIdentity; }
157
155
158
156
T MValue;
159
157
160
158
private:
161
159
const T MIdentity;
160
+ BinaryOperation MBinaryOp;
162
161
};
163
162
164
163
// / Specialization of the generic class 'reducer'. It is used for reductions
@@ -183,7 +182,7 @@ class reducer<T, BinaryOperation,
183
182
enable_if_t <IsKnownIdentityOp<T, BinaryOperation>::value>> {
184
183
public:
185
184
reducer () : MValue(getIdentity()) {}
186
- reducer (const T &) : MValue(getIdentity()) {}
185
+ reducer (const T &, BinaryOperation ) : MValue(getIdentity()) {}
187
186
188
187
void combine (const T &Partial) {
189
188
BinaryOperation BOp;
@@ -405,7 +404,7 @@ class reduction_impl {
405
404
template <
406
405
typename _T = T, class _BinaryOperation = BinaryOperation,
407
406
enable_if_t <IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr >
408
- reduction_impl (accessor_type &Acc, const T &Identity)
407
+ reduction_impl (accessor_type &Acc, const T &Identity, BinaryOperation )
409
408
: MAcc(shared_ptr_class<accessor_type>(shared_ptr_class<accessor_type>{},
410
409
&Acc)),
411
410
MIdentity (getIdentity()) {
@@ -431,10 +430,10 @@ class reduction_impl {
431
430
template <
432
431
typename _T = T, class _BinaryOperation = BinaryOperation,
433
432
enable_if_t <!IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr >
434
- reduction_impl (accessor_type &Acc, const T &Identity)
433
+ reduction_impl (accessor_type &Acc, const T &Identity, BinaryOperation BOp )
435
434
: MAcc(shared_ptr_class<accessor_type>(shared_ptr_class<accessor_type>{},
436
435
&Acc)),
437
- MIdentity (Identity) {
436
+ MIdentity (Identity), MBinaryOp(BOp) {
438
437
assert (Acc.get_count () == 1 &&
439
438
" Only scalar/1-element reductions are supported now." );
440
439
}
@@ -456,7 +455,7 @@ class reduction_impl {
456
455
template <
457
456
typename _T = T, class _BinaryOperation = BinaryOperation,
458
457
enable_if_t <IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr >
459
- reduction_impl (T *VarPtr, const T &Identity)
458
+ reduction_impl (T *VarPtr, const T &Identity, BinaryOperation )
460
459
: MIdentity(Identity), MUSMPointer(VarPtr) {
461
460
// For now the implementation ignores the identity value given by user
462
461
// when the implementation knows the identity.
@@ -478,8 +477,8 @@ class reduction_impl {
478
477
template <
479
478
typename _T = T, class _BinaryOperation = BinaryOperation,
480
479
enable_if_t <!IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr >
481
- reduction_impl (T *VarPtr, const T &Identity)
482
- : MIdentity(Identity), MUSMPointer(VarPtr) {}
480
+ reduction_impl (T *VarPtr, const T &Identity, BinaryOperation BOp )
481
+ : MIdentity(Identity), MUSMPointer(VarPtr), MBinaryOp(BOp) {}
483
482
484
483
// / Associates reduction accessor with the given handler and saves reduction
485
484
// / buffer so that it is alive until the command group finishes the work.
@@ -563,6 +562,9 @@ class reduction_impl {
563
562
return OutPtr;
564
563
}
565
564
565
+ // / Returns the binary operation associated with the reduction.
566
+ BinaryOperation getBinaryOperation () const { return MBinaryOp; }
567
+
566
568
private:
567
569
// / Identity of the BinaryOperation.
568
570
// / The result of BinaryOperation(X, MIdentity) is equal to X for any X.
@@ -576,6 +578,8 @@ class reduction_impl {
576
578
// / USM pointer referencing the memory to where the result of the reduction
577
579
// / must be written. Applicable/used only for USM reductions.
578
580
T *MUSMPointer = nullptr ;
581
+
582
+ BinaryOperation MBinaryOp;
579
583
};
580
584
581
585
// / These are the forward declaration for the classes that help to create
@@ -794,9 +798,10 @@ reduCGFuncImpl(handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
794
798
typename Reduction::result_type ReduIdentity = Redu.getIdentity ();
795
799
using Name = typename get_reduction_main_kernel_name_t <
796
800
KernelName, KernelType, Reduction::is_usm, UniformPow2WG, OutputT>::name;
801
+ auto BOp = Redu.getBinaryOperation ();
797
802
CGH.parallel_for <Name>(Range, [=](nd_item<Dims> NDIt) {
798
803
// Call user's functions. Reducer.MValue gets initialized there.
799
- typename Reduction::reducer_type Reducer (ReduIdentity);
804
+ typename Reduction::reducer_type Reducer (ReduIdentity, BOp );
800
805
KernelFunc (NDIt, Reducer);
801
806
802
807
size_t WGSize = NDIt.get_local_range ().size ();
@@ -811,7 +816,6 @@ reduCGFuncImpl(handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
811
816
// Tree-reduction: reduce the local array LocalReds[:] to LocalReds[0]
812
817
// LocalReds[WGSize] accumulates last/odd elements when the step
813
818
// of tree-reduction loop is not even.
814
- typename Reduction::binary_operation BOp;
815
819
size_t PrevStep = WGSize;
816
820
for (size_t CurStep = PrevStep >> 1 ; CurStep > 0 ; CurStep >>= 1 ) {
817
821
if (LID < CurStep)
@@ -925,6 +929,7 @@ reduAuxCGFuncImpl(handler &CGH, size_t NWorkItems, size_t NWorkGroups,
925
929
auto LocalReds = Redu.getReadWriteLocalAcc (NumLocalElements, CGH);
926
930
927
931
auto ReduIdentity = Redu.getIdentity ();
932
+ auto BOp = Redu.getBinaryOperation ();
928
933
using Name = typename get_reduction_aux_kernel_name_t <
929
934
KernelName, KernelType, Reduction::is_usm, UniformPow2WG, OutputT>::name;
930
935
nd_range<1 > Range{range<1 >(NWorkItems), range<1 >(WGSize)};
@@ -943,7 +948,6 @@ reduAuxCGFuncImpl(handler &CGH, size_t NWorkItems, size_t NWorkGroups,
943
948
// Tree-reduction: reduce the local array LocalReds[:] to LocalReds[0]
944
949
// LocalReds[WGSize] accumulates last/odd elements when the step
945
950
// of tree-reduction loop is not even.
946
- typename Reduction::binary_operation BOp;
947
951
size_t PrevStep = WGSize;
948
952
for (size_t CurStep = PrevStep >> 1 ; CurStep > 0 ; CurStep >>= 1 ) {
949
953
if (LID < CurStep)
@@ -1022,10 +1026,10 @@ template <typename T, class BinaryOperation, int Dims, access::mode AccMode,
1022
1026
access::placeholder IsPH>
1023
1027
detail::reduction_impl<T, BinaryOperation, Dims, false , AccMode, IsPH>
1024
1028
reduction (accessor<T, Dims, AccMode, access::target::global_buffer, IsPH> &Acc,
1025
- const T &Identity, BinaryOperation) {
1029
+ const T &Identity, BinaryOperation BOp ) {
1026
1030
// The Combiner argument was needed only to define the BinaryOperation param.
1027
1031
return detail::reduction_impl<T, BinaryOperation, Dims, false , AccMode, IsPH>(
1028
- Acc, Identity);
1032
+ Acc, Identity, BOp );
1029
1033
}
1030
1034
1031
1035
// / Creates and returns an object implementing the reduction functionality.
@@ -1050,9 +1054,10 @@ reduction(accessor<T, Dims, AccMode, access::target::global_buffer, IsPH> &Acc,
1050
1054
// / \param Identity, and the binary operation used in the reduction.
1051
1055
template <typename T, class BinaryOperation >
1052
1056
detail::reduction_impl<T, BinaryOperation, 0 , true , access::mode::read_write>
1053
- reduction (T *VarPtr, const T &Identity, BinaryOperation) {
1057
+ reduction (T *VarPtr, const T &Identity, BinaryOperation BOp ) {
1054
1058
return detail::reduction_impl<T, BinaryOperation, 0 , true ,
1055
- access::mode::read_write>(VarPtr, Identity);
1059
+ access::mode::read_write>(VarPtr, Identity,
1060
+ BOp);
1056
1061
}
1057
1062
1058
1063
// / Creates and returns an object implementing the reduction functionality.
0 commit comments