Skip to content

Commit 6014cef

Browse files
committed
[SYCL] Move bfloat support from experimental to supported.
Signed-off-by: Rajiv Deodhar <[email protected]>
1 parent 62c36e9 commit 6014cef

File tree

5 files changed

+47
-56
lines changed

5 files changed

+47
-56
lines changed

sycl/doc/extensions/experimental/sycl_ext_oneapi_bfloat16.asciidoc renamed to sycl/doc/extensions/supported/sycl_ext_oneapi_bfloat16.asciidoc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,11 @@ public:
135135
bfloat16(const float &a);
136136
bfloat16 &operator=(const float &a);
137137
138-
// Convert from bfloat16 to float
138+
// Convert bfloat16 to floating-point types
139139
operator float() const;
140+
operator sycl::half() const;
140141
141-
// Get bfloat16 as uint16.
142-
operator storage_t() const;
143-
144-
// Convert to bool type
142+
// Convert bfloat16 to bool type
145143
explicit operator bool();
146144
147145
friend bfloat16 operator-(bfloat16 &bf) { /* ... */ }
@@ -195,11 +193,11 @@ Table 1. Member functions of `bfloat16` class.
195193
| `operator float() const;`
196194
| Return `bfloat16` value converted to `float`.
197195

198-
| `operator storage_t() const;`
199-
| Return `uint16_t` value, whose bits represent `bfloat16` value.
196+
| `operator sycl::half() const;`
197+
| Return `bfloat16` value converted to `sycl::half`.
200198

201199
| `explicit operator bool() { /* ... */ }`
202-
| Convert `bfloat16` to `bool` type. Return `false` if the value equals to
200+
| Convert `bfloat16` to `bool` type. Return `false` if the `value` equals to
203201
zero, return `true` otherwise.
204202

205203
| `friend bfloat16 operator-(bfloat16 &bf) { /* ... */ }`
@@ -408,4 +406,5 @@ Compute absolute value of a `bfloat16`.
408406
|3|2021-08-18|Alexey Sotkin |Remove `uint16_t` constructor
409407
|4|2022-03-07|Aidan Belton and Jack Kirk |Switch from Intel vendor specific to oneapi
410408
|5|2022-04-05|Jack Kirk | Added section for bfloat16 math builtins
409+
|6|2022-08-03|Alexey Sotkin |Add `operator sycl::half()`
411410
|========================================

sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp renamed to sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ __SYCL_INLINE_NAMESPACE(cl) {
1515
namespace sycl {
1616
namespace ext {
1717
namespace oneapi {
18-
namespace experimental {
1918

2019
class bfloat16 {
2120
using storage_t = uint16_t;
@@ -165,7 +164,6 @@ class bfloat16 {
165164
// for floating-point types.
166165
};
167166

168-
} // namespace experimental
169167
} // namespace oneapi
170168
} // namespace ext
171169

sycl/include/sycl/ext/oneapi/experimental/builtins.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include <sycl/detail/type_traits.hpp>
1616

1717
#include <CL/__spirv/spirv_ops.hpp>
18-
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
18+
#include <sycl/ext/oneapi/bfloat16.hpp>
1919

2020
// TODO Decide whether to mark functions with this attribute.
2121
#define __NOEXC /*noexcept*/

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#include <CL/__spirv/spirv_ops.hpp>
1212
#include <sycl/detail/defines_elementary.hpp>
13-
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
13+
#include <sycl/ext/oneapi/bfloat16.hpp>
1414
#include <sycl/feature_test.hpp>
1515

1616
__SYCL_INLINE_NAMESPACE(cl) {
@@ -458,18 +458,16 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
458458
};
459459

460460
template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
461-
class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
462-
Layout, Group> {
463-
joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
464-
Layout, Group> &M;
461+
class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group> {
462+
joint_matrix<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group> &M;
465463
std::size_t idx;
466464

467465
public:
468-
wi_element(joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows,
469-
NumCols, Layout, Group> &Mat,
466+
wi_element(joint_matrix<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout,
467+
Group> &Mat,
470468
std::size_t i)
471469
: M(Mat), idx(i) {}
472-
operator sycl::ext::oneapi::experimental::bfloat16() {
470+
operator sycl::ext::oneapi::bfloat16() {
473471
#ifdef __SYCL_DEVICE_ONLY__
474472
return __spirv_VectorExtractDynamic(M.spvm, idx);
475473
#else
@@ -488,7 +486,7 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
488486
#endif // __SYCL_DEVICE_ONLY__
489487
}
490488

491-
wi_element &operator=(const sycl::ext::oneapi::experimental::bfloat16 &rhs) {
489+
wi_element &operator=(const sycl::ext::oneapi::bfloat16 &rhs) {
492490
#ifdef __SYCL_DEVICE_ONLY__
493491
M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
494492
return *this;
@@ -499,9 +497,8 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
499497
#endif // __SYCL_DEVICE_ONLY__
500498
}
501499

502-
wi_element &
503-
operator=(const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows,
504-
NumCols, Layout, Group> &rhs) {
500+
wi_element &operator=(const wi_element<sycl::ext::oneapi::bfloat16, NumRows,
501+
NumCols, Layout, Group> &rhs) {
505502
#ifdef __SYCL_DEVICE_ONLY__
506503
M.spvm = __spirv_VectorInsertDynamic(
507504
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
@@ -515,16 +512,14 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
515512

516513
#if __SYCL_DEVICE_ONLY__
517514
#define OP(opassign, op) \
518-
wi_element &operator opassign( \
519-
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
515+
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
520516
M.spvm = __spirv_VectorInsertDynamic( \
521517
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \
522518
return *this; \
523519
}
524520
#else // __SYCL_DEVICE_ONLY__
525521
#define OP(opassign, op) \
526-
wi_element &operator opassign( \
527-
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
522+
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
528523
(void)rhs; \
529524
throw runtime_error("joint matrix is not supported on host device.", \
530525
PI_ERROR_INVALID_DEVICE); \
@@ -539,34 +534,34 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
539534
#if __SYCL_DEVICE_ONLY__
540535
#define OP(type, op) \
541536
friend type operator op( \
542-
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
543-
NumCols, Layout, Group> &lhs, \
544-
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
537+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
538+
Group> &lhs, \
539+
const sycl::ext::oneapi::bfloat16 &rhs) { \
545540
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \
546541
} \
547542
friend type operator op( \
548-
const sycl::ext::oneapi::experimental::bfloat16 &lhs, \
549-
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
550-
NumCols, Layout, Group> &rhs) { \
543+
const sycl::ext::oneapi::bfloat16 &lhs, \
544+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
545+
Group> &rhs) { \
551546
return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \
552547
}
553-
OP(sycl::ext::oneapi::experimental::bfloat16, +)
554-
OP(sycl::ext::oneapi::experimental::bfloat16, -)
555-
OP(sycl::ext::oneapi::experimental::bfloat16, *)
556-
OP(sycl::ext::oneapi::experimental::bfloat16, /)
548+
OP(sycl::ext::oneapi::bfloat16, +)
549+
OP(sycl::ext::oneapi::bfloat16, -)
550+
OP(sycl::ext::oneapi::bfloat16, *)
551+
OP(sycl::ext::oneapi::bfloat16, /)
557552
#undef OP
558553
#define OP(type, op) \
559554
friend type operator op( \
560-
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
561-
NumCols, Layout, Group> &lhs, \
562-
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
555+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
556+
Group> &lhs, \
557+
const sycl::ext::oneapi::bfloat16 &rhs) { \
563558
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
564559
lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
565560
} \
566561
friend type operator op( \
567-
const sycl::ext::oneapi::experimental::bfloat16 &lhs, \
568-
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
569-
NumCols, Layout, Group> &rhs) { \
562+
const sycl::ext::oneapi::bfloat16 &lhs, \
563+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
564+
Group> &rhs) { \
570565
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
571566
rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
572567
}
@@ -579,24 +574,23 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
579574
#undef OP
580575
#else // __SYCL_DEVICE_ONLY__
581576
#define OP(type, op) \
582-
friend type operator op( \
583-
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
584-
NumCols, Layout, Group> &, \
585-
const sycl::ext::oneapi::experimental::bfloat16 &) { \
577+
friend type operator op(const wi_element<sycl::ext::oneapi::bfloat16, \
578+
NumRows, NumCols, Layout, Group> &, \
579+
const sycl::ext::oneapi::bfloat16 &) { \
586580
throw runtime_error("joint matrix is not supported on host device.", \
587581
PI_ERROR_INVALID_DEVICE); \
588582
} \
589583
friend type operator op( \
590-
const sycl::ext::oneapi::experimental::bfloat16 &, \
591-
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
592-
NumCols, Layout, Group> &) { \
584+
const sycl::ext::oneapi::bfloat16 &, \
585+
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
586+
Group> &) { \
593587
throw runtime_error("joint matrix is not supported on host device.", \
594588
PI_ERROR_INVALID_DEVICE); \
595589
}
596-
OP(sycl::ext::oneapi::experimental::bfloat16, +)
597-
OP(sycl::ext::oneapi::experimental::bfloat16, -)
598-
OP(sycl::ext::oneapi::experimental::bfloat16, *)
599-
OP(sycl::ext::oneapi::experimental::bfloat16, /)
590+
OP(sycl::ext::oneapi::bfloat16, +)
591+
OP(sycl::ext::oneapi::bfloat16, -)
592+
OP(sycl::ext::oneapi::bfloat16, *)
593+
OP(sycl::ext::oneapi::bfloat16, /)
600594
OP(bool, ==)
601595
OP(bool, !=)
602596
OP(bool, <)

sycl/test/extensions/bfloat16.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
// UNSUPPORTED: cuda || hip_amd
44

5-
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
5+
#include <sycl/ext/oneapi/bfloat16.hpp>
66
#include <sycl/sycl.hpp>
77

8-
using sycl::ext::oneapi::experimental::bfloat16;
8+
using sycl::ext::oneapi::bfloat16;
99

1010
SYCL_EXTERNAL uint16_t some_bf16_intrinsic(uint16_t x, uint16_t y);
1111
SYCL_EXTERNAL void foo(long x, sycl::half y);

0 commit comments

Comments
 (0)