Skip to content

Commit 068f6a9

Browse files
committed
Additional fix in reduction LIT test to be in sync with patch enabling lambdas for reduction
Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent 6e0b92e commit 068f6a9

File tree

1 file changed

+33
-60
lines changed

1 file changed

+33
-60
lines changed

sycl/test/reduction/reduction_ctor.cpp

Lines changed: 33 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@ void test_reducer(Reduction &Redu, T A, T B) {
2323
"Wrong result of binary operation.");
2424
}
2525

26-
template <typename T, typename Reduction>
27-
void test_reducer(Reduction &Redu, T Identity, T A, T B) {
28-
typename Reduction::reducer_type Reducer(Identity);
26+
template <typename T, typename Reduction, typename BinaryOperation>
27+
void test_reducer(Reduction &Redu, T Identity, BinaryOperation BOp, T A, T B) {
28+
typename Reduction::reducer_type Reducer(Identity, BOp);
2929
Reducer.combine(A);
3030
Reducer.combine(B);
3131

32-
typename Reduction::binary_operation BOp;
3332
T ExpectedValue = BOp(A, B);
3433
assert(ExpectedValue == Reducer.MValue &&
3534
"Wrong result of binary operation.");
@@ -40,35 +39,8 @@ class Known;
4039
template <typename T, int Dim, class BinaryOperation>
4140
class Unknown;
4241

43-
template <typename T>
44-
struct Point {
45-
Point() : X(0), Y(0) {}
46-
Point(T X, T Y) : X(X), Y(Y) {}
47-
Point(T V) : X(V), Y(V) {}
48-
bool operator==(const Point &P) const {
49-
return P.X == X && P.Y == Y;
50-
}
51-
T X;
52-
T Y;
53-
};
54-
55-
template <typename T>
56-
bool operator==(const Point<T> &A, const Point<T> &B) {
57-
return A.X == B.X && A.Y == B.Y;
58-
}
59-
60-
template <class T>
61-
struct PointPlus {
62-
using P = Point<T>;
63-
P operator()(const P &A, const P &B) const {
64-
return P(A.X + B.X, A.Y + B.Y);
65-
}
66-
};
67-
6842
template <typename T, int Dim, class BinaryOperation>
69-
void testKnown(T Identity, T A, T B) {
70-
71-
BinaryOperation BOp;
43+
void testKnown(T Identity, BinaryOperation BOp, T A, T B) {
7244
buffer<T, 1> ReduBuf(1);
7345

7446
queue Q;
@@ -81,17 +53,15 @@ void testKnown(T Identity, T A, T B) {
8153
assert(Redu.getIdentity() == Identity &&
8254
"Failed getIdentity() check().");
8355
test_reducer(Redu, A, B);
84-
test_reducer(Redu, Identity, A, B);
56+
test_reducer(Redu, Identity, BOp, A, B);
8557

8658
// Command group must have at least one task in it. Use an empty one.
8759
CGH.single_task<Known<T, Dim, BinaryOperation>>([=]() {});
8860
});
8961
}
9062

91-
template <typename T, int Dim, class BinaryOperation>
92-
void testUnknown(T Identity, T A, T B) {
93-
94-
BinaryOperation BOp;
63+
template <typename T, int Dim, typename KernelName, class BinaryOperation>
64+
void testUnknown(T Identity, BinaryOperation BOp, T A, T B) {
9565
buffer<T, 1> ReduBuf(1);
9666
queue Q;
9767
Q.submit([&](handler &CGH) {
@@ -102,38 +72,41 @@ void testUnknown(T Identity, T A, T B) {
10272
auto Redu = intel::reduction(ReduAcc, Identity, BOp);
10373
assert(Redu.getIdentity() == Identity &&
10474
"Failed getIdentity() check().");
105-
test_reducer(Redu, Identity, A, B);
75+
test_reducer(Redu, Identity, BOp, A, B);
10676

10777
// Command group must have at least one task in it. Use an empty one.
108-
CGH.single_task<Unknown<T, Dim, BinaryOperation>>([=]() {});
78+
CGH.single_task<KernelName>([=]() {});
10979
});
11080
}
11181

11282
template <typename T, class BinaryOperation>
113-
void testBoth(T Identity, T A, T B) {
114-
testKnown<T, 0, BinaryOperation>(Identity, A, B);
115-
testKnown<T, 1, BinaryOperation>(Identity, A, B);
116-
testUnknown<T, 0, BinaryOperation>(Identity, A, B);
117-
testUnknown<T, 1, BinaryOperation>(Identity, A, B);
83+
void testBoth(T Identity, BinaryOperation BOp, T A, T B) {
84+
testKnown<T, 0 >(Identity, BOp, A, B);
85+
testKnown<T, 1>(Identity, BOp, A, B);
86+
testUnknown<T, 0, Unknown<T, 0, BinaryOperation>>(Identity, BOp, A, B);
87+
testUnknown<T, 1, Unknown<T, 1, BinaryOperation>>(Identity, BOp, A, B);
11888
}
11989

12090
int main() {
121-
// testKnown does not pass identity to reduction ctor.
122-
testBoth<int, intel::plus<int>>(0, 1, 7);
123-
testBoth<int, std::multiplies<int>>(1, 1, 7);
124-
testBoth<int, intel::bit_or<int>>(0, 1, 8);
125-
testBoth<int, intel::bit_xor<int>>(0, 7, 3);
126-
testBoth<int, intel::bit_and<int>>(~0, 7, 3);
127-
testBoth<int, intel::minimum<int>>((std::numeric_limits<int>::max)(), 7, 3);
128-
testBoth<int, intel::maximum<int>>((std::numeric_limits<int>::min)(), 7, 3);
129-
130-
testBoth<float, intel::plus<float>>(0, 1, 7);
131-
testBoth<float, std::multiplies<float>>(1, 1, 7);
132-
testBoth<float, intel::minimum<float>>(getMaximumFPValue<float>(), 7, 3);
133-
testBoth<float, intel::maximum<float>>(getMinimumFPValue<float>(), 7, 3);
134-
135-
testUnknown<Point<float>, 0, PointPlus<float>>(Point<float>(0), Point<float>(1), Point<float>(7));
136-
testUnknown<Point<float>, 1, PointPlus<float>>(Point<float>(0), Point<float>(1), Point<float>(7));
91+
testBoth<int>(0, intel::plus<int>(), 1, 7);
92+
testBoth<int>(1, std::multiplies<int>(), 1, 7);
93+
testBoth<int>(0, intel::bit_or<int>(), 1, 8);
94+
testBoth<int>(0, intel::bit_xor<int>(), 7, 3);
95+
testBoth<int>(~0, intel::bit_and<int>(), 7, 3);
96+
testBoth<int>((std::numeric_limits<int>::max)(), intel::minimum<int>(), 7, 3);
97+
testBoth<int>((std::numeric_limits<int>::min)(), intel::maximum<int>(), 7, 3);
98+
99+
testBoth<float>(0, intel::plus<float>(), 1, 7);
100+
testBoth<float>(1, std::multiplies<float>(), 1, 7);
101+
testBoth<float>(getMaximumFPValue<float>(), intel::minimum<float>(), 7, 3);
102+
testBoth<float>(getMinimumFPValue<float>(), intel::maximum<float>(), 7, 3);
103+
104+
testUnknown<CustomVec<float>, 0, Unknown<CustomVec<float>, 0, CustomVecPlus<float>>>(
105+
CustomVec<float>(0), CustomVecPlus<float>(), CustomVec<float>(1), CustomVec<float>(7));
106+
testUnknown<CustomVec<float>, 1, Unknown<CustomVec<float>, 1, CustomVecPlus<float>>>(
107+
CustomVec<float>(0), CustomVecPlus<float>(), CustomVec<float>(1), CustomVec<float>(7));
108+
109+
testUnknown<int, 0, class BitOrName>(0, [](auto a, auto b) { return a | b; }, 1, 8);
137110

138111
std::cout << "Test passed\n";
139112
return 0;

0 commit comments

Comments
 (0)