Skip to content

Commit b8a1899

Browse files
authored
[ExecuTorch] Implement BFloat16 and hook it up to scalar_type_util
Differential Revision: D61981361 Pull Request resolved: #4975
1 parent 6ccb290 commit b8a1899

File tree

8 files changed

+744
-87
lines changed

8 files changed

+744
-87
lines changed

kernels/portable/cpu/scalar_utils.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,6 @@ struct promote_type_with_scalar_type {
9494
static_assert(
9595
!is_bits_type<T1>::value,
9696
"promote_type_with_scalar_type not valid for bits dtypes");
97-
static_assert(
98-
!std::is_same<
99-
T1,
100-
typename ScalarTypeToCppType<exec_aten::ScalarType::BFloat16>::type>::
101-
value,
102-
"promote_type_with_scalar_type not valid for BFloat16");
10397
using promote_type_with_scalar_type_not_respecting_half_to_float =
10498
typename std::conditional<
10599
is_complex_type<T1>::value ||
@@ -119,10 +113,14 @@ struct promote_type_with_scalar_type {
119113
public:
120114
using type = typename std::conditional<
121115
half_to_float &&
122-
std::is_same<
123-
promote_type_with_scalar_type_not_respecting_half_to_float,
124-
typename ScalarTypeToCppType<exec_aten::ScalarType::Half>::type>::
125-
value,
116+
(std::is_same<
117+
promote_type_with_scalar_type_not_respecting_half_to_float,
118+
typename ScalarTypeToCppType<
119+
exec_aten::ScalarType::Half>::type>::value ||
120+
std::is_same<
121+
promote_type_with_scalar_type_not_respecting_half_to_float,
122+
typename ScalarTypeToCppType<
123+
exec_aten::ScalarType::BFloat16>::type>::value),
126124
typename ScalarTypeToCppType<exec_aten::ScalarType::Float>::type,
127125
promote_type_with_scalar_type_not_respecting_half_to_float>::type;
128126
};

runtime/core/exec_aten/util/genScalarTypeTable.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,35 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
indexToType = ["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1"]
7+
indexToType = [
8+
"U1",
9+
"I1",
10+
"I2",
11+
"I4",
12+
"I8",
13+
"F2",
14+
"F4",
15+
"F8",
16+
"C2",
17+
"C4",
18+
"C8",
19+
"B1",
20+
"BF",
21+
]
822
promoteTypesLookup = [
9-
["U1", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "U1"],
10-
["I2", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I1"],
11-
["I2", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I2"],
12-
["I4", "I4", "I4", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I4"],
13-
["I8", "I8", "I8", "I8", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I8"],
14-
["F2", "F2", "F2", "F2", "F2", "F2", "F4", "F8", "C2", "C4", "C8", "F2"],
15-
["F4", "F4", "F4", "F4", "F4", "F4", "F4", "F8", "C4", "C4", "C8", "F4"],
16-
["F8", "F8", "F8", "F8", "F8", "F8", "F8", "F8", "C8", "C8", "C8", "F8"],
17-
["C2", "C2", "C2", "C2", "C2", "C2", "C4", "C8", "C2", "C4", "C8", "C2"],
18-
["C4", "C4", "C4", "C4", "C4", "C4", "C4", "C8", "C4", "C4", "C8", "C4"],
19-
["C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8"],
20-
["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1"],
23+
["U1", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "U1", "BF"],
24+
["I2", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I1", "BF"],
25+
["I2", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I2", "BF"],
26+
["I4", "I4", "I4", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I4", "BF"],
27+
["I8", "I8", "I8", "I8", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I8", "BF"],
28+
["F2", "F2", "F2", "F2", "F2", "F2", "F4", "F8", "C2", "C4", "C8", "F2", "F4"],
29+
["F4", "F4", "F4", "F4", "F4", "F4", "F4", "F8", "C4", "C4", "C8", "F4", "F4"],
30+
["F8", "F8", "F8", "F8", "F8", "F8", "F8", "F8", "C8", "C8", "C8", "F8", "F8"],
31+
["C2", "C2", "C2", "C2", "C2", "C2", "C4", "C8", "C2", "C4", "C8", "C2", "C4"],
32+
["C4", "C4", "C4", "C4", "C4", "C4", "C4", "C8", "C4", "C4", "C8", "C4", "C4"],
33+
["C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8"],
34+
["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1", "BF"],
35+
["BF", "BF", "BF", "BF", "BF", "F4", "F4", "F8", "C4", "C4", "C8", "BF", "BF"],
2136
]
2237
for rowIndex, row in enumerate(promoteTypesLookup):
2338
for colIndex, col in enumerate(row):

0 commit comments

Comments
 (0)