8
8
9
9
#pragma once
10
10
11
- #include < cmath>
12
- #include < cstdint>
13
- #include < cstring>
14
- #include < limits>
15
- #include < ostream>
16
-
17
- namespace executorch {
18
- namespace runtime {
19
- namespace etensor {
11
+ #include < c10/util/BFloat16.h>
20
12
13
+ namespace executorch ::runtime::etensor {
14
+ using c10::BFloat16;
21
15
namespace internal {
22
- inline float f32_from_bits (uint16_t src) {
23
- float res = 0 ;
24
- uint32_t tmp = src;
25
- tmp <<= 16 ;
26
- std::memcpy (&res, &tmp, sizeof (tmp));
27
- return res;
28
- }
29
-
30
- inline uint16_t round_to_nearest_even (float src) {
31
- if (std::isnan (src)) {
32
- return UINT16_C (0x7FC0 );
33
- }
34
- uint32_t U32 = 0 ;
35
- std::memcpy (&U32, &src, sizeof (U32));
36
- uint32_t rounding_bias = ((U32 >> 16 ) & 1 ) + UINT32_C (0x7FFF );
37
- return static_cast <uint16_t >((U32 + rounding_bias) >> 16 );
38
- }
16
+ using c10::detail::f32_from_bits;
17
+ using c10::detail::round_to_nearest_even;
39
18
} // namespace internal
40
-
41
- /* *
42
- * The "brain floating-point" type, compatible with c10/util/BFloat16.h from
43
- * pytorch core.
44
- *
45
- * This representation uses 1 bit for the sign, 8 bits for the exponent and 7
46
- * bits for the mantissa.
47
- */
48
- struct alignas (2 ) BFloat16 {
49
- uint16_t x;
50
-
51
- BFloat16 () = default ;
52
- struct from_bits_t {};
53
- static constexpr from_bits_t from_bits () {
54
- return from_bits_t ();
55
- }
56
-
57
- constexpr BFloat16 (unsigned short bits, from_bits_t ) : x (bits) {}
58
- /* implicit */ BFloat16 (float value)
59
- : x (internal::round_to_nearest_even (value)) {}
60
- operator float () const {
61
- return internal::f32_from_bits (x);
62
- }
63
- };
64
-
65
- inline std::ostream& operator <<(std::ostream& out, const BFloat16& value) {
66
- out << (float )value;
67
- return out;
68
- }
69
-
70
- // / Arithmetic
71
-
72
- inline BFloat16 operator +(const BFloat16& a, const BFloat16& b) {
73
- return static_cast <float >(a) + static_cast <float >(b);
74
- }
75
-
76
- inline BFloat16 operator -(const BFloat16& a, const BFloat16& b) {
77
- return static_cast <float >(a) - static_cast <float >(b);
78
- }
79
-
80
- inline BFloat16 operator *(const BFloat16& a, const BFloat16& b) {
81
- return static_cast <float >(a) * static_cast <float >(b);
82
- }
83
-
84
- inline BFloat16 operator /(const BFloat16& a, const BFloat16& b) {
85
- return static_cast <float >(a) / static_cast <float >(b);
86
- }
87
-
88
- inline BFloat16 operator -(const BFloat16& a) {
89
- return -static_cast <float >(a);
90
- }
91
-
92
- inline BFloat16& operator +=(BFloat16& a, const BFloat16& b) {
93
- a = a + b;
94
- return a;
95
- }
96
-
97
- inline BFloat16& operator -=(BFloat16& a, const BFloat16& b) {
98
- a = a - b;
99
- return a;
100
- }
101
-
102
- inline BFloat16& operator *=(BFloat16& a, const BFloat16& b) {
103
- a = a * b;
104
- return a;
105
- }
106
-
107
- inline BFloat16& operator /=(BFloat16& a, const BFloat16& b) {
108
- a = a / b;
109
- return a;
110
- }
111
-
112
- inline BFloat16& operator |(BFloat16& a, const BFloat16& b) {
113
- a.x = a.x | b.x ;
114
- return a;
115
- }
116
-
117
- inline BFloat16& operator ^(BFloat16& a, const BFloat16& b) {
118
- a.x = a.x ^ b.x ;
119
- return a;
120
- }
121
-
122
- inline BFloat16& operator &(BFloat16& a, const BFloat16& b) {
123
- a.x = a.x & b.x ;
124
- return a;
125
- }
126
-
127
- // / Arithmetic with floats
128
-
129
- inline float operator +(BFloat16 a, float b) {
130
- return static_cast <float >(a) + b;
131
- }
132
- inline float operator -(BFloat16 a, float b) {
133
- return static_cast <float >(a) - b;
134
- }
135
- inline float operator *(BFloat16 a, float b) {
136
- return static_cast <float >(a) * b;
137
- }
138
- inline float operator /(BFloat16 a, float b) {
139
- return static_cast <float >(a) / b;
140
- }
141
-
142
- inline float operator +(float a, BFloat16 b) {
143
- return a + static_cast <float >(b);
144
- }
145
- inline float operator -(float a, BFloat16 b) {
146
- return a - static_cast <float >(b);
147
- }
148
- inline float operator *(float a, BFloat16 b) {
149
- return a * static_cast <float >(b);
150
- }
151
- inline float operator /(float a, BFloat16 b) {
152
- return a / static_cast <float >(b);
153
- }
154
-
155
- inline float & operator +=(float & a, const BFloat16& b) {
156
- return a += static_cast <float >(b);
157
- }
158
- inline float & operator -=(float & a, const BFloat16& b) {
159
- return a -= static_cast <float >(b);
160
- }
161
- inline float & operator *=(float & a, const BFloat16& b) {
162
- return a *= static_cast <float >(b);
163
- }
164
- inline float & operator /=(float & a, const BFloat16& b) {
165
- return a /= static_cast <float >(b);
166
- }
167
-
168
- // / Arithmetic with doubles
169
-
170
- inline double operator +(BFloat16 a, double b) {
171
- return static_cast <double >(a) + b;
172
- }
173
- inline double operator -(BFloat16 a, double b) {
174
- return static_cast <double >(a) - b;
175
- }
176
- inline double operator *(BFloat16 a, double b) {
177
- return static_cast <double >(a) * b;
178
- }
179
- inline double operator /(BFloat16 a, double b) {
180
- return static_cast <double >(a) / b;
181
- }
182
-
183
- inline double operator +(double a, BFloat16 b) {
184
- return a + static_cast <double >(b);
185
- }
186
- inline double operator -(double a, BFloat16 b) {
187
- return a - static_cast <double >(b);
188
- }
189
- inline double operator *(double a, BFloat16 b) {
190
- return a * static_cast <double >(b);
191
- }
192
- inline double operator /(double a, BFloat16 b) {
193
- return a / static_cast <double >(b);
194
- }
195
-
196
- // / Arithmetic with ints
197
-
198
- inline BFloat16 operator +(BFloat16 a, int b) {
199
- return a + static_cast <BFloat16>(b);
200
- }
201
- inline BFloat16 operator -(BFloat16 a, int b) {
202
- return a - static_cast <BFloat16>(b);
203
- }
204
- inline BFloat16 operator *(BFloat16 a, int b) {
205
- return a * static_cast <BFloat16>(b);
206
- }
207
- inline BFloat16 operator /(BFloat16 a, int b) {
208
- return a / static_cast <BFloat16>(b);
209
- }
210
-
211
- inline BFloat16 operator +(int a, BFloat16 b) {
212
- return static_cast <BFloat16>(a) + b;
213
- }
214
- inline BFloat16 operator -(int a, BFloat16 b) {
215
- return static_cast <BFloat16>(a) - b;
216
- }
217
- inline BFloat16 operator *(int a, BFloat16 b) {
218
- return static_cast <BFloat16>(a) * b;
219
- }
220
- inline BFloat16 operator /(int a, BFloat16 b) {
221
- return static_cast <BFloat16>(a) / b;
222
- }
223
-
224
- // // Arithmetic with int64_t
225
-
226
- inline BFloat16 operator +(BFloat16 a, int64_t b) {
227
- return a + static_cast <BFloat16>(b);
228
- }
229
- inline BFloat16 operator -(BFloat16 a, int64_t b) {
230
- return a - static_cast <BFloat16>(b);
231
- }
232
- inline BFloat16 operator *(BFloat16 a, int64_t b) {
233
- return a * static_cast <BFloat16>(b);
234
- }
235
- inline BFloat16 operator /(BFloat16 a, int64_t b) {
236
- return a / static_cast <BFloat16>(b);
237
- }
238
-
239
- inline BFloat16 operator +(int64_t a, BFloat16 b) {
240
- return static_cast <BFloat16>(a) + b;
241
- }
242
- inline BFloat16 operator -(int64_t a, BFloat16 b) {
243
- return static_cast <BFloat16>(a) - b;
244
- }
245
- inline BFloat16 operator *(int64_t a, BFloat16 b) {
246
- return static_cast <BFloat16>(a) * b;
247
- }
248
- inline BFloat16 operator /(int64_t a, BFloat16 b) {
249
- return static_cast <BFloat16>(a) / b;
250
- }
251
-
252
- // Overloading < and > operators, because std::max and std::min use them.
253
-
254
- inline bool operator >(BFloat16& lhs, BFloat16& rhs) {
255
- return float (lhs) > float (rhs);
256
- }
257
-
258
- inline bool operator <(BFloat16& lhs, BFloat16& rhs) {
259
- return float (lhs) < float (rhs);
260
- }
261
-
262
- } // namespace etensor
263
- } // namespace runtime
264
- } // namespace executorch
19
+ } // namespace executorch::runtime::etensor
265
20
266
21
namespace torch {
267
22
namespace executor {
@@ -270,74 +25,3 @@ namespace executor {
270
25
using ::executorch::runtime::etensor::BFloat16;
271
26
} // namespace executor
272
27
} // namespace torch
273
-
274
- namespace std {
275
-
276
- template <>
277
- class numeric_limits <executorch::runtime::etensor::BFloat16> {
278
- public:
279
- static constexpr bool is_signed = true ;
280
- static constexpr bool is_specialized = true ;
281
- static constexpr bool is_integer = false ;
282
- static constexpr bool is_exact = false ;
283
- static constexpr bool has_infinity = true ;
284
- static constexpr bool has_quiet_NaN = true ;
285
- static constexpr bool has_signaling_NaN = true ;
286
- static constexpr auto has_denorm = numeric_limits<float >::has_denorm;
287
- static constexpr auto has_denorm_loss =
288
- numeric_limits<float >::has_denorm_loss;
289
- static constexpr auto round_style = numeric_limits<float >::round_style;
290
- static constexpr bool is_iec559 = false ;
291
- static constexpr bool is_bounded = true ;
292
- static constexpr bool is_modulo = false ;
293
- static constexpr int digits = 8 ;
294
- static constexpr int digits10 = 2 ;
295
- static constexpr int max_digits10 = 4 ;
296
- static constexpr int radix = 2 ;
297
- static constexpr int min_exponent = -125 ;
298
- static constexpr int min_exponent10 = -37 ;
299
- static constexpr int max_exponent = 128 ;
300
- static constexpr int max_exponent10 = 38 ;
301
- static constexpr auto traps = numeric_limits<float >::traps;
302
- static constexpr auto tinyness_before =
303
- numeric_limits<float >::tinyness_before;
304
-
305
- static constexpr torch::executor::BFloat16 min () {
306
- return torch::executor::BFloat16 (
307
- 0x0080 , torch::executor::BFloat16::from_bits ());
308
- }
309
- static constexpr torch::executor::BFloat16 lowest () {
310
- return torch::executor::BFloat16 (
311
- 0xFF7F , torch::executor::BFloat16::from_bits ());
312
- }
313
- static constexpr torch::executor::BFloat16 max () {
314
- return torch::executor::BFloat16 (
315
- 0x7F7F , torch::executor::BFloat16::from_bits ());
316
- }
317
- static constexpr torch::executor::BFloat16 epsilon () {
318
- return torch::executor::BFloat16 (
319
- 0x3C00 , torch::executor::BFloat16::from_bits ());
320
- }
321
- static constexpr torch::executor::BFloat16 round_error () {
322
- return torch::executor::BFloat16 (
323
- 0x3F00 , torch::executor::BFloat16::from_bits ());
324
- }
325
- static constexpr torch::executor::BFloat16 infinity () {
326
- return torch::executor::BFloat16 (
327
- 0x7F80 , torch::executor::BFloat16::from_bits ());
328
- }
329
- static constexpr torch::executor::BFloat16 quiet_NaN () {
330
- return torch::executor::BFloat16 (
331
- 0x7FC0 , torch::executor::BFloat16::from_bits ());
332
- }
333
- static constexpr torch::executor::BFloat16 signaling_NaN () {
334
- return torch::executor::BFloat16 (
335
- 0x7F80 , torch::executor::BFloat16::from_bits ());
336
- }
337
- static constexpr torch::executor::BFloat16 denorm_min () {
338
- return torch::executor::BFloat16 (
339
- 0x0001 , torch::executor::BFloat16::from_bits ());
340
- }
341
- };
342
-
343
- } // namespace std
0 commit comments