@@ -25,31 +25,115 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
25
25
bfloat16 (const bfloat16 &) = default ;
26
26
~bfloat16 () = default ;
27
27
28
- // Direct initialization
29
- bfloat16 (const storage_t &a) : value (a) {}
30
-
31
- // convert from float to bfloat16
32
- bfloat16 (const float &a) {
28
+ // Explicit conversion functions
29
+ static storage_t from_float (const float &a) {
33
30
#if defined(__SYCL_DEVICE_ONLY__)
34
- value = __spirv_ConvertFToBF16INTEL (a);
31
+ return __spirv_ConvertFToBF16INTEL (a);
35
32
#else
36
33
throw runtime_error (" Bfloat16 conversion is not supported on HOST device." ,
37
34
PI_INVALID_DEVICE);
38
35
#endif
39
36
}
40
-
41
- // convert from bfloat16 to float
42
- operator float () const {
37
+ static float to_float (const storage_t &a) {
43
38
#if defined(__SYCL_DEVICE_ONLY__)
44
- return __spirv_ConvertBF16ToFINTEL (value );
39
+ return __spirv_ConvertBF16ToFINTEL (a );
45
40
#else
46
41
throw runtime_error (" Bfloat16 conversion is not supported on HOST device." ,
47
42
PI_INVALID_DEVICE);
48
43
#endif
49
44
}
50
45
51
- // Get bfloat16 as uint16.
46
+ // Direct initialization
47
+ bfloat16 (const storage_t &a) : value (a) {}
48
+
49
+ // Implicit conversion from float to bfloat16
50
+ bfloat16 (const float &a) { value = from_float (a); }
51
+
52
+ bfloat16 &operator =(const float &rhs) {
53
+ value = from_float (rhs);
54
+ return *this ;
55
+ }
56
+
57
+ // Implicit conversion from bfloat16 to float
58
+ operator float () const { return to_float (value); }
59
+
60
+ // Get raw bits representation of bfloat16
52
61
operator storage_t () const { return value; }
62
+
63
+ // Assignment operators overloading
64
+ #define OP (op ) \
65
+ friend bfloat16 operator op (bfloat16 &lhs, const bfloat16 &rhs) { \
66
+ float f = static_cast <float >(lhs); \
67
+ f op static_cast <float >(rhs); \
68
+ return lhs = f; \
69
+ } \
70
+ \
71
+ template <typename T> \
72
+ friend bfloat16 operator op (bfloat16 &lhs, const T &rhs) { \
73
+ float f = static_cast <float >(lhs); \
74
+ \
75
+ f op static_cast <float >(rhs); \
76
+ return lhs = f; \
77
+ } \
78
+ template <typename T> friend T operator op (T &lhs, const bfloat16 &rhs) { \
79
+ float f = static_cast <float >(lhs); \
80
+ f op static_cast <float >(rhs); \
81
+ return lhs = f; \
82
+ }
83
+ OP (+=)
84
+ OP (-=)
85
+ OP (*=)
86
+ OP (/=)
87
+ #undef OP
88
+
89
+ // Increment and decrement operators overloading
90
+ #define OP (op ) \
91
+ bfloat16 &operator op () { \
92
+ float f = to_float (value); \
93
+ value = from_float (op f); \
94
+ return *this ; \
95
+ } \
96
+ bfloat16 operator op (int ) { \
97
+ bfloat16 old = *this ; \
98
+ operator op (); \
99
+ return old; \
100
+ }
101
+ OP (++)
102
+ OP (--)
103
+ #undef OP
104
+
105
+ // Unary minus operator overloading
106
+ bfloat16 operator -() { return bfloat16{-to_float (value)}; }
107
+
108
+ // Logical operators (!,||,&&) are covered if we can cast to bool
109
+ explicit operator bool () { return to_float (value) != 0 .0f ; }
110
+
111
+ // Binary operators overloading
112
+ #define OP (type, op ) \
113
+ friend type operator op (const bfloat16 &lhs, const bfloat16 &rhs) { \
114
+ return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
115
+ } \
116
+ template <typename T> \
117
+ friend type operator op (const bfloat16 &lhs, const T &rhs) { \
118
+ return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
119
+ } \
120
+ template <typename T> \
121
+ friend type operator op (const T &lhs, const bfloat16 &rhs) { \
122
+ return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
123
+ }
124
+ OP (bfloat16, +)
125
+ OP (bfloat16, -)
126
+ OP (bfloat16, *)
127
+ OP (bfloat16, /)
128
+ OP (bool , ==)
129
+ OP (bool , !=)
130
+ OP (bool , <)
131
+ OP (bool , >)
132
+ OP (bool , <=)
133
+ OP (bool , >=)
134
+ #undef OP
135
+ // Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
136
+ // for floating-point types.
53
137
};
54
138
55
139
} // namespace experimental
0 commit comments