1
1
#include < iostream>
2
- #include < sycl/ext/oneapi/experimental/ bfloat16.hpp>
2
+ #include < sycl/ext/oneapi/bfloat16.hpp>
3
3
#include < sycl/sycl.hpp>
4
4
5
5
#include < cmath>
@@ -11,8 +11,7 @@ constexpr size_t N = 100;
11
11
template <typename T> void assert_close (const T &C, const float ref) {
12
12
for (size_t i = 0 ; i < N; i++) {
13
13
auto diff = C[i] - ref;
14
- assert (std::fabs (static_cast <float >(diff)) <
15
- std::numeric_limits<float >::epsilon ());
14
+ assert (std::fabs (static_cast <float >(diff)) < 0.1 );
16
15
}
17
16
}
18
17
@@ -21,7 +20,7 @@ void verify_conv_implicit(queue &q, buffer<float, 1> &a, range<1> &r,
21
20
q.submit ([&](handler &cgh) {
22
21
auto A = a.get_access <access::mode::read_write>(cgh);
23
22
cgh.parallel_for <class calc_conv >(r, [=](id<1 > index) {
24
- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
23
+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
25
24
A[index] = AVal;
26
25
});
27
26
});
@@ -34,9 +33,8 @@ void verify_conv_explicit(queue &q, buffer<float, 1> &a, range<1> &r,
34
33
q.submit ([&](handler &cgh) {
35
34
auto A = a.get_access <access::mode::read_write>(cgh);
36
35
cgh.parallel_for <class calc_conv_impl >(r, [=](id<1 > index) {
37
- uint16_t AVal =
38
- sycl::ext::oneapi::experimental::bfloat16::from_float (A[index]);
39
- A[index] = sycl::ext::oneapi::experimental::bfloat16::to_float (AVal);
36
+ sycl::ext::oneapi::bfloat16 AVal = A[index];
37
+ A[index] = float (AVal);
40
38
});
41
39
});
42
40
@@ -52,9 +50,9 @@ void verify_add(queue &q, buffer<float, 1> &a, buffer<float, 1> &b, range<1> &r,
52
50
auto B = b.get_access <access::mode::read>(cgh);
53
51
auto C = c.get_access <access::mode::write>(cgh);
54
52
cgh.parallel_for <class calc_add_expl >(r, [=](id<1 > index) {
55
- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
56
- sycl::ext::oneapi::experimental:: bfloat16 BVal{B[index]};
57
- sycl::ext::oneapi::experimental:: bfloat16 CVal = AVal + BVal;
53
+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
54
+ sycl::ext::oneapi::bfloat16 BVal{B[index]};
55
+ sycl::ext::oneapi::bfloat16 CVal = AVal + BVal;
58
56
C[index] = CVal;
59
57
});
60
58
});
@@ -71,9 +69,9 @@ void verify_sub(queue &q, buffer<float, 1> &a, buffer<float, 1> &b, range<1> &r,
71
69
auto B = b.get_access <access::mode::read>(cgh);
72
70
auto C = c.get_access <access::mode::write>(cgh);
73
71
cgh.parallel_for <class calc_sub >(r, [=](id<1 > index) {
74
- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
75
- sycl::ext::oneapi::experimental:: bfloat16 BVal{B[index]};
76
- sycl::ext::oneapi::experimental:: bfloat16 CVal = AVal - BVal;
72
+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
73
+ sycl::ext::oneapi::bfloat16 BVal{B[index]};
74
+ sycl::ext::oneapi::bfloat16 CVal = AVal - BVal;
77
75
C[index] = CVal;
78
76
});
79
77
});
@@ -88,8 +86,8 @@ void verify_minus(queue &q, buffer<float, 1> &a, range<1> &r, const float ref) {
88
86
auto A = a.get_access <access::mode::read>(cgh);
89
87
auto C = c.get_access <access::mode::write>(cgh);
90
88
cgh.parallel_for <class calc_minus >(r, [=](id<1 > index) {
91
- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
92
- sycl::ext::oneapi::experimental:: bfloat16 CVal = -AVal;
89
+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
90
+ sycl::ext::oneapi::bfloat16 CVal = -AVal;
93
91
C[index] = CVal;
94
92
});
95
93
});
@@ -106,9 +104,9 @@ void verify_mul(queue &q, buffer<float, 1> &a, buffer<float, 1> &b, range<1> &r,
106
104
auto B = b.get_access <access::mode::read>(cgh);
107
105
auto C = c.get_access <access::mode::write>(cgh);
108
106
cgh.parallel_for <class calc_mul >(r, [=](id<1 > index) {
109
- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
110
- sycl::ext::oneapi::experimental:: bfloat16 BVal{B[index]};
111
- sycl::ext::oneapi::experimental:: bfloat16 CVal = AVal * BVal;
107
+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
108
+ sycl::ext::oneapi::bfloat16 BVal{B[index]};
109
+ sycl::ext::oneapi::bfloat16 CVal = AVal * BVal;
112
110
C[index] = CVal;
113
111
});
114
112
});
@@ -125,9 +123,9 @@ void verify_div(queue &q, buffer<float, 1> &a, buffer<float, 1> &b, range<1> &r,
125
123
auto B = b.get_access <access::mode::read>(cgh);
126
124
auto C = c.get_access <access::mode::write>(cgh);
127
125
cgh.parallel_for <class calc_div >(r, [=](id<1 > index) {
128
- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
129
- sycl::ext::oneapi::experimental:: bfloat16 BVal{B[index]};
130
- sycl::ext::oneapi::experimental:: bfloat16 CVal = AVal / BVal;
126
+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
127
+ sycl::ext::oneapi::bfloat16 BVal{B[index]};
128
+ sycl::ext::oneapi::bfloat16 CVal = AVal / BVal;
131
129
C[index] = CVal;
132
130
});
133
131
});
@@ -144,19 +142,18 @@ void verify_logic(queue &q, buffer<float, 1> &a, buffer<float, 1> &b,
144
142
auto B = b.get_access <access::mode::read>(cgh);
145
143
auto C = c.get_access <access::mode::write>(cgh);
146
144
cgh.parallel_for <class logic >(r, [=](id<1 > index) {
147
- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
148
- sycl::ext::oneapi::experimental:: bfloat16 BVal{B[index]};
145
+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
146
+ sycl::ext::oneapi::bfloat16 BVal{B[index]};
149
147
if (AVal) {
150
148
if (AVal > BVal || AVal >= BVal || AVal < BVal || AVal <= BVal ||
151
149
!BVal) {
152
- sycl::ext::oneapi::experimental::bfloat16 CVal =
153
- AVal != BVal ? AVal : BVal;
150
+ sycl::ext::oneapi::bfloat16 CVal = AVal != BVal ? AVal : BVal;
154
151
CVal--;
155
152
CVal++;
156
153
if (AVal == BVal) {
157
154
CVal -= AVal;
158
- CVal *= 3.0 ;
159
- CVal /= 2.0 ;
155
+ CVal *= 3 .0f ;
156
+ CVal /= 2 .0f ;
160
157
} else
161
158
CVal += BVal;
162
159
C[index] = CVal;
@@ -179,9 +176,9 @@ int run_tests() {
179
176
return 0 ;
180
177
}
181
178
182
- std::vector<float > vec_a (N, 5.0 );
183
- std::vector<float > vec_b (N, 2.0 );
184
- std::vector<float > vec_b_neg (N, -2.0 );
179
+ std::vector<float > vec_a (N, 5 .0f );
180
+ std::vector<float > vec_b (N, 2 .0f );
181
+ std::vector<float > vec_b_neg (N, -2 .0f );
185
182
186
183
range<1 > r (N);
187
184
buffer<float , 1 > a{vec_a.data (), r};
@@ -190,19 +187,32 @@ int run_tests() {
190
187
191
188
queue q{dev};
192
189
193
- verify_conv_implicit (q, a, r, 5.0 );
194
- verify_conv_explicit (q, a, r, 5.0 );
195
- verify_add (q, a, b, r, 7.0 );
196
- verify_sub (q, a, b, r, 3.0 );
197
- verify_mul (q, a, b, r, 10.0 );
198
- verify_div (q, a, b, r, 2.5 );
199
- verify_logic (q, a, b, r, 7.0 );
200
- verify_add (q, a, b_neg, r, 3.0 );
201
- verify_sub (q, a, b_neg, r, 7.0 );
202
- verify_minus (q, a, r, -5.0 );
203
- verify_mul (q, a, b_neg, r, -10.0 );
204
- verify_div (q, a, b_neg, r, -2.5 );
205
- verify_logic (q, a, b_neg, r, 3.0 );
190
+ verify_conv_implicit (q, a, r, 5 .0f );
191
+ std::cout << " PASS verify_conv_implicit\n " ;
192
+ verify_conv_explicit (q, a, r, 5 .0f );
193
+ std::cout << " PASS verify_conv_explicit\n " ;
194
+ verify_add (q, a, b, r, 7 .0f );
195
+ std::cout << " PASS verify_add\n " ;
196
+ verify_sub (q, a, b, r, 3 .0f );
197
+ std::cout << " PASS verify_sub\n " ;
198
+ verify_mul (q, a, b, r, 10 .0f );
199
+ std::cout << " PASS verify_mul\n " ;
200
+ verify_div (q, a, b, r, 2 .5f );
201
+ std::cout << " PASS verify_div\n " ;
202
+ verify_logic (q, a, b, r, 7 .0f );
203
+ std::cout << " PASS verify_logic\n " ;
204
+ verify_add (q, a, b_neg, r, 3 .0f );
205
+ std::cout << " PASS verify_add\n " ;
206
+ verify_sub (q, a, b_neg, r, 7 .0f );
207
+ std::cout << " PASS verify_sub\n " ;
208
+ verify_minus (q, a, r, -5 .0f );
209
+ std::cout << " PASS verify_minus\n " ;
210
+ verify_mul (q, a, b_neg, r, -10 .0f );
211
+ std::cout << " PASS verify_mul\n " ;
212
+ verify_div (q, a, b_neg, r, -2 .5f );
213
+ std::cout << " PASS verify_div\n " ;
214
+ verify_logic (q, a, b_neg, r, 3 .0f );
215
+ std::cout << " PASS verify_logic\n " ;
206
216
207
217
return 0 ;
208
218
}
0 commit comments