@@ -60,22 +60,45 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
60
60
// Get raw bits representation of bfloat16
61
61
operator storage_t () const { return value; }
62
62
63
- // Assignment operators overloading
63
+ // Logical operators (!,||,&&) are covered if we can cast to bool
64
+ explicit operator bool () { return to_float (value) != 0 .0f ; }
65
+
66
+ // Unary minus operator overloading
67
+ friend bfloat16 operator -(bfloat16 &lhs) {
68
+ return bfloat16{-to_float (lhs.value )};
69
+ }
70
+
71
+ // Increment and decrement operators overloading
72
+ #define OP (op ) \
73
+ friend bfloat16 &operator op (bfloat16 &lhs) { \
74
+ float f = to_float (lhs.value ); \
75
+ lhs.value = from_float (op f); \
76
+ return lhs; \
77
+ } \
78
+ friend bfloat16 operator op (bfloat16 &lhs, int ) { \
79
+ bfloat16 old = lhs; \
80
+ operator op (lhs); \
81
+ return old; \
82
+ }
83
+ OP (++)
84
+ OP (--)
85
+ #undef OP
86
+
87
+ // Assignment operators overloading
64
88
#define OP (op ) \
65
- friend bfloat16 operator op (bfloat16 &lhs, const bfloat16 &rhs) { \
89
+ friend bfloat16 & operator op (bfloat16 &lhs, const bfloat16 &rhs) { \
66
90
float f = static_cast <float >(lhs); \
67
91
f op static_cast <float >(rhs); \
68
92
return lhs = f; \
69
93
} \
70
94
\
71
95
template <typename T> \
72
- friend bfloat16 operator op (bfloat16 &lhs, const T &rhs) { \
96
+ friend bfloat16 & operator op (bfloat16 &lhs, const T &rhs) { \
73
97
float f = static_cast <float >(lhs); \
74
- \
75
98
f op static_cast <float >(rhs); \
76
99
return lhs = f; \
77
100
} \
78
- template <typename T> friend T operator op (T &lhs, const bfloat16 &rhs) { \
101
+ template <typename T> friend T & operator op (T &lhs, const bfloat16 &rhs) { \
79
102
float f = static_cast <float >(lhs); \
80
103
f op static_cast <float >(rhs); \
81
104
return lhs = f; \
@@ -86,28 +109,6 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
86
109
OP (/=)
87
110
#undef OP
88
111
89
- // Increment and decrement operators overloading
90
- #define OP (op ) \
91
- bfloat16 &operator op () { \
92
- float f = to_float (value); \
93
- value = from_float (op f); \
94
- return *this ; \
95
- } \
96
- bfloat16 operator op (int ) { \
97
- bfloat16 old = *this ; \
98
- operator op (); \
99
- return old; \
100
- }
101
- OP (++)
102
- OP (--)
103
- #undef OP
104
-
105
- // Unary minus operator overloading
106
- bfloat16 operator -() { return bfloat16{-to_float (value)}; }
107
-
108
- // Logical operators (!,||,&&) are covered if we can cast to bool
109
- explicit operator bool () { return to_float (value) != 0 .0f ; }
110
-
111
112
// Binary operators overloading
112
113
#define OP (type, op ) \
113
114
friend type operator op (const bfloat16 &lhs, const bfloat16 &rhs) { \
@@ -132,6 +133,7 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
132
133
OP (bool , <=)
133
134
OP (bool , >=)
134
135
#undef OP
136
+
135
137
// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
136
138
// for floating-point types.
137
139
};
0 commit comments