Skip to content

Commit 6e0b92e

Browse files
committed
[SYCL] Support lambda functions passed to reduction
Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent aaa575d commit 6e0b92e

File tree

2 files changed

+95
-19
lines changed

2 files changed

+95
-19
lines changed

sycl/include/CL/sycl/intel/reduction.hpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,17 @@ using IsKnownIdentityOp =
147147
template <typename T, class BinaryOperation, typename Subst = void>
148148
class reducer {
149149
public:
150-
reducer(const T &Identity) : MValue(Identity), MIdentity(Identity) {}
151-
void combine(const T &Partial) {
152-
BinaryOperation BOp;
153-
MValue = BOp(MValue, Partial);
154-
}
150+
reducer(const T &Identity, BinaryOperation BOp)
151+
: MValue(Identity), MIdentity(Identity), MBinaryOp(BOp) {}
152+
void combine(const T &Partial) { MValue = MBinaryOp(MValue, Partial); }
155153

156154
T getIdentity() const { return MIdentity; }
157155

158156
T MValue;
159157

160158
private:
161159
const T MIdentity;
160+
BinaryOperation MBinaryOp;
162161
};
163162

164163
/// Specialization of the generic class 'reducer'. It is used for reductions
@@ -183,7 +182,7 @@ class reducer<T, BinaryOperation,
183182
enable_if_t<IsKnownIdentityOp<T, BinaryOperation>::value>> {
184183
public:
185184
reducer() : MValue(getIdentity()) {}
186-
reducer(const T &) : MValue(getIdentity()) {}
185+
reducer(const T &, BinaryOperation) : MValue(getIdentity()) {}
187186

188187
void combine(const T &Partial) {
189188
BinaryOperation BOp;
@@ -405,7 +404,7 @@ class reduction_impl {
405404
template <
406405
typename _T = T, class _BinaryOperation = BinaryOperation,
407406
enable_if_t<IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr>
408-
reduction_impl(accessor_type &Acc, const T &Identity)
407+
reduction_impl(accessor_type &Acc, const T &Identity, BinaryOperation)
409408
: MAcc(shared_ptr_class<accessor_type>(shared_ptr_class<accessor_type>{},
410409
&Acc)),
411410
MIdentity(getIdentity()) {
@@ -431,10 +430,10 @@ class reduction_impl {
431430
template <
432431
typename _T = T, class _BinaryOperation = BinaryOperation,
433432
enable_if_t<!IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr>
434-
reduction_impl(accessor_type &Acc, const T &Identity)
433+
reduction_impl(accessor_type &Acc, const T &Identity, BinaryOperation BOp)
435434
: MAcc(shared_ptr_class<accessor_type>(shared_ptr_class<accessor_type>{},
436435
&Acc)),
437-
MIdentity(Identity) {
436+
MIdentity(Identity), MBinaryOp(BOp) {
438437
assert(Acc.get_count() == 1 &&
439438
"Only scalar/1-element reductions are supported now.");
440439
}
@@ -456,7 +455,7 @@ class reduction_impl {
456455
template <
457456
typename _T = T, class _BinaryOperation = BinaryOperation,
458457
enable_if_t<IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr>
459-
reduction_impl(T *VarPtr, const T &Identity)
458+
reduction_impl(T *VarPtr, const T &Identity, BinaryOperation)
460459
: MIdentity(Identity), MUSMPointer(VarPtr) {
461460
// For now the implementation ignores the identity value given by user
462461
// when the implementation knows the identity.
@@ -478,8 +477,8 @@ class reduction_impl {
478477
template <
479478
typename _T = T, class _BinaryOperation = BinaryOperation,
480479
enable_if_t<!IsKnownIdentityOp<_T, _BinaryOperation>::value> * = nullptr>
481-
reduction_impl(T *VarPtr, const T &Identity)
482-
: MIdentity(Identity), MUSMPointer(VarPtr) {}
480+
reduction_impl(T *VarPtr, const T &Identity, BinaryOperation BOp)
481+
: MIdentity(Identity), MUSMPointer(VarPtr), MBinaryOp(BOp) {}
483482

484483
/// Associates reduction accessor with the given handler and saves reduction
485484
/// buffer so that it is alive until the command group finishes the work.
@@ -563,6 +562,9 @@ class reduction_impl {
563562
return OutPtr;
564563
}
565564

565+
/// Returns the binary operation associated with the reduction.
566+
BinaryOperation getBinaryOperation() const { return MBinaryOp; }
567+
566568
private:
567569
/// Identity of the BinaryOperation.
568570
/// The result of BinaryOperation(X, MIdentity) is equal to X for any X.
@@ -576,6 +578,8 @@ class reduction_impl {
576578
/// USM pointer referencing the memory to where the result of the reduction
577579
/// must be written. Applicable/used only for USM reductions.
578580
T *MUSMPointer = nullptr;
581+
582+
BinaryOperation MBinaryOp;
579583
};
580584

581585
/// These are the forward declaration for the classes that help to create
@@ -794,9 +798,10 @@ reduCGFuncImpl(handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
794798
typename Reduction::result_type ReduIdentity = Redu.getIdentity();
795799
using Name = typename get_reduction_main_kernel_name_t<
796800
KernelName, KernelType, Reduction::is_usm, UniformPow2WG, OutputT>::name;
801+
auto BOp = Redu.getBinaryOperation();
797802
CGH.parallel_for<Name>(Range, [=](nd_item<Dims> NDIt) {
798803
// Call user's functions. Reducer.MValue gets initialized there.
799-
typename Reduction::reducer_type Reducer(ReduIdentity);
804+
typename Reduction::reducer_type Reducer(ReduIdentity, BOp);
800805
KernelFunc(NDIt, Reducer);
801806

802807
size_t WGSize = NDIt.get_local_range().size();
@@ -811,7 +816,6 @@ reduCGFuncImpl(handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
811816
// Tree-reduction: reduce the local array LocalReds[:] to LocalReds[0]
812817
// LocalReds[WGSize] accumulates last/odd elements when the step
813818
// of tree-reduction loop is not even.
814-
typename Reduction::binary_operation BOp;
815819
size_t PrevStep = WGSize;
816820
for (size_t CurStep = PrevStep >> 1; CurStep > 0; CurStep >>= 1) {
817821
if (LID < CurStep)
@@ -925,6 +929,7 @@ reduAuxCGFuncImpl(handler &CGH, size_t NWorkItems, size_t NWorkGroups,
925929
auto LocalReds = Redu.getReadWriteLocalAcc(NumLocalElements, CGH);
926930

927931
auto ReduIdentity = Redu.getIdentity();
932+
auto BOp = Redu.getBinaryOperation();
928933
using Name = typename get_reduction_aux_kernel_name_t<
929934
KernelName, KernelType, Reduction::is_usm, UniformPow2WG, OutputT>::name;
930935
nd_range<1> Range{range<1>(NWorkItems), range<1>(WGSize)};
@@ -943,7 +948,6 @@ reduAuxCGFuncImpl(handler &CGH, size_t NWorkItems, size_t NWorkGroups,
943948
// Tree-reduction: reduce the local array LocalReds[:] to LocalReds[0]
944949
// LocalReds[WGSize] accumulates last/odd elements when the step
945950
// of tree-reduction loop is not even.
946-
typename Reduction::binary_operation BOp;
947951
size_t PrevStep = WGSize;
948952
for (size_t CurStep = PrevStep >> 1; CurStep > 0; CurStep >>= 1) {
949953
if (LID < CurStep)
@@ -1022,10 +1026,10 @@ template <typename T, class BinaryOperation, int Dims, access::mode AccMode,
10221026
access::placeholder IsPH>
10231027
detail::reduction_impl<T, BinaryOperation, Dims, false, AccMode, IsPH>
10241028
reduction(accessor<T, Dims, AccMode, access::target::global_buffer, IsPH> &Acc,
1025-
const T &Identity, BinaryOperation) {
1029+
const T &Identity, BinaryOperation BOp) {
10261030
// The Combiner argument was needed only to define the BinaryOperation param.
10271031
return detail::reduction_impl<T, BinaryOperation, Dims, false, AccMode, IsPH>(
1028-
Acc, Identity);
1032+
Acc, Identity, BOp);
10291033
}
10301034

10311035
/// Creates and returns an object implementing the reduction functionality.
@@ -1050,9 +1054,10 @@ reduction(accessor<T, Dims, AccMode, access::target::global_buffer, IsPH> &Acc,
10501054
/// \param Identity, and the binary operation used in the reduction.
10511055
template <typename T, class BinaryOperation>
10521056
detail::reduction_impl<T, BinaryOperation, 0, true, access::mode::read_write>
1053-
reduction(T *VarPtr, const T &Identity, BinaryOperation) {
1057+
reduction(T *VarPtr, const T &Identity, BinaryOperation BOp) {
10541058
return detail::reduction_impl<T, BinaryOperation, 0, true,
1055-
access::mode::read_write>(VarPtr, Identity);
1059+
access::mode::read_write>(VarPtr, Identity,
1060+
BOp);
10561061
}
10571062

10581063
/// Creates and returns an object implementing the reduction functionality.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// UNSUPPORTED: cuda
2+
// OpenCL C 2.x alike work-group functions not yet supported by CUDA.
3+
//
4+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
5+
// RUNx: env SYCL_DEVICE_TYPE=HOST %t.out
6+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
7+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
8+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
9+
10+
// This test performs basic checks of parallel_for(nd_range, reduction, lambda)
11+
12+
#include "reduction_utils.hpp"
13+
#include <CL/sycl.hpp>
14+
#include <cassert>
15+
16+
using namespace cl::sycl;
17+
18+
template <class KernelName, typename T, class BinaryOperation>
19+
void test(T Identity, BinaryOperation BOp, size_t WGSize, size_t NWItems) {
20+
buffer<T, 1> InBuf(NWItems);
21+
buffer<T, 1> OutBuf(1);
22+
23+
// Initialize.
24+
T CorrectOut;
25+
initInputData(InBuf, CorrectOut, Identity, BOp, NWItems);
26+
27+
// Compute.
28+
queue Q;
29+
Q.submit([&](handler &CGH) {
30+
auto In = InBuf.template get_access<access::mode::read>(CGH);
31+
auto Out = OutBuf.template get_access<access::mode::discard_write>(CGH);
32+
auto Redu = intel::reduction(Out, Identity, BOp);
33+
34+
range<1> GlobalRange(NWItems);
35+
range<1> LocalRange(WGSize);
36+
nd_range<1> NDRange(GlobalRange, LocalRange);
37+
CGH.parallel_for<KernelName>(NDRange, Redu,
38+
[=](nd_item<1> NDIt, auto &Sum) {
39+
Sum.combine(In[NDIt.get_global_linear_id()]);
40+
});
41+
});
42+
43+
// Check correctness.
44+
auto Out = OutBuf.template get_access<access::mode::read>();
45+
T ComputedOut = *(Out.get_pointer());
46+
if (ComputedOut != CorrectOut) {
47+
std::cout << "NWItems = " << NWItems << ", WGSize = " << WGSize << "\n";
48+
std::cout << "Computed value: " << ComputedOut
49+
<< ", Expected value: " << CorrectOut << "\n";
50+
assert(0 && "Wrong value.");
51+
}
52+
}
53+
54+
int main() {
55+
test<class AddTestName, int>(
56+
0, [](auto x, auto y) { return (x + y); }, 8, 32);
57+
test<class MulTestName, int>(
58+
0, [](auto x, auto y) { return (x * y); }, 8, 32);
59+
60+
// Check with CUSTOM type.
61+
test<class CustomAddTestname, CustomVec<long long>>(
62+
CustomVec<long long>(0),
63+
[](auto x, auto y) {
64+
CustomVecPlus<long long> BOp;
65+
return BOp(x, y);
66+
},
67+
4, 64);
68+
69+
std::cout << "Test passed\n";
70+
return 0;
71+
}

0 commit comments

Comments
 (0)