Skip to content

Commit 03a0168

Browse files
authored
[ExecuTorch] support BF16 in op_add
Differential Revision: D61981362 Pull Request resolved: #4983
1 parent 234f948 commit 03a0168

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

kernels/optimized/cpu/op_add.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ Tensor& opt_add_out(
8383
ScalarType out_type = out.scalar_type();
8484

8585
if (b.numel() == 1) {
86-
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
86+
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
87+
a_type != ScalarType::BFloat16) {
8788
auto error = resize_tensor(out, a.sizes());
8889
ET_KERNEL_CHECK_MSG(
8990
ctx,
@@ -186,12 +187,12 @@ Tensor& opt_add_out(
186187
InvalidArgument,
187188
out);
188189

189-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
190-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
190+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
191+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
191192
using CTYPE_IN = typename torch::executor::
192193
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
193194
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
194-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
195+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
195196
CTYPE_IN alpha_val;
196197
ET_KERNEL_CHECK(
197198
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
@@ -226,7 +227,7 @@ Tensor& opt_add_scalar_out(
226227

227228
ET_CHECK(common_type == out_type);
228229

229-
if (common_type == ScalarType::Half) {
230+
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
230231
common_type = ScalarType::Float;
231232
}
232233

@@ -235,7 +236,7 @@ Tensor& opt_add_scalar_out(
235236
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
236237

237238
if (a_type == common_type && a_type == out_type &&
238-
a_type != ScalarType::Half) {
239+
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
239240
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() {
240241
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
241242
CTYPE_B b_val;
@@ -255,11 +256,11 @@ Tensor& opt_add_scalar_out(
255256
});
256257
});
257258
} else {
258-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
259+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
259260
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
260261
ET_SWITCH_REALB_TYPES(
261262
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
262-
ET_SWITCH_REALHB_TYPES(
263+
ET_SWITCH_REALHBBF16_TYPES(
263264
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
264265
CTYPE_B b_val;
265266
ET_EXTRACT_SCALAR(b, b_val);

kernels/portable/cpu/op_add.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ Tensor& add_out(
7878
InvalidArgument,
7979
out);
8080

81-
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
81+
ET_KERNEL_CHECK(
82+
ctx,
83+
executorch::runtime::tensor_is_realhbbf16_type(out),
84+
InvalidArgument,
85+
out);
8286
ET_KERNEL_CHECK(
8387
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
8488

@@ -94,15 +98,15 @@ Tensor& add_out(
9498

9599
constexpr auto name = "add.out";
96100

97-
ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
98-
ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
101+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
102+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
99103
using CTYPE_IN = typename torch::executor::
100104
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
101105
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
102106
CTYPE_IN alpha_val;
103107
utils::extract_scalar(alpha, &alpha_val);
104108

105-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
109+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
106110
AddInner<
107111
can_cast<CTYPE_IN, CTYPE_OUT>::value,
108112
CTYPE_A,
@@ -132,7 +136,11 @@ Tensor& add_scalar_out(
132136
out,
133137
"Failed to resize output tensor.");
134138

135-
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
139+
ET_KERNEL_CHECK(
140+
ctx,
141+
executorch::runtime::tensor_is_realhbbf16_type(out),
142+
InvalidArgument,
143+
out);
136144
ET_KERNEL_CHECK(
137145
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
138146

@@ -153,7 +161,7 @@ Tensor& add_scalar_out(
153161

154162
constexpr auto name = "add.Scalar_out";
155163

156-
ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
164+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
157165
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
158166
using CTYPE_IN = typename utils::promote_type_with_scalar_type<
159167
CTYPE_A,

kernels/test/op_add_test.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class OpAddOutKernelTest : public OperatorTest {
5858

5959
template <ScalarType DTYPE_A, ScalarType DTYPE_B>
6060
void test_add_enumerate_out_types() {
61+
test_add<DTYPE_A, DTYPE_B, ScalarType::BFloat16>();
6162
test_add<DTYPE_A, DTYPE_B, ScalarType::Half>();
6263
test_add<DTYPE_A, DTYPE_B, ScalarType::Float>();
6364
test_add<DTYPE_A, DTYPE_B, ScalarType::Double>();
@@ -73,7 +74,7 @@ class OpAddOutKernelTest : public OperatorTest {
7374
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
7475
test_add_enumerate_out_types<DTYPE_A, ScalarType::dtype>();
7576

76-
ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY)
77+
ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)
7778

7879
#undef ENUMERATE_TEST_ENTRY
7980
}
@@ -82,7 +83,7 @@ class OpAddOutKernelTest : public OperatorTest {
8283
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
8384
test_add_enumerate_b_types<ScalarType::dtype>();
8485

85-
ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY)
86+
ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)
8687

8788
#undef ENUMERATE_TEST_ENTRY
8889
}
@@ -99,13 +100,15 @@ class OpAddOutKernelTest : public OperatorTest {
99100

100101
// Add two tensors.
101102
op_add_out(
102-
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}),
103+
tf.make(sizes, /*data=*/{1.25, 2.25, 4.5, 8.875}),
103104
tf.ones(sizes),
104-
/*alpha=*/1.1,
105+
/*alpha=*/1.25,
105106
out);
106107

107-
// Check that it matches the expected output.
108-
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.2, 3.3, 5.5, 9.9}));
108+
// Check that it matches the expected output. Values selected to
109+
// be exactly representable to avoid throwing off half/bfloat16
110+
// tests.
111+
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.5, 3.5, 5.75, 10.125}));
109112
}
110113
};
111114

@@ -136,6 +139,14 @@ TEST_F(OpAddOutKernelTest, DoubleTensors) {
136139
test_floating_point_add_out<ScalarType::Double>();
137140
}
138141

142+
TEST_F(OpAddOutKernelTest, HalfTensors) {
143+
test_floating_point_add_out<ScalarType::Half>();
144+
}
145+
146+
TEST_F(OpAddOutKernelTest, BFloat16Tensors) {
147+
test_floating_point_add_out<ScalarType::BFloat16>();
148+
}
149+
139150
TEST_F(OpAddOutKernelTest, BoolAndIntInputTensor) {
140151
TensorFactory<ScalarType::Bool> tf;
141152
TensorFactory<ScalarType::Int> tfi;

0 commit comments

Comments
 (0)