Skip to content

Commit c7759bb

Browse files
jinge90JackAKirk
andauthored
[SYCL] Add generic impl for some bf16 math functions (#8583)
Signed-off-by: jinge90 <[email protected]> Co-authored-by: JackAKirk <[email protected]>
1 parent f578eef commit c7759bb

File tree

3 files changed

+294
-15
lines changed

3 files changed

+294
-15
lines changed

sycl/doc/extensions/experimental/sycl_ext_oneapi_bfloat16_math_functions.asciidoc

Lines changed: 199 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
== Notice
2424

25-
Copyright © 2022-2022 Intel Corporation. All rights reserved.
25+
Copyright © 2022-2023 Intel Corporation. All rights reserved.
2626

2727
Khronos® is a registered trademark and SYCL™ and SPIR™ are trademarks of
2828
The Khronos Group Inc. OpenCL™ is a trademark of Apple Inc. used by permission
@@ -55,13 +55,16 @@ specification.
5555

5656
== Overview
5757

58-
This extension adds `bfloat16` support to the `fma`, `fmin`, `fmax`, `fabs`
59-
and `isnan` SYCL floating point math functions. These functions can be used as
60-
element wise operations on matrices, supplementing the `bfloat16` support
61-
in the sycl_ext_oneapi_matrix extension.
58+
This extension adds `bfloat16` support to the `fma`, `fmin`, `fmax`, `fabs`,
59+
`isnan`, `ceil`, `floor`, `cos`, `sin`, `exp`, `exp2`, `exp10`, `log`, `log2`,
60+
`log10`, `rint`, `sqrt`, `rsqrt` and `trunc` SYCL floating point math functions.
61+
These functions can be used as element wise operations on matrices, supplementing
62+
the `bfloat16` support in the sycl_ext_oneapi_matrix extension.
6263

63-
The descriptions of the `fma`, `fmin`, `fmax`, `fabs` and `isnan` SYCL floating
64-
point math functions can be found in the SYCL specification:
64+
The descriptions of the `fma`, `fmin`, `fmax`, `fabs`, `isnan`, `ceil`, `floor`,
65+
`cos`, `sin`, `exp`, `exp2`, `exp10`, `log`, `log2`, `log10`, `rint`, `sqrt`,
66+
`rsqrt` and `trunc` SYCL floating point math functions can be found in the SYCL
67+
specification:
6568
https://www.khronos.org/registry/SYCL/specs/sycl-2020/html/sycl-2020.html#_math_functions.
6669

6770
== Specification
@@ -80,7 +83,7 @@ supports.
8083
[%header,cols="1,5"]
8184
|===
8285
|Value |Description
83-
|1 |Initial extension version. Base features are supported.
86+
|1 |The APIs of this experimental extension are not versioned, so the feature-test macro always has this value.
8487
|===
8588

8689
=== Extension to `enum class aspect`
@@ -184,7 +187,194 @@ T fabs(T x);
184187

185188
===== Description
186189

187-
Compute absolute value of a `bfloat16`.
190+
Compute absolute value of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
191+
192+
==== ceil
193+
194+
```c++
195+
namespace sycl::ext::oneapi::experimental {
196+
template <typename T>
197+
T ceil(T x);
198+
} // namespace sycl::ext::oneapi::experimental
199+
```
200+
201+
===== Description
202+
203+
Returns `x` rounded to an integral value using the round to positive infinity rounding mode
204+
205+
==== floor
206+
207+
```c++
208+
namespace sycl::ext::oneapi::experimental {
209+
template <typename T>
210+
T floor(T x);
211+
} // namespace sycl::ext::oneapi::experimental
212+
```
213+
214+
===== Description
215+
216+
Returns `x` rounded to an integral value using the round to negative infinity rounding mode
217+
for a `bfloat16` value or `sycl::marray<bfloat16, N>`.
218+
219+
==== cos
220+
221+
```c++
222+
namespace sycl::ext::oneapi::experimental {
223+
template <typename T>
224+
T cos(T x);
225+
} // namespace sycl::ext::oneapi::experimental
226+
```
227+
228+
===== Description
229+
230+
Compute cosine of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
231+
232+
==== sin
233+
234+
```c++
235+
namespace sycl::ext::oneapi::experimental {
236+
template <typename T>
237+
T sin(T x);
238+
} // namespace sycl::ext::oneapi::experimental
239+
```
240+
241+
===== Description
242+
243+
Compute sine of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
244+
245+
246+
==== exp
247+
248+
```c++
249+
namespace sycl::ext::oneapi::experimental {
250+
template <typename T>
251+
T exp(T x);
252+
} // namespace sycl::ext::oneapi::experimental
253+
```
254+
255+
===== Description
256+
257+
Compute the base-e exponential of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
258+
259+
==== exp2
260+
261+
```c++
262+
namespace sycl::ext::oneapi::experimental {
263+
template <typename T>
264+
T exp2(T x);
265+
} // namespace sycl::ext::oneapi::experimental
266+
```
267+
268+
===== Description
269+
270+
Compute the base-2 exponential of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
271+
272+
==== exp10
273+
274+
```c++
275+
namespace sycl::ext::oneapi::experimental {
276+
template <typename T>
277+
T exp10(T x);
278+
} // namespace sycl::ext::oneapi::experimental
279+
```
280+
281+
===== Description
282+
283+
Compute the base-10 exponential of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
284+
285+
==== log
286+
287+
```c++
288+
namespace sycl::ext::oneapi::experimental {
289+
template <typename T>
290+
T log(T x);
291+
} // namespace sycl::ext::oneapi::experimental
292+
```
293+
294+
===== Description
295+
296+
Compute natural logarithm of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
297+
298+
==== log2
299+
300+
```c++
301+
namespace sycl::ext::oneapi::experimental {
302+
template <typename T>
303+
T log2(T x);
304+
} // namespace sycl::ext::oneapi::experimental
305+
```
306+
307+
===== Description
308+
309+
Compute base-2 logarithm of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
310+
311+
==== log10
312+
313+
```c++
314+
namespace sycl::ext::oneapi::experimental {
315+
template <typename T>
316+
T log10(T x);
317+
} // namespace sycl::ext::oneapi::experimental
318+
```
319+
320+
===== Description
321+
322+
Compute base-10 logarithm of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
323+
324+
325+
==== rint
326+
327+
```c++
328+
namespace sycl::ext::oneapi::experimental {
329+
template <typename T>
330+
T rint(T x);
331+
} // namespace sycl::ext::oneapi::experimental
332+
```
333+
334+
===== Description
335+
336+
Returns `x` rounded to an integral value using the round to nearest even rounding mode
337+
for a `bfloat16` value or `sycl::marray<bfloat16, N>`.
338+
339+
==== sqrt
340+
341+
```c++
342+
namespace sycl::ext::oneapi::experimental {
343+
template <typename T>
344+
T sqrt(T x);
345+
} // namespace sycl::ext::oneapi::experimental
346+
```
347+
348+
===== Description
349+
350+
Compute square root of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
351+
352+
==== rsqrt
353+
354+
```c++
355+
namespace sycl::ext::oneapi::experimental {
356+
template <typename T>
357+
T rsqrt(T x);
358+
} // namespace sycl::ext::oneapi::experimental
359+
```
360+
361+
===== Description
362+
363+
Compute inverse square root of a `bfloat16` value or `sycl::marray<bfloat16, N>`.
364+
365+
==== trunc
366+
367+
```c++
368+
namespace sycl::ext::oneapi::experimental {
369+
template <typename T>
370+
T trunc(T x);
371+
} // namespace sycl::ext::oneapi::experimental
372+
```
373+
374+
===== Description
375+
376+
Returns `x` rounded to an integral value using the round to zero rounding mode
377+
for a `bfloat16` value or `sycl::marray<bfloat16, N>`.
188378

189379
== Issues
190380

sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,53 @@ sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
232232
return res;
233233
}
234234

235+
#define BFLOAT16_MATH_FP32_WRAPPERS(op) \
236+
template <typename T> \
237+
std::enable_if_t<std::is_same<T, bfloat16>::value, T> op(T x) { \
238+
return sycl::ext::oneapi::bfloat16{sycl::op(float{x})}; \
239+
}
240+
241+
#define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op) \
242+
template <size_t N> \
243+
sycl::marray<bfloat16, N> op(sycl::marray<bfloat16, N> x) { \
244+
sycl::marray<bfloat16, N> res; \
245+
for (size_t i = 0; i < N; i++) { \
246+
res[i] = op(x[i]); \
247+
} \
248+
return res; \
249+
}
250+
251+
BFLOAT16_MATH_FP32_WRAPPERS(ceil)
252+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(ceil)
253+
BFLOAT16_MATH_FP32_WRAPPERS(cos)
254+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(cos)
255+
BFLOAT16_MATH_FP32_WRAPPERS(exp)
256+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(exp)
257+
BFLOAT16_MATH_FP32_WRAPPERS(exp10)
258+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(exp10)
259+
BFLOAT16_MATH_FP32_WRAPPERS(exp2)
260+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(exp2)
261+
BFLOAT16_MATH_FP32_WRAPPERS(floor)
262+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(floor)
263+
BFLOAT16_MATH_FP32_WRAPPERS(log)
264+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(log)
265+
BFLOAT16_MATH_FP32_WRAPPERS(log2)
266+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(log2)
267+
BFLOAT16_MATH_FP32_WRAPPERS(log10)
268+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(log10)
269+
BFLOAT16_MATH_FP32_WRAPPERS(rint)
270+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(rint)
271+
BFLOAT16_MATH_FP32_WRAPPERS(rsqrt)
272+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(rsqrt)
273+
BFLOAT16_MATH_FP32_WRAPPERS(sin)
274+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(sin)
275+
BFLOAT16_MATH_FP32_WRAPPERS(sqrt)
276+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(sqrt)
277+
BFLOAT16_MATH_FP32_WRAPPERS(trunc)
278+
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(trunc)
279+
280+
#undef BFLOAT16_MATH_FP32_WRAPPERS
281+
#undef BFLOAT16_MATH_FP32_WRAPPERS_MARRAY
235282
} // namespace ext::oneapi::experimental
236283
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
237284
} // namespace sycl

sycl/test-e2e/BFloat16/bfloat16_builtins.cpp

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ bool check(bool a, bool b) { return (a != b); }
3737
cgh); \
3838
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
3939
cgh.parallel_for(N, [=](id<1> index) { \
40+
float ABF16 = float{bfloat16{A[index]}}; \
4041
if (check(sycl::ext::oneapi::experimental::NAME(bfloat16{A[index]}), \
41-
sycl::NAME(A[index]))) { \
42+
sycl::NAME(ABF16))) { \
4243
ERR[0] = 1; \
4344
} \
4445
}); \
@@ -61,7 +62,8 @@ bool check(bool a, bool b) { return (a != b); }
6162
} \
6263
marray<RETTY, SZ> res = NAME(arg); \
6364
for (int i = 0; i < SZ; i++) { \
64-
if (check(res[i], sycl::NAME(A[index][i]))) { \
65+
float ABF16 = float{bfloat16{A[index][i]}}; \
66+
if (check(res[i], sycl::NAME(ABF16))) { \
6567
ERR[0] = 1; \
6668
} \
6769
} \
@@ -90,8 +92,10 @@ bool check(bool a, bool b) { return (a != b); }
9092
cgh); \
9193
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
9294
cgh.parallel_for(N, [=](id<1> index) { \
95+
float ABF16 = float{bfloat16{A[index]}}; \
96+
float BBF16 = float{bfloat16{B[index]}}; \
9397
if (check(NAME(bfloat16{A[index]}, bfloat16{B[index]}), \
94-
NAME(A[index], B[index]))) { \
98+
NAME(ABF16, BBF16))) { \
9599
ERR[0] = 1; \
96100
} \
97101
}); \
@@ -118,7 +122,9 @@ bool check(bool a, bool b) { return (a != b); }
118122
} \
119123
marray<bfloat16, SZ> res = NAME(arg0, arg1); \
120124
for (int i = 0; i < SZ; i++) { \
121-
if (check(res[i], NAME(A[index][i], B[index][i]))) { \
125+
float ABF16 = float{bfloat16{A[index][i]}}; \
126+
float BBF16 = float{bfloat16{B[index][i]}}; \
127+
if (check(res[i], NAME(ABF16, BBF16))) { \
122128
ERR[0] = 1; \
123129
} \
124130
} \
@@ -150,9 +156,12 @@ bool check(bool a, bool b) { return (a != b); }
150156
cgh); \
151157
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
152158
cgh.parallel_for(N, [=](id<1> index) { \
159+
float ABF16 = float{bfloat16{A[index]}}; \
160+
float BBF16 = float{bfloat16{B[index]}}; \
161+
float CBF16 = float{bfloat16{C[index]}}; \
153162
if (check(NAME(bfloat16{A[index]}, bfloat16{B[index]}, \
154163
bfloat16{C[index]}), \
155-
NAME(A[index], B[index], C[index]))) { \
164+
NAME(ABF16, BBF16, CBF16))) { \
156165
ERR[0] = 1; \
157166
} \
158167
}); \
@@ -183,7 +192,10 @@ bool check(bool a, bool b) { return (a != b); }
183192
} \
184193
marray<bfloat16, SZ> res = NAME(arg0, arg1, arg2); \
185194
for (int i = 0; i < SZ; i++) { \
186-
if (check(res[i], NAME(A[index][i], B[index][i], C[index][i]))) { \
195+
float ABF16 = float{bfloat16{A[index][i]}}; \
196+
float BBF16 = float{bfloat16{B[index][i]}}; \
197+
float CBF16 = float{bfloat16{C[index][i]}}; \
198+
if (check(res[i], NAME(ABF16, BBF16, CBF16))) { \
187199
ERR[0] = 1; \
188200
} \
189201
} \
@@ -245,5 +257,35 @@ int main() {
245257
a[0] = a[N - 1] = NAN;
246258
TEST_BUILTIN_1(isnan, bool);
247259

260+
// Orignal input 'a[0...N-1]' are in range [-0.5, 0.5),
261+
// need to update it for generic math testing.
262+
// sin, cos testing
263+
for (int i = 0; i < N; ++i) {
264+
a[i] = (i / (float)(N - 1)) * 6.28;
265+
if ((i & 0x1) == 0x1)
266+
a[i] = -a[i];
267+
}
268+
TEST_BUILTIN_1(cos, sycl::ext::oneapi::bfloat16);
269+
TEST_BUILTIN_1(sin, sycl::ext::oneapi::bfloat16);
270+
271+
// ceil, floor, trunc, exp, exp2, exp10, rint testing
272+
TEST_BUILTIN_1(ceil, sycl::ext::oneapi::bfloat16);
273+
TEST_BUILTIN_1(floor, sycl::ext::oneapi::bfloat16);
274+
TEST_BUILTIN_1(trunc, sycl::ext::oneapi::bfloat16);
275+
TEST_BUILTIN_1(exp, sycl::ext::oneapi::bfloat16);
276+
TEST_BUILTIN_1(exp10, sycl::ext::oneapi::bfloat16);
277+
TEST_BUILTIN_1(exp2, sycl::ext::oneapi::bfloat16);
278+
TEST_BUILTIN_1(rint, sycl::ext::oneapi::bfloat16);
279+
280+
// log, log2, log10, sqrt, rsqrt testing, the input
281+
// must be positive.
282+
for (int i = 0; i < N; ++i)
283+
a[i] = a[i] + 8.5;
284+
TEST_BUILTIN_1(sqrt, sycl::ext::oneapi::bfloat16);
285+
TEST_BUILTIN_1(rsqrt, sycl::ext::oneapi::bfloat16);
286+
TEST_BUILTIN_1(log, sycl::ext::oneapi::bfloat16);
287+
TEST_BUILTIN_1(log2, sycl::ext::oneapi::bfloat16);
288+
TEST_BUILTIN_1(log10, sycl::ext::oneapi::bfloat16);
289+
248290
return 0;
249291
}

0 commit comments

Comments
 (0)