Skip to content

Commit e055600

Browse files
author
aidan.belton
committed
Update bf16_conversion extension
1 parent 25a6707 commit e055600

File tree

2 files changed

+337
-25
lines changed

2 files changed

+337
-25
lines changed
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
= SYCL_ONEAPI_bfloat16
2+
3+
:source-highlighter: coderay
4+
:coderay-linenums-mode: table
5+
6+
// This section needs to be after the document title.
7+
:doctype: book
8+
:toc2:
9+
:toc: left
10+
:encoding: utf-8
11+
:lang: en
12+
13+
:blank: pass:[ +]
14+
15+
// Set the default source code type in this document to C++,
16+
// for syntax highlighting purposes. This is needed because
17+
// docbook uses c++ and html5 uses cpp.
18+
:language: {basebackend@docbook:c++:cpp}
19+
20+
// This is necessary for asciidoc, but not for asciidoctor
21+
:cpp: C++
22+
23+
== Notice
24+
25+
IMPORTANT: This specification is a draft.
26+
27+
Copyright (c) 2021 Intel Corporation. All rights reserved.
28+
29+
NOTE: Khronos(R) is a registered trademark and SYCL(TM) and SPIR(TM) are
30+
trademarks of The Khronos Group Inc. OpenCL(TM) is a trademark of Apple Inc.
31+
used by permission by Khronos.
32+
33+
== Dependencies
34+
35+
This extension is written against the SYCL 2020 specification, Revision 3.
36+
37+
== Status
38+
39+
Draft
40+
41+
This is a preview extension specification, intended to provide early access to
42+
a feature for review and community feedback. When the feature matures, this
43+
specification may be released as a formal extension.
44+
45+
Because the interfaces defined by this specification are not final and are
46+
subject to change they are not intended to be used by shipping software
47+
products.
48+
49+
== Version
50+
51+
Revision: 3
52+
53+
== Introduction
54+
55+
This extension adds functionality to convert value of single-precision
56+
floating-point type(`float`) to `bfloat16` type and vice versa. The extension
57+
doesn't add support for `bfloat16` type as such, instead it uses 16-bit integer
58+
type(`uint16_t`) as a storage for `bfloat16` values.
59+
60+
The purpose of conversion from float to bfloat16 is to reduce ammount of memory
61+
required to store floating-point numbers. Computations are expected to be done with
62+
32-bit floating-point values.
63+
64+
This extension is an optional kernel feature as described in
65+
https://www.khronos.org/registry/SYCL/specs/sycl-2020/html/sycl-2020.html#sec:optional-kernel-features[section 5.7]
66+
of the SYCL 2020 spec. Therefore, attempting to submit a kernel using this
67+
feature to a device that does not support it should cause a synchronous
68+
`errc::kernel_not_supported` exception to be thrown from the kernel invocation
69+
command (e.g. from `parallel_for`).
70+
71+
== Feature test macro
72+
73+
This extension provides a feature-test macro as described in the core SYCL
74+
specification section 6.3.3 "Feature test macros". Therefore, an implementation
75+
supporting this extension must predefine the macro
76+
`SYCL_EXT_ONEAPI_BFLOAT16_CONVERSION` to one of the values defined in the table
77+
below. Applications can test for the existence of this macro to determine if
78+
the implementation supports this feature, or applications can test the macro’s
79+
value to determine which of the extension’s APIs the implementation supports.
80+
81+
[%header,cols="1,5"]
82+
|===
83+
|Value |Description
84+
|1 |Initial extension version. Base features are supported.
85+
|===
86+
87+
== Extension to `enum class aspect`
88+
89+
[source]
90+
----
91+
namespace sycl {
92+
enum class aspect {
93+
...
94+
ext_oneapi_bfloat16
95+
}
96+
}
97+
----
98+
99+
If a SYCL device has the `ext_oneapi_bfloat16` aspect, then it natively
100+
supports conversion of values of `float` type to `bfloat16` and back.
101+
102+
If the device doesn't have the aspect, objects of `bfloat16` class must not be
103+
used in the device code.
104+
105+
**NOTE**: The `ext_oneapi_bfloat16` aspect is not yet supported. The
106+
`bfloat16` class is currently supported only on Xe HP GPU.
107+
108+
== New `bfloat16` class
109+
110+
The `bfloat16` class below provides the conversion functionality. Conversion
111+
from `float` to `bfloat16` is done with round to nearest even(RTE) rounding
112+
mode.
113+
114+
[source]
115+
----
116+
namespace sycl {
117+
namespace ext {
118+
namespace oneapi {
119+
namespace experimental {
120+
121+
class bfloat16 {
122+
using storage_t = uint16_t;
123+
storage_t value;
124+
125+
public:
126+
bfloat16() = default;
127+
bfloat16(const bfloat16 &) = default;
128+
~bfloat16() = default;
129+
130+
// Explicit conversion functions
131+
static storage_t from_float(const float &a);
132+
static float to_float(const storage_t &a);
133+
134+
// Convert from float to bfloat16
135+
bfloat16(const float &a);
136+
bfloat16 &operator=(const float &a);
137+
138+
// Convert from bfloat16 to float
139+
operator float() const;
140+
141+
// Get bfloat16 as uint16.
142+
operator storage_t() const;
143+
144+
// Convert to bool type
145+
explicit operator bool();
146+
147+
friend bfloat16 operator-(bfloat16 &bf) { /* ... */ }
148+
149+
// OP is: prefix ++, --
150+
friend bfloat16 &operatorOP(bfloat16 &bf) { /* ... */ }
151+
152+
// OP is: postfix ++, --
153+
friend bfloat16 operatorOP(bfloat16 &bf, int) { /* ... */ }
154+
155+
// OP is: +=, -=, *=, /=
156+
friend bfloat16 &operatorOP(bfloat16 &lhs, const bfloat16 &rhs) { /* ... */ }
157+
158+
// OP is +, -, *, /
159+
friend bfloat16 operatorOP(const bfloat16 &lhs, const bfloat16 &rhs) { /* ... */ }
160+
template <typename T>
161+
friend bfloat16 operatorOP(const bfloat16 &lhs, const T &rhs) { /* ... */ }
162+
template <typename T>
163+
friend bfloat16 operatorOP(const T &lhs, const bfloat16 &rhs) { /* ... */ }
164+
165+
// OP is ==,!=, <, >, <=, >=
166+
friend bool operatorOP(const bfloat16 &lhs, const bfloat16 &rhs) { /* ... */ }
167+
template <typename T>
168+
friend bool operatorOP(const bfloat16 &lhs, const T &rhs) { /* ... */ }
169+
template <typename T>
170+
friend bool operatorOP(const T &lhs, const bfloat16 &rhs) { /* ... */ }
171+
};
172+
173+
} // namespace experimental
174+
} // namespace oneapi
175+
} // namespace ext
176+
} // namespace sycl
177+
----
178+
179+
Table 1. Member functions of `bfloat16` class.
180+
|===
181+
| Member Function | Description
182+
183+
| `static storage_t from_float(const float &a);`
184+
| Explicitly convert from `float` to `bfloat16`.
185+
186+
| `static float to_float(const storage_t &a);`
187+
| Interpret `a` as `bfloat16` and explicitly convert it to `float`.
188+
189+
| `bfloat16(const float& a);`
190+
| Construct `bfloat16` from `float`. Converts `float` to `bfloat16`.
191+
192+
| `bfloat16 &operator=(const float &a);`
193+
| Replace the value with `a` converted to `bfloat16`
194+
195+
| `operator float() const;`
196+
| Return `bfloat16` value converted to `float`.
197+
198+
| `operator storage_t() const;`
199+
| Return `uint16_t` value, whose bits represent `bfloat16` value.
200+
201+
| `explicit operator bool() { /* ... */ }`
202+
| Convert `bfloat16` to `bool` type. Return `false` if the value equals to
203+
zero, return `true` otherwise.
204+
205+
| `friend bfloat16 operator-(bfloat16 &bf) { /* ... */ }`
206+
| Construct new instance of `bfloat16` class with negated value of the `bf`.
207+
208+
| `friend bfloat16 &operatorOP(bfloat16 &bf) { /* ... */ }`
209+
| Perform an in-place `OP` prefix arithmetic operation on the `bf`,
210+
assigning the result to the `bf` and return the `bf`.
211+
212+
OP is: `++, --`
213+
214+
| `friend bfloat16 operatorOP(bfloat16 &bf, int) { /* ... */ }`
215+
| Perform an in-place `OP` postfix arithmetic operation on `bf`, assigning
216+
the result to the `bf` and return a copy of `bf` before the operation is
217+
performed.
218+
219+
OP is: `++, --`
220+
221+
| `friend bfloat16 operatorOP(const bfloat16 &lhs, const bfloat16 &rhs) { /* ... */ }`
222+
| Perform an in-place `OP` arithmetic operation between the `lhs` and the `rhs`
223+
and return the `lhs`.
224+
225+
OP is: `+=, -=, *=, /=`
226+
227+
| `friend type operatorOP(const bfloat16 &lhs, const bfloat16 &rhs) { /* ... */ }`
228+
| Construct a new instance of the `bfloat16` class with the value of the new
229+
`bfloat16` instance being the result of an OP arithmetic operation between
230+
the `lhs` `bfloat16` and `rhs` `bfloat16` values.
231+
232+
OP is `+, -, *, /`
233+
234+
| `template <typename T>
235+
friend bfloat16 operatorOP(const bfloat16 &lhs, const T &rhs) { /* ... */ }`
236+
| Construct a new instance of the `bfloat16` class with the value of the new
237+
`bfloat16` instance being the result of an OP arithmetic operation between
238+
the `lhs` `bfloat16` value and `rhs` of template type `T`. Type `T` must be
239+
convertible to `float`.
240+
241+
OP is `+, -, *, /`
242+
243+
| `template <typename T>
244+
friend bfloat16 operatorOP(const T &lhs, const bfloat16 &rhs) { /* ... */ }`
245+
| Construct a new instance of the `bfloat16` class with the value of the new
246+
`bfloat16` instance being the result of an OP arithmetic operation between
247+
the `lhs` of template type `T` and `rhs` `bfloat16` value. Type `T` must be
248+
convertible to `float`.
249+
250+
OP is `+, -, *, /`
251+
252+
| `friend bool operatorOP(const bfloat16 &lhs, const bfloat16 &rhs) { /* ... */ }`
253+
| Perform comparison operation OP between `lhs` `bfloat16` and `rhs` `bfloat16`
254+
values and return the result as a boolean value.
255+
256+
OP is `==, !=, <, >, <=, >=`
257+
258+
| `template <typename T>
259+
friend bool operatorOP(const bfloat16 &lhs, const T &rhs) { /* ... */ }`
260+
| Perform comparison operation OP between `lhs` `bfloat16` and `rhs` of
261+
template type `T` and return the result as a boolean value. Type `T` must be
262+
convertible to `float`.
263+
264+
OP is `==, !=, <, >, <=, >=`
265+
266+
| `template <typename T>
267+
friend bool operatorOP(const T &lhs, const bfloat16 &rhs) { /* ... */ }`
268+
| Perform comparison operation OP between `lhs` of template type `T` and `rhs`
269+
`bfloat16` value and return the result as a boolean value. Type `T` must be
270+
convertible to `float`.
271+
272+
OP is `==, !=, <, >, <=, >=`
273+
|===
274+
275+
== Example
276+
277+
[source]
278+
----
279+
#include <sycl/sycl.hpp>
280+
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
281+
282+
using sycl::ext::oneapi::experimental::bfloat16;
283+
284+
bfloat16 operator+(const bfloat16 &lhs, const bfloat16 &rhs) {
285+
return static_cast<float>(lhs) + static_cast<float>(rhs);
286+
}
287+
288+
float foo(float a, float b) {
289+
// Convert from float to bfloat16.
290+
bfloat16 A {a};
291+
bfloat16 B {b};
292+
293+
// Convert A and B from bfloat16 to float, do addition on floating-pointer
294+
// numbers, then convert the result to bfloat16 and store it in C.
295+
bfloat16 C = A + B;
296+
297+
// Return the result converted from bfloat16 to float.
298+
return C;
299+
}
300+
301+
int main (int argc, char *argv[]) {
302+
float data[3] = {7.0, 8.1, 0.0};
303+
sycl::device dev;
304+
sycl::queue deviceQueue{dev};
305+
sycl::buffer<float, 1> buf {data, sycl::range<1> {3}};
306+
307+
if (dev.has(sycl::aspect::ext_oneapi_bfloat16)) {
308+
deviceQueue.submit ([&] (sycl::handler& cgh) {
309+
auto numbers = buf.get_access<sycl::access::mode::read_write> (cgh);
310+
cgh.single_task<class simple_kernel> ([=] () {
311+
numbers[2] = foo(numbers[0], numbers[1]);
312+
});
313+
});
314+
}
315+
return 0;
316+
}
317+
----
318+
319+
== Issues
320+
321+
None.
322+
323+
== Revision History
324+
325+
[cols="5,15,15,70"]
326+
[grid="rows"]
327+
[options="header"]
328+
|========================================
329+
|Rev|Date|Author|Changes
330+
|1|2021-08-02|Alexey Sotkin |Initial public working draft
331+
|2|2021-08-17|Alexey Sotkin |Add explicit conversion functions +
332+
Add operator overloadings +
333+
Apply code review suggestions
334+
|3|2021-08-18|Alexey Sotkin |Remove `uint16_t` constructor
335+
|========================================

sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace ext {
1717
namespace oneapi {
1818
namespace experimental {
1919

20-
class bfloat16 {
20+
class [[sycl_detail::uses_aspects(ext_oneapi_bfloat16)]] bfloat16 {
2121
using storage_t = uint16_t;
2222
storage_t value;
2323

@@ -29,29 +29,16 @@ class bfloat16 {
2929
// Explicit conversion functions
3030
static storage_t from_float(const float &a) {
3131
#if defined(__SYCL_DEVICE_ONLY__)
32-
#if defined(__NVPTX__)
33-
return __nvvm_f2bf16_rn(a);
34-
#else
3532
return __spirv_ConvertFToBF16INTEL(a);
36-
#endif
3733
#else
38-
(void)a;
3934
throw exception{errc::feature_not_supported,
4035
"Bfloat16 conversion is not supported on host device"};
4136
#endif
4237
}
4338
static float to_float(const storage_t &a) {
4439
#if defined(__SYCL_DEVICE_ONLY__)
45-
#if defined(__NVPTX__)
46-
uint32_t y = a;
47-
y = y << 16;
48-
float *res = reinterpret_cast<float *>(&y);
49-
return *res;
50-
#else
5140
return __spirv_ConvertBF16ToFINTEL(a);
52-
#endif
5341
#else
54-
(void)a;
5542
throw exception{errc::feature_not_supported,
5643
"Bfloat16 conversion is not supported on host device"};
5744
#endif
@@ -83,17 +70,7 @@ class bfloat16 {
8370

8471
// Unary minus operator overloading
8572
friend bfloat16 operator-(bfloat16 &lhs) {
86-
#if defined(__SYCL_DEVICE_ONLY__)
87-
#if defined(__NVPTX__)
88-
return from_bits(__nvvm_neg_bf16(lhs.value));
89-
#else
90-
return bfloat16{-__spirv_ConvertBF16ToFINTEL(lhs.value)};
91-
#endif
92-
#else
93-
(void)lhs;
94-
throw exception{errc::feature_not_supported,
95-
"Bfloat16 unary minus is not supported on host device"};
96-
#endif
73+
return bfloat16{-to_float(lhs.value)};
9774
}
9875

9976
// Increment and decrement operators overloading

0 commit comments

Comments
 (0)