@@ -132,69 +132,175 @@ class bfloat16 {
132
132
#endif
133
133
}
134
134
135
- // Increment and decrement operators overloading
135
+ bfloat16 &operator +=(const bfloat16 &rhs) {
136
+ value = from_float (to_float (value) + to_float (rhs.value ));
137
+ return *this ;
138
+ }
139
+
140
+ bfloat16 &operator -=(const bfloat16 &rhs) {
141
+ value = from_float (to_float (value) - to_float (rhs.value ));
142
+ return *this ;
143
+ }
144
+
145
+ bfloat16 &operator *=(const bfloat16 &rhs) {
146
+ value = from_float (to_float (value) * to_float (rhs.value ));
147
+ return *this ;
148
+ }
149
+
150
+ bfloat16 &operator /=(const bfloat16 &rhs) {
151
+ value = from_float (to_float (value) / to_float (rhs.value ));
152
+ return *this ;
153
+ }
154
+
155
+ // Operator ++, --
156
+ bfloat16 &operator ++() {
157
+ float f = to_float (value);
158
+ value = from_float (++f);
159
+ return *this ;
160
+ }
161
+
162
+ bfloat16 operator ++(int ) {
163
+ bfloat16 ret (*this );
164
+ operator ++();
165
+ return ret;
166
+ }
167
+
168
+ bfloat16 &operator --() {
169
+ float f = to_float (value);
170
+ value = from_float (--f);
171
+ return *this ;
172
+ }
173
+
174
+ bfloat16 operator --(int ) {
175
+ bfloat16 ret (*this );
176
+ operator --();
177
+ return ret;
178
+ }
179
+
180
+ // Operator +, -, *, /
136
181
#define OP (op ) \
137
- friend bfloat16 &operator op (bfloat16 &lhs) { \
138
- float f = to_float (lhs.value ); \
139
- lhs.value = from_float (op f); \
140
- return lhs; \
141
- } \
142
- friend bfloat16 operator op (bfloat16 &lhs, int ) { \
143
- bfloat16 old = lhs; \
144
- operator op (lhs); \
145
- return old; \
146
- }
147
- OP (++)
148
- OP (--)
182
+ friend bfloat16 operator op (const bfloat16 lhs, const bfloat16 rhs) { \
183
+ return to_float (lhs.value ) op to_float (rhs.value ); \
184
+ } \
185
+ friend double operator op (const bfloat16 lhs, const double rhs) { \
186
+ return to_float (lhs.value ) op rhs; \
187
+ } \
188
+ friend double operator op (const double lhs, const bfloat16 rhs) { \
189
+ return lhs op to_float (rhs.value ); \
190
+ } \
191
+ friend float operator op (const bfloat16 lhs, const float rhs) { \
192
+ return to_float (lhs.value ) op rhs; \
193
+ } \
194
+ friend float operator op (const float lhs, const bfloat16 rhs) { \
195
+ return lhs op to_float (rhs.value ); \
196
+ } \
197
+ friend bfloat16 operator op (const bfloat16 lhs, const int rhs) { \
198
+ return to_float (lhs.value ) op rhs; \
199
+ } \
200
+ friend bfloat16 operator op (const int lhs, const bfloat16 rhs) { \
201
+ return lhs op to_float (rhs.value ); \
202
+ } \
203
+ friend bfloat16 operator op (const bfloat16 lhs, const long rhs) { \
204
+ return to_float (lhs.value ) op rhs; \
205
+ } \
206
+ friend bfloat16 operator op (const long lhs, const bfloat16 rhs) { \
207
+ return lhs op to_float (rhs.value ); \
208
+ } \
209
+ friend bfloat16 operator op (const bfloat16 lhs, const long long rhs) { \
210
+ return to_float (lhs.value ) op rhs; \
211
+ } \
212
+ friend bfloat16 operator op (const long long lhs, const bfloat16 rhs) { \
213
+ return lhs op to_float (rhs.value ); \
214
+ } \
215
+ friend bfloat16 operator op (const bfloat16 &lhs, const unsigned int &rhs) { \
216
+ return to_float (lhs.value ) op rhs; \
217
+ } \
218
+ friend bfloat16 operator op (const unsigned int &lhs, const bfloat16 &rhs) { \
219
+ return lhs op to_float (rhs.value ); \
220
+ } \
221
+ friend bfloat16 operator op (const bfloat16 &lhs, const unsigned long &rhs) { \
222
+ return to_float (lhs.value ) op rhs; \
223
+ } \
224
+ friend bfloat16 operator op (const unsigned long &lhs, const bfloat16 &rhs) { \
225
+ return lhs op to_float (rhs.value ); \
226
+ } \
227
+ friend bfloat16 operator op (const bfloat16 &lhs, \
228
+ const unsigned long long &rhs) { \
229
+ return to_float (lhs.value ) op rhs; \
230
+ } \
231
+ friend bfloat16 operator op (const unsigned long long &lhs, \
232
+ const bfloat16 &rhs) { \
233
+ return lhs op to_float (rhs.value ); \
234
+ }
235
+ OP (+)
236
+ OP (-)
237
+ OP (*)
238
+ OP (/)
239
+
149
240
#undef OP
150
241
151
- // Assignment operators overloading
242
+ // Operator ==, !=, <, >, <=, >=
152
243
#define OP (op ) \
153
- friend bfloat16 &operator op (bfloat16 &lhs, const bfloat16 &rhs) { \
154
- float f = static_cast <float >(lhs); \
155
- f op static_cast <float >(rhs); \
156
- return lhs = f; \
157
- } \
158
- template <typename T> \
159
- friend bfloat16 &operator op (bfloat16 &lhs, const T &rhs) { \
160
- float f = static_cast <float >(lhs); \
161
- f op static_cast <float >(rhs); \
162
- return lhs = f; \
163
- } \
164
- template <typename T> friend T &operator op (T &lhs, const bfloat16 &rhs) { \
165
- float f = static_cast <float >(lhs); \
166
- f op static_cast <float >(rhs); \
167
- return lhs = f; \
168
- }
169
- OP (+=)
170
- OP (-=)
171
- OP (*=)
172
- OP (/=)
173
- #undef OP
244
+ friend bool operator op (const bfloat16 &lhs, const bfloat16 &rhs) { \
245
+ return to_float (lhs.value ) op to_float (rhs.value ); \
246
+ } \
247
+ friend bool operator op (const bfloat16 &lhs, const double &rhs) { \
248
+ return to_float (lhs.value ) op rhs; \
249
+ } \
250
+ friend bool operator op (const double &lhs, const bfloat16 &rhs) { \
251
+ return lhs op to_float (rhs.value ); \
252
+ } \
253
+ friend bool operator op (const bfloat16 &lhs, const float &rhs) { \
254
+ return to_float (lhs.value ) op rhs; \
255
+ } \
256
+ friend bool operator op (const float &lhs, const bfloat16 &rhs) { \
257
+ return lhs op to_float (rhs.value ); \
258
+ } \
259
+ friend bool operator op (const bfloat16 &lhs, const int &rhs) { \
260
+ return to_float (lhs.value ) op rhs; \
261
+ } \
262
+ friend bool operator op (const int &lhs, const bfloat16 &rhs) { \
263
+ return lhs op to_float (rhs.value ); \
264
+ } \
265
+ friend bool operator op (const bfloat16 &lhs, const long &rhs) { \
266
+ return to_float (lhs.value ) op rhs; \
267
+ } \
268
+ friend bool operator op (const long &lhs, const bfloat16 &rhs) { \
269
+ return lhs op to_float (rhs.value ); \
270
+ } \
271
+ friend bool operator op (const bfloat16 &lhs, const long long &rhs) { \
272
+ return to_float (lhs.value ) op rhs; \
273
+ } \
274
+ friend bool operator op (const long long &lhs, const bfloat16 &rhs) { \
275
+ return lhs op to_float (rhs.value ); \
276
+ } \
277
+ friend bool operator op (const bfloat16 &lhs, const unsigned int &rhs) { \
278
+ return to_float (lhs.value ) op rhs; \
279
+ } \
280
+ friend bool operator op (const unsigned int &lhs, const bfloat16 &rhs) { \
281
+ return lhs op to_float (rhs.value ); \
282
+ } \
283
+ friend bool operator op (const bfloat16 &lhs, const unsigned long &rhs) { \
284
+ return to_float (lhs.value ) op rhs; \
285
+ } \
286
+ friend bool operator op (const unsigned long &lhs, const bfloat16 &rhs) { \
287
+ return lhs op to_float (rhs.value ); \
288
+ } \
289
+ friend bool operator op (const bfloat16 &lhs, \
290
+ const unsigned long long &rhs) { \
291
+ return to_float (lhs.value ) op rhs; \
292
+ } \
293
+ friend bool operator op (const unsigned long long &lhs, \
294
+ const bfloat16 &rhs) { \
295
+ return lhs op to_float (rhs.value ); \
296
+ }
297
+ OP (==)
298
+ OP (!=)
299
+ OP (<)
300
+ OP (>)
301
+ OP (<=)
302
+ OP (>=)
174
303
175
- // Binary operators overloading
176
- #define OP (type, op ) \
177
- friend type operator op (const bfloat16 &lhs, const bfloat16 &rhs) { \
178
- return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
179
- } \
180
- template <typename T> \
181
- friend type operator op (const bfloat16 &lhs, const T &rhs) { \
182
- return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
183
- } \
184
- template <typename T> \
185
- friend type operator op (const T &lhs, const bfloat16 &rhs) { \
186
- return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
187
- }
188
- OP (bfloat16, +)
189
- OP (bfloat16, -)
190
- OP (bfloat16, *)
191
- OP (bfloat16, /)
192
- OP (bool , ==)
193
- OP (bool , !=)
194
- OP (bool , <)
195
- OP (bool , >)
196
- OP (bool , <=)
197
- OP (bool , >=)
198
304
#undef OP
199
305
200
306
// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
0 commit comments