Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 519f704

Browse files
committed
Merge remote-tracking branch 'jack/bfloat16-class-tests' into 9-may-22-cuda
2 parents 613ede6 + 4e1d6e4 commit 519f704

File tree

2 files changed

+395
-93
lines changed

2 files changed

+395
-93
lines changed

SYCL/BFloat16/bfloat16_builtins.cpp

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
// REQUIRES: cuda
2+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out -Xsycl-target-backend --cuda-gpu-arch=sm_80
3+
// RUN: %t.out
4+
5+
#include <CL/sycl.hpp>
6+
7+
#include <cmath>
8+
#include <vector>
9+
10+
using namespace cl::sycl;
11+
using sycl::ext::oneapi::experimental::bfloat16;
12+
13+
constexpr int N = 60; // divisible by all tested array sizes
14+
constexpr float bf16_eps = 0.00390625;
15+
16+
float make_fp32(uint16_t x) {
17+
uint32_t y = x;
18+
y = y << 16;
19+
auto res = reinterpret_cast<float *>(&y);
20+
return *res;
21+
}
22+
23+
bool check(float a, float b) {
24+
return fabs(2 * (a - b) / (a + b)) > bf16_eps * 2;
25+
}
26+
27+
#define TEST_BUILTIN_1_SCAL_IMPL(NAME) \
28+
{ \
29+
buffer<float> a_buf(&a[0], N); \
30+
buffer<int> err_buf(&err, 1); \
31+
q.submit([&](handler &cgh) { \
32+
accessor<float, 1, access::mode::read_write, target::device> A(a_buf, \
33+
cgh); \
34+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
35+
cgh.parallel_for(N, [=](id<1> index) { \
36+
if (check(make_fp32(NAME(bfloat16{A[index]}).raw()), \
37+
NAME(A[index]))) { \
38+
ERR[0] = 1; \
39+
} \
40+
}); \
41+
}); \
42+
} \
43+
assert(err == 0);
44+
45+
#define TEST_BUILTIN_1_ARR_IMPL(NAME, SZ) \
46+
{ \
47+
buffer<float, 2> a_buf{range<2>{N / SZ, SZ}}; \
48+
buffer<int> err_buf(&err, 1); \
49+
q.submit([&](handler &cgh) { \
50+
accessor<float, 2, access::mode::read_write, target::device> A(a_buf, \
51+
cgh); \
52+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
53+
cgh.parallel_for(N / SZ, [=](id<1> index) { \
54+
marray<bfloat16, SZ> arg; \
55+
for (int i = 0; i < SZ; i++) { \
56+
arg[i] = A[index][i]; \
57+
} \
58+
marray<bfloat16, SZ> res = NAME(arg); \
59+
for (int i = 0; i < SZ; i++) { \
60+
if (check(make_fp32(res[i].raw()), NAME(A[index][i]))) { \
61+
ERR[0] = 1; \
62+
} \
63+
} \
64+
}); \
65+
}); \
66+
} \
67+
assert(err == 0);
68+
69+
#define TEST_BUILTIN_1(NAME) \
70+
TEST_BUILTIN_1_SCAL_IMPL(NAME) \
71+
TEST_BUILTIN_1_ARR_IMPL(NAME, 1) \
72+
TEST_BUILTIN_1_ARR_IMPL(NAME, 2) \
73+
TEST_BUILTIN_1_ARR_IMPL(NAME, 3) \
74+
TEST_BUILTIN_1_ARR_IMPL(NAME, 4) \
75+
TEST_BUILTIN_1_ARR_IMPL(NAME, 5)
76+
77+
#define TEST_BUILTIN_2_SCAL_IMPL(NAME) \
78+
{ \
79+
buffer<float> a_buf(&a[0], N); \
80+
buffer<float> b_buf(&b[0], N); \
81+
buffer<int> err_buf(&err, 1); \
82+
q.submit([&](handler &cgh) { \
83+
accessor<float, 1, access::mode::read_write, target::device> A(a_buf, \
84+
cgh); \
85+
accessor<float, 1, access::mode::read_write, target::device> B(b_buf, \
86+
cgh); \
87+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
88+
cgh.parallel_for(N, [=](id<1> index) { \
89+
if (check( \
90+
make_fp32(NAME(bfloat16{A[index]}, bfloat16{B[index]}).raw()), \
91+
NAME(A[index], B[index]))) { \
92+
ERR[0] = 1; \
93+
} \
94+
}); \
95+
}); \
96+
} \
97+
assert(err == 0);
98+
99+
#define TEST_BUILTIN_2_ARR_IMPL(NAME, SZ) \
100+
{ \
101+
buffer<float, 2> a_buf{range<2>{N / SZ, SZ}}; \
102+
buffer<float, 2> b_buf{range<2>{N / SZ, SZ}}; \
103+
buffer<int> err_buf(&err, 1); \
104+
q.submit([&](handler &cgh) { \
105+
accessor<float, 2, access::mode::read_write, target::device> A(a_buf, \
106+
cgh); \
107+
accessor<float, 2, access::mode::read_write, target::device> B(b_buf, \
108+
cgh); \
109+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
110+
cgh.parallel_for(N / SZ, [=](id<1> index) { \
111+
marray<bfloat16, SZ> arg0, arg1; \
112+
for (int i = 0; i < SZ; i++) { \
113+
arg0[i] = A[index][i]; \
114+
arg1[i] = B[index][i]; \
115+
} \
116+
marray<bfloat16, SZ> res = NAME(arg0, arg1); \
117+
for (int i = 0; i < SZ; i++) { \
118+
if (check(make_fp32(res[i].raw()), \
119+
NAME(A[index][i], B[index][i]))) { \
120+
ERR[0] = 1; \
121+
} \
122+
} \
123+
}); \
124+
}); \
125+
} \
126+
assert(err == 0);
127+
128+
#define TEST_BUILTIN_2(NAME) \
129+
TEST_BUILTIN_2_SCAL_IMPL(NAME) \
130+
TEST_BUILTIN_2_ARR_IMPL(NAME, 1) \
131+
TEST_BUILTIN_2_ARR_IMPL(NAME, 2) \
132+
TEST_BUILTIN_2_ARR_IMPL(NAME, 3) \
133+
TEST_BUILTIN_2_ARR_IMPL(NAME, 4) \
134+
TEST_BUILTIN_2_ARR_IMPL(NAME, 5)
135+
136+
#define TEST_BUILTIN_3_SCAL_IMPL(NAME) \
137+
{ \
138+
buffer<float> a_buf(&a[0], N); \
139+
buffer<float> b_buf(&b[0], N); \
140+
buffer<float> c_buf(&c[0], N); \
141+
buffer<int> err_buf(&err, 1); \
142+
q.submit([&](handler &cgh) { \
143+
accessor<float, 1, access::mode::read_write, target::device> A(a_buf, \
144+
cgh); \
145+
accessor<float, 1, access::mode::read_write, target::device> B(b_buf, \
146+
cgh); \
147+
accessor<float, 1, access::mode::read_write, target::device> C(c_buf, \
148+
cgh); \
149+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
150+
cgh.parallel_for(N, [=](id<1> index) { \
151+
if (check(make_fp32(NAME(bfloat16{A[index]}, bfloat16{B[index]}, \
152+
bfloat16{C[index]}) \
153+
.raw()), \
154+
NAME(A[index], B[index], C[index]))) { \
155+
ERR[0] = 1; \
156+
} \
157+
}); \
158+
}); \
159+
} \
160+
assert(err == 0);
161+
162+
#define TEST_BUILTIN_3_ARR_IMPL(NAME, SZ) \
163+
{ \
164+
buffer<float, 2> a_buf{range<2>{N / SZ, SZ}}; \
165+
buffer<float, 2> b_buf{range<2>{N / SZ, SZ}}; \
166+
buffer<float, 2> c_buf{range<2>{N / SZ, SZ}}; \
167+
buffer<int> err_buf(&err, 1); \
168+
q.submit([&](handler &cgh) { \
169+
accessor<float, 2, access::mode::read_write, target::device> A(a_buf, \
170+
cgh); \
171+
accessor<float, 2, access::mode::read_write, target::device> B(b_buf, \
172+
cgh); \
173+
accessor<float, 2, access::mode::read_write, target::device> C(c_buf, \
174+
cgh); \
175+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
176+
cgh.parallel_for(N / SZ, [=](id<1> index) { \
177+
marray<bfloat16, SZ> arg0, arg1, arg2; \
178+
for (int i = 0; i < SZ; i++) { \
179+
arg0[i] = A[index][i]; \
180+
arg1[i] = B[index][i]; \
181+
arg2[i] = C[index][i]; \
182+
} \
183+
marray<bfloat16, SZ> res = NAME(arg0, arg1, arg2); \
184+
for (int i = 0; i < SZ; i++) { \
185+
if (check(make_fp32(res[i].raw()), \
186+
NAME(A[index][i], B[index][i], C[index][i]))) { \
187+
ERR[0] = 1; \
188+
} \
189+
} \
190+
}); \
191+
}); \
192+
} \
193+
assert(err == 0);
194+
195+
#define TEST_BUILTIN_3(NAME) \
196+
TEST_BUILTIN_3_SCAL_IMPL(NAME) \
197+
TEST_BUILTIN_3_ARR_IMPL(NAME, 1) \
198+
TEST_BUILTIN_3_ARR_IMPL(NAME, 2) \
199+
TEST_BUILTIN_3_ARR_IMPL(NAME, 3) \
200+
TEST_BUILTIN_3_ARR_IMPL(NAME, 4) \
201+
TEST_BUILTIN_3_ARR_IMPL(NAME, 5)
202+
203+
#define TEST_BUILTIN_2_NAN(NAME) \
204+
{ \
205+
buffer<int> err_buf(&err, 1); \
206+
buffer<float> nan_buf(&check_nan, 1); \
207+
q.submit([&](handler &cgh) { \
208+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
209+
accessor<float, 1, access::mode::write, target::device> checkNAN( \
210+
nan_buf, cgh); \
211+
cgh.single_task([=]() { \
212+
checkNAN[0] = make_fp32(NAME(bfloat16{NAN}, bfloat16{NAN}).raw()); \
213+
if ((make_fp32(NAME(bfloat16{2}, bfloat16{NAN}).raw()) != 2) || \
214+
(make_fp32(NAME(bfloat16{NAN}, bfloat16{2}).raw()) != 2)) { \
215+
ERR[0] = 1; \
216+
} \
217+
}); \
218+
}); \
219+
} \
220+
assert(err == 0); \
221+
assert(std::isnan(check_nan));
222+
223+
int main() {
224+
queue q;
225+
226+
auto computeCapability =
227+
std::stof(q.get_device().get_info<sycl::info::device::backend_version>());
228+
// TODO check for "ext_oneapi_bfloat16" aspect instead once aspect is
229+
// supported. Since this test only covers CUDA the current check is
230+
// functionally equivalent to "ext_oneapi_bfloat16".
231+
if (computeCapability >= 8.0) {
232+
std::vector<float> a(N), b(N), c(N);
233+
int err = 0;
234+
235+
for (int i = 0; i < N; i++) {
236+
a[i] = (i - N / 2) / (float)N;
237+
b[i] = (N / 2 - i) / (float)N;
238+
c[i] = (float)(3 * i);
239+
}
240+
241+
TEST_BUILTIN_1(fabs);
242+
TEST_BUILTIN_2(fmin);
243+
TEST_BUILTIN_2(fmax);
244+
TEST_BUILTIN_3(fma);
245+
246+
float check_nan = 0;
247+
TEST_BUILTIN_2_NAN(fmin);
248+
TEST_BUILTIN_2_NAN(fmax);
249+
}
250+
return 0;
251+
}

0 commit comments

Comments
 (0)