Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit a1180cb

Browse files
authored
[SYCL] Add a LIT test for +=,*=,|=,^=,&= operations usable for reduce… (#140)
This test verifies this change-set: intel/llvm#3193 Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent 8e2ab8f commit a1180cb

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -fsycl-unnamed-lambda %s -o %t.out
2+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
3+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
4+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
5+
6+
// This test checks that if the custom type supports operations like +=, then
7+
// such operations can be used for the reduction objects in kernels.
8+
9+
#include <CL/sycl.hpp>
10+
#include <cmath>
11+
#include <iostream>
12+
13+
using namespace sycl;
14+
using namespace sycl::ONEAPI;
15+
16+
struct XY {
17+
constexpr XY() : X(0), Y(0) {}
18+
constexpr XY(int64_t X, int64_t Y) : X(X), Y(Y) {}
19+
int64_t X;
20+
int64_t Y;
21+
int64_t x() const { return X; };
22+
int64_t y() const { return Y; };
23+
};
24+
25+
enum OperationEqual {
26+
PlusEq,
27+
MultipliesEq,
28+
BitwiseOREq,
29+
BitwiseXOREq,
30+
BitwiseANDEq
31+
};
32+
33+
namespace std {
34+
template <> struct plus<XY> {
35+
using result_type = XY;
36+
using first_argument_type = XY;
37+
using second_argument_type = XY;
38+
constexpr XY operator()(const XY &lhs, const XY &rhs) const {
39+
return XY(lhs.X + rhs.X, lhs.Y + rhs.Y);
40+
}
41+
};
42+
43+
template <> struct multiplies<XY> {
44+
using result_type = XY;
45+
using first_argument_type = XY;
46+
using second_argument_type = XY;
47+
constexpr XY operator()(const XY &lhs, const XY &rhs) const {
48+
return XY(lhs.X * rhs.X, lhs.Y * rhs.Y);
49+
}
50+
};
51+
52+
template <> struct bit_or<XY> {
53+
using result_type = XY;
54+
using first_argument_type = XY;
55+
using second_argument_type = XY;
56+
constexpr XY operator()(const XY &lhs, const XY &rhs) const {
57+
return XY(lhs.X | rhs.X, lhs.Y | rhs.Y);
58+
}
59+
};
60+
61+
template <> struct bit_xor<XY> {
62+
using result_type = XY;
63+
using first_argument_type = XY;
64+
using second_argument_type = XY;
65+
constexpr XY operator()(const XY &lhs, const XY &rhs) const {
66+
return XY(lhs.X ^ rhs.X, lhs.Y ^ rhs.Y);
67+
}
68+
};
69+
70+
template <> struct bit_and<XY> {
71+
using result_type = XY;
72+
using first_argument_type = XY;
73+
using second_argument_type = XY;
74+
constexpr XY operator()(const XY &lhs, const XY &rhs) const {
75+
return XY(lhs.X & rhs.X, lhs.Y & rhs.Y);
76+
}
77+
};
78+
} // namespace std
79+
80+
template <typename T, typename BinaryOperation, OperationEqual OpEq,
81+
bool IsFP = false>
82+
int test(T Identity) {
83+
constexpr size_t N = 16;
84+
constexpr size_t L = 4;
85+
86+
queue Q;
87+
T *Data = malloc_shared<T>(N, Q);
88+
T *Res = malloc_shared<T>(1, Q);
89+
T Expected = Identity;
90+
BinaryOperation BOp;
91+
for (int I = 0; I < N; I++) {
92+
Data[I] = T{I, I + 1};
93+
Expected = BOp(Expected, T{I, I + 1});
94+
}
95+
96+
*Res = Identity;
97+
auto Red = reduction(Res, Identity, BOp);
98+
nd_range<1> NDR{N, L};
99+
if constexpr (OpEq == PlusEq) {
100+
auto Lambda = [=](nd_item<1> ID, auto &Sum) {
101+
Sum += Data[ID.get_global_id(0)];
102+
};
103+
Q.submit([&](handler &H) { H.parallel_for(NDR, Red, Lambda); }).wait();
104+
} else if constexpr (OpEq == MultipliesEq) {
105+
auto Lambda = [=](nd_item<1> ID, auto &Sum) {
106+
Sum *= Data[ID.get_global_id(0)];
107+
};
108+
Q.submit([&](handler &H) { H.parallel_for(NDR, Red, Lambda); }).wait();
109+
} else if constexpr (OpEq == BitwiseOREq) {
110+
auto Lambda = [=](nd_item<1> ID, auto &Sum) {
111+
Sum |= Data[ID.get_global_id(0)];
112+
};
113+
Q.submit([&](handler &H) { H.parallel_for(NDR, Red, Lambda); }).wait();
114+
} else if constexpr (OpEq == BitwiseXOREq) {
115+
auto Lambda = [=](nd_item<1> ID, auto &Sum) {
116+
Sum ^= Data[ID.get_global_id(0)];
117+
};
118+
Q.submit([&](handler &H) { H.parallel_for(NDR, Red, Lambda); }).wait();
119+
} else if constexpr (OpEq == BitwiseANDEq) {
120+
auto Lambda = [=](nd_item<1> ID, auto &Sum) {
121+
Sum &= Data[ID.get_global_id(0)];
122+
};
123+
Q.submit([&](handler &H) { H.parallel_for(NDR, Red, Lambda); }).wait();
124+
}
125+
126+
int Error = 0;
127+
if constexpr (IsFP) {
128+
T Diff = (Expected / *Res) - T{1};
129+
Error = (std::abs(Diff.x()) > 0.5 || std::abs(Diff.y()) > 0.5) ? 1 : 0;
130+
} else {
131+
Error = (Expected.x() != Res->x() || Expected.y() != Res->y()) ? 1 : 0;
132+
}
133+
if (Error)
134+
std::cerr << "Error: expected = (" << Expected.x() << ", " << Expected.y()
135+
<< "); computed = (" << Res->x() << ", " << Res->y() << ")\n";
136+
137+
free(Res, Q);
138+
free(Data, Q);
139+
return Error;
140+
}
141+
142+
template <typename T> int testFPPack() {
143+
int Error = 0;
144+
Error += test<T, std::plus<>, PlusEq, true>(T{});
145+
Error += test<T, std::plus<T>, PlusEq, true>(T{});
146+
Error += test<T, std::multiplies<>, MultipliesEq, true>(T{1, 1});
147+
Error += test<T, std::multiplies<T>, MultipliesEq, true>(T{1, 1});
148+
return Error;
149+
}
150+
151+
template <typename T> int testINTPack() {
152+
int Error = 0;
153+
Error += test<T, std::plus<T>, PlusEq>(T{});
154+
Error += test<T, std::multiplies<T>, MultipliesEq>(T{1, 1});
155+
Error += test<T, std::bit_or<T>, BitwiseOREq>(T{});
156+
Error += test<T, std::bit_xor<T>, BitwiseXOREq>(T{});
157+
Error += test<T, std::bit_and<T>, BitwiseANDEq>(T{~0, ~0});
158+
return Error;
159+
}
160+
161+
int main() {
162+
int Error = 0;
163+
Error += testFPPack<float2>();
164+
Error += testINTPack<XY>();
165+
166+
// TODO: enable this test for int vetors as well.
167+
// This test revealed an existing/unrelated problem with the type trait
168+
// known_identity_impl. It returns true for 'int2' type, but the
169+
// corrsponding functionality returning identity value is not implemented
170+
// correctly.
171+
// Error += testINTPack<int2>();
172+
173+
std::cout << (Error ? "Failed\n" : "Passed.\n");
174+
return Error;
175+
}

0 commit comments

Comments
 (0)