Skip to content

Commit e55d06a

Browse files
committed
Use c10 version of half/bfloat16 in executorch
Pull Request resolved: #7040 Accomplished by importing relevant files from c10 into executorch/runtime/core/portable_type/c10, and then using `using` in the top-level ExecuTorch headers. This approach should keep the ExecuTorch build hermetic for embedded use cases. In the future, we should add a CI job to ensure the c10 files stay identical to the PyTorch ones. ghstack-source-id: 255100694 @exported-using-ghexport Differential Revision: [D66106969](https://our.internmc.facebook.com/intern/diff/D66106969/)
1 parent abd739e commit e55d06a

File tree

18 files changed

+2257
-1118
lines changed

18 files changed

+2257
-1118
lines changed

.lintrunner.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ exclude_patterns = [
7777
# File contains @generated
7878
'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h',
7979
'extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h',
80+
# Want to be able to keep c10 in sync with PyTorch core.
81+
'runtime/core/portable_type/c10/**',
8082
]
8183
command = [
8284
'python',
@@ -260,6 +262,8 @@ exclude_patterns = [
260262
'extension/**',
261263
'kernels/optimized/**',
262264
'runtime/core/exec_aten/**',
265+
# Want to be able to keep c10 in sync with PyTorch core.
266+
'runtime/core/portable_type/c10/**',
263267
'runtime/executor/tensor_parser_aten.cpp',
264268
'scripts/**',
265269
'test/**',

CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ if(NOT "${_repo_dir_name}" STREQUAL "executorch")
337337
"fix for this restriction."
338338
)
339339
endif()
340-
set(_common_include_directories ${CMAKE_CURRENT_SOURCE_DIR}/..)
340+
set(_common_include_directories ${CMAKE_CURRENT_SOURCE_DIR}/.. ${CMAKE_CURRENT_SOURCE_DIR}/runtime/core/portable_type)
341+
# We don't need any of C10's CMake macros.
342+
add_definitions(-DC10_USING_CUSTOM_GENERATED_MACROS)
341343

342344
#
343345
# The `_<target>_srcs` lists are defined by including ${EXECUTORCH_SRCS_FILE}.

runtime/core/exec_aten/exec_aten.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
#else // use executor
3434
#include <executorch/runtime/core/array_ref.h> // @manual
3535
#include <executorch/runtime/core/portable_type/bfloat16.h> // @manual
36-
#include <executorch/runtime/core/portable_type/bfloat16_math.h> // @manual
3736
#include <executorch/runtime/core/portable_type/complex.h> // @manual
3837
#include <executorch/runtime/core/portable_type/device.h> // @manual
3938
#include <executorch/runtime/core/portable_type/half.h> // @manual

runtime/core/portable_type/bfloat16.h

Lines changed: 6 additions & 322 deletions
Original file line numberDiff line numberDiff line change
@@ -8,260 +8,15 @@
88

99
#pragma once
1010

11-
#include <cmath>
12-
#include <cstdint>
13-
#include <cstring>
14-
#include <limits>
15-
#include <ostream>
16-
17-
namespace executorch {
18-
namespace runtime {
19-
namespace etensor {
11+
#include <c10/util/BFloat16.h>
2012

13+
namespace executorch::runtime::etensor {
14+
using c10::BFloat16;
2115
namespace internal {
22-
inline float f32_from_bits(uint16_t src) {
23-
float res = 0;
24-
uint32_t tmp = src;
25-
tmp <<= 16;
26-
std::memcpy(&res, &tmp, sizeof(tmp));
27-
return res;
28-
}
29-
30-
inline uint16_t round_to_nearest_even(float src) {
31-
if (std::isnan(src)) {
32-
return UINT16_C(0x7FC0);
33-
}
34-
uint32_t U32 = 0;
35-
std::memcpy(&U32, &src, sizeof(U32));
36-
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
37-
return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
38-
}
16+
using c10::detail::f32_from_bits;
17+
using c10::detail::round_to_nearest_even;
3918
} // namespace internal
40-
41-
/**
42-
* The "brain floating-point" type, compatible with c10/util/BFloat16.h from
43-
* pytorch core.
44-
*
45-
* This representation uses 1 bit for the sign, 8 bits for the exponent and 7
46-
* bits for the mantissa.
47-
*/
48-
struct alignas(2) BFloat16 {
49-
uint16_t x;
50-
51-
BFloat16() = default;
52-
struct from_bits_t {};
53-
static constexpr from_bits_t from_bits() {
54-
return from_bits_t();
55-
}
56-
57-
constexpr BFloat16(unsigned short bits, from_bits_t) : x(bits) {}
58-
/* implicit */ BFloat16(float value)
59-
: x(internal::round_to_nearest_even(value)) {}
60-
operator float() const {
61-
return internal::f32_from_bits(x);
62-
}
63-
};
64-
65-
inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) {
66-
out << (float)value;
67-
return out;
68-
}
69-
70-
/// Arithmetic
71-
72-
inline BFloat16 operator+(const BFloat16& a, const BFloat16& b) {
73-
return static_cast<float>(a) + static_cast<float>(b);
74-
}
75-
76-
inline BFloat16 operator-(const BFloat16& a, const BFloat16& b) {
77-
return static_cast<float>(a) - static_cast<float>(b);
78-
}
79-
80-
inline BFloat16 operator*(const BFloat16& a, const BFloat16& b) {
81-
return static_cast<float>(a) * static_cast<float>(b);
82-
}
83-
84-
inline BFloat16 operator/(const BFloat16& a, const BFloat16& b) {
85-
return static_cast<float>(a) / static_cast<float>(b);
86-
}
87-
88-
inline BFloat16 operator-(const BFloat16& a) {
89-
return -static_cast<float>(a);
90-
}
91-
92-
inline BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
93-
a = a + b;
94-
return a;
95-
}
96-
97-
inline BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
98-
a = a - b;
99-
return a;
100-
}
101-
102-
inline BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
103-
a = a * b;
104-
return a;
105-
}
106-
107-
inline BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
108-
a = a / b;
109-
return a;
110-
}
111-
112-
inline BFloat16& operator|(BFloat16& a, const BFloat16& b) {
113-
a.x = a.x | b.x;
114-
return a;
115-
}
116-
117-
inline BFloat16& operator^(BFloat16& a, const BFloat16& b) {
118-
a.x = a.x ^ b.x;
119-
return a;
120-
}
121-
122-
inline BFloat16& operator&(BFloat16& a, const BFloat16& b) {
123-
a.x = a.x & b.x;
124-
return a;
125-
}
126-
127-
/// Arithmetic with floats
128-
129-
inline float operator+(BFloat16 a, float b) {
130-
return static_cast<float>(a) + b;
131-
}
132-
inline float operator-(BFloat16 a, float b) {
133-
return static_cast<float>(a) - b;
134-
}
135-
inline float operator*(BFloat16 a, float b) {
136-
return static_cast<float>(a) * b;
137-
}
138-
inline float operator/(BFloat16 a, float b) {
139-
return static_cast<float>(a) / b;
140-
}
141-
142-
inline float operator+(float a, BFloat16 b) {
143-
return a + static_cast<float>(b);
144-
}
145-
inline float operator-(float a, BFloat16 b) {
146-
return a - static_cast<float>(b);
147-
}
148-
inline float operator*(float a, BFloat16 b) {
149-
return a * static_cast<float>(b);
150-
}
151-
inline float operator/(float a, BFloat16 b) {
152-
return a / static_cast<float>(b);
153-
}
154-
155-
inline float& operator+=(float& a, const BFloat16& b) {
156-
return a += static_cast<float>(b);
157-
}
158-
inline float& operator-=(float& a, const BFloat16& b) {
159-
return a -= static_cast<float>(b);
160-
}
161-
inline float& operator*=(float& a, const BFloat16& b) {
162-
return a *= static_cast<float>(b);
163-
}
164-
inline float& operator/=(float& a, const BFloat16& b) {
165-
return a /= static_cast<float>(b);
166-
}
167-
168-
/// Arithmetic with doubles
169-
170-
inline double operator+(BFloat16 a, double b) {
171-
return static_cast<double>(a) + b;
172-
}
173-
inline double operator-(BFloat16 a, double b) {
174-
return static_cast<double>(a) - b;
175-
}
176-
inline double operator*(BFloat16 a, double b) {
177-
return static_cast<double>(a) * b;
178-
}
179-
inline double operator/(BFloat16 a, double b) {
180-
return static_cast<double>(a) / b;
181-
}
182-
183-
inline double operator+(double a, BFloat16 b) {
184-
return a + static_cast<double>(b);
185-
}
186-
inline double operator-(double a, BFloat16 b) {
187-
return a - static_cast<double>(b);
188-
}
189-
inline double operator*(double a, BFloat16 b) {
190-
return a * static_cast<double>(b);
191-
}
192-
inline double operator/(double a, BFloat16 b) {
193-
return a / static_cast<double>(b);
194-
}
195-
196-
/// Arithmetic with ints
197-
198-
inline BFloat16 operator+(BFloat16 a, int b) {
199-
return a + static_cast<BFloat16>(b);
200-
}
201-
inline BFloat16 operator-(BFloat16 a, int b) {
202-
return a - static_cast<BFloat16>(b);
203-
}
204-
inline BFloat16 operator*(BFloat16 a, int b) {
205-
return a * static_cast<BFloat16>(b);
206-
}
207-
inline BFloat16 operator/(BFloat16 a, int b) {
208-
return a / static_cast<BFloat16>(b);
209-
}
210-
211-
inline BFloat16 operator+(int a, BFloat16 b) {
212-
return static_cast<BFloat16>(a) + b;
213-
}
214-
inline BFloat16 operator-(int a, BFloat16 b) {
215-
return static_cast<BFloat16>(a) - b;
216-
}
217-
inline BFloat16 operator*(int a, BFloat16 b) {
218-
return static_cast<BFloat16>(a) * b;
219-
}
220-
inline BFloat16 operator/(int a, BFloat16 b) {
221-
return static_cast<BFloat16>(a) / b;
222-
}
223-
224-
//// Arithmetic with int64_t
225-
226-
inline BFloat16 operator+(BFloat16 a, int64_t b) {
227-
return a + static_cast<BFloat16>(b);
228-
}
229-
inline BFloat16 operator-(BFloat16 a, int64_t b) {
230-
return a - static_cast<BFloat16>(b);
231-
}
232-
inline BFloat16 operator*(BFloat16 a, int64_t b) {
233-
return a * static_cast<BFloat16>(b);
234-
}
235-
inline BFloat16 operator/(BFloat16 a, int64_t b) {
236-
return a / static_cast<BFloat16>(b);
237-
}
238-
239-
inline BFloat16 operator+(int64_t a, BFloat16 b) {
240-
return static_cast<BFloat16>(a) + b;
241-
}
242-
inline BFloat16 operator-(int64_t a, BFloat16 b) {
243-
return static_cast<BFloat16>(a) - b;
244-
}
245-
inline BFloat16 operator*(int64_t a, BFloat16 b) {
246-
return static_cast<BFloat16>(a) * b;
247-
}
248-
inline BFloat16 operator/(int64_t a, BFloat16 b) {
249-
return static_cast<BFloat16>(a) / b;
250-
}
251-
252-
// Overloading < and > operators, because std::max and std::min use them.
253-
254-
inline bool operator>(BFloat16& lhs, BFloat16& rhs) {
255-
return float(lhs) > float(rhs);
256-
}
257-
258-
inline bool operator<(BFloat16& lhs, BFloat16& rhs) {
259-
return float(lhs) < float(rhs);
260-
}
261-
262-
} // namespace etensor
263-
} // namespace runtime
264-
} // namespace executorch
19+
} // namespace executorch::runtime::etensor
26520

26621
namespace torch {
26722
namespace executor {
@@ -270,74 +25,3 @@ namespace executor {
27025
using ::executorch::runtime::etensor::BFloat16;
27126
} // namespace executor
27227
} // namespace torch
273-
274-
namespace std {
275-
276-
template <>
277-
class numeric_limits<executorch::runtime::etensor::BFloat16> {
278-
public:
279-
static constexpr bool is_signed = true;
280-
static constexpr bool is_specialized = true;
281-
static constexpr bool is_integer = false;
282-
static constexpr bool is_exact = false;
283-
static constexpr bool has_infinity = true;
284-
static constexpr bool has_quiet_NaN = true;
285-
static constexpr bool has_signaling_NaN = true;
286-
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
287-
static constexpr auto has_denorm_loss =
288-
numeric_limits<float>::has_denorm_loss;
289-
static constexpr auto round_style = numeric_limits<float>::round_style;
290-
static constexpr bool is_iec559 = false;
291-
static constexpr bool is_bounded = true;
292-
static constexpr bool is_modulo = false;
293-
static constexpr int digits = 8;
294-
static constexpr int digits10 = 2;
295-
static constexpr int max_digits10 = 4;
296-
static constexpr int radix = 2;
297-
static constexpr int min_exponent = -125;
298-
static constexpr int min_exponent10 = -37;
299-
static constexpr int max_exponent = 128;
300-
static constexpr int max_exponent10 = 38;
301-
static constexpr auto traps = numeric_limits<float>::traps;
302-
static constexpr auto tinyness_before =
303-
numeric_limits<float>::tinyness_before;
304-
305-
static constexpr torch::executor::BFloat16 min() {
306-
return torch::executor::BFloat16(
307-
0x0080, torch::executor::BFloat16::from_bits());
308-
}
309-
static constexpr torch::executor::BFloat16 lowest() {
310-
return torch::executor::BFloat16(
311-
0xFF7F, torch::executor::BFloat16::from_bits());
312-
}
313-
static constexpr torch::executor::BFloat16 max() {
314-
return torch::executor::BFloat16(
315-
0x7F7F, torch::executor::BFloat16::from_bits());
316-
}
317-
static constexpr torch::executor::BFloat16 epsilon() {
318-
return torch::executor::BFloat16(
319-
0x3C00, torch::executor::BFloat16::from_bits());
320-
}
321-
static constexpr torch::executor::BFloat16 round_error() {
322-
return torch::executor::BFloat16(
323-
0x3F00, torch::executor::BFloat16::from_bits());
324-
}
325-
static constexpr torch::executor::BFloat16 infinity() {
326-
return torch::executor::BFloat16(
327-
0x7F80, torch::executor::BFloat16::from_bits());
328-
}
329-
static constexpr torch::executor::BFloat16 quiet_NaN() {
330-
return torch::executor::BFloat16(
331-
0x7FC0, torch::executor::BFloat16::from_bits());
332-
}
333-
static constexpr torch::executor::BFloat16 signaling_NaN() {
334-
return torch::executor::BFloat16(
335-
0x7F80, torch::executor::BFloat16::from_bits());
336-
}
337-
static constexpr torch::executor::BFloat16 denorm_min() {
338-
return torch::executor::BFloat16(
339-
0x0001, torch::executor::BFloat16::from_bits());
340-
}
341-
};
342-
343-
} // namespace std
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()

0 commit comments

Comments
 (0)