Skip to content

Commit 10288a2

Browse files
authored
[ExecuTorch] support BF16 in op_mul
Differential Revision: D61981355 Pull Request resolved: #4977
1 parent c9ac212 commit 10288a2

File tree

6 files changed

+130
-77
lines changed

6 files changed

+130
-77
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ ElementwiseOptimizedPath inline select_optimized_path(
7575
ScalarType b_type = b.scalar_type();
7676
ScalarType out_type = out.scalar_type();
7777

78-
if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half) {
78+
if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half ||
79+
a_type == ScalarType::BFloat16) {
7980
return ElementwiseOptimizedPath::kNone;
8081
}
8182
if (a.sizes().equals(b.sizes()) ||

kernels/optimized/cpu/op_mul.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ Tensor& opt_mul_out(
8080
ScalarType out_type = out.scalar_type();
8181

8282
if (b.numel() == 1) {
83-
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
83+
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
84+
a_type != ScalarType::BFloat16) {
8485
auto error = resize_tensor(out, a.sizes());
8586
ET_KERNEL_CHECK_MSG(
8687
ctx,
@@ -170,12 +171,12 @@ Tensor& opt_mul_out(
170171
InvalidArgument,
171172
out);
172173

173-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
174-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
174+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
175+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
175176
using CTYPE_IN = typename torch::executor::
176177
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
177178
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
178-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
179+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
179180
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
180181
[](const CTYPE_A val_a, const CTYPE_B val_b) {
181182
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
@@ -210,7 +211,7 @@ Tensor& opt_mul_scalar_out(
210211

211212
ET_CHECK(common_type == out_type);
212213

213-
if (common_type == ScalarType::Half) {
214+
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
214215
common_type = ScalarType::Float;
215216
}
216217

@@ -219,7 +220,7 @@ Tensor& opt_mul_scalar_out(
219220
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
220221

221222
if (a_type == common_type && a_type == out_type &&
222-
a_type != ScalarType::Half) {
223+
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
223224
ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() {
224225
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
225226
CTYPE_B b_val;
@@ -235,11 +236,11 @@ Tensor& opt_mul_scalar_out(
235236
});
236237
});
237238
} else {
238-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
239+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
239240
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
240241
ET_SWITCH_REALB_TYPES(
241242
common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
242-
ET_SWITCH_REALHB_TYPES(
243+
ET_SWITCH_REALHBBF16_TYPES(
243244
out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
244245
CTYPE_B b_val;
245246
ET_EXTRACT_SCALAR(b, b_val);

kernels/portable/cpu/op_mul.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
7070
InvalidArgument,
7171
out);
7272

73-
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
73+
ET_KERNEL_CHECK(
74+
ctx,
75+
executorch::runtime::tensor_is_realhbbf16_type(out),
76+
InvalidArgument,
77+
out);
7478

7579
ET_KERNEL_CHECK(
7680
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
@@ -82,12 +86,12 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
8286

8387
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
8488

85-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
86-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
89+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
90+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
8791
using CTYPE_IN = typename torch::executor::
8892
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
8993
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
90-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
94+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
9195
MulInner<
9296
can_cast<CTYPE_IN, CTYPE_OUT>::value,
9397
CTYPE_A,
@@ -129,15 +133,15 @@ Tensor& mul_scalar_out(
129133

130134
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
131135

132-
if (common_type == ScalarType::Half) {
136+
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
133137
common_type = ScalarType::Float;
134138
}
135139

136-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
140+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
137141
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
138142
ET_SWITCH_REALB_TYPES(
139143
common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
140-
ET_SWITCH_REALHB_TYPES(
144+
ET_SWITCH_REALHBBF16_TYPES(
141145
out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
142146
CTYPE_B b_val;
143147
utils::extract_scalar(b, &b_val);

kernels/test/op_mul_test.cpp

Lines changed: 98 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class OpMulOutTest : public OperatorTest {
7272
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
7373
test_mul_enumerate_out_types<DTYPE_A, ScalarType::dtype>();
7474

75-
ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY)
75+
ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)
7676

7777
#undef ENUMERATE_TEST_ENTRY
7878
}
@@ -89,29 +89,99 @@ class OpMulOutTest : public OperatorTest {
8989

9090
// Multiply two tensors
9191
op_mul_out(
92-
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.ones(sizes), out);
93-
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}));
92+
tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}), tf.ones(sizes), out);
93+
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}));
9494

9595
op_mul_out(
9696
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.zeros(sizes), out);
9797
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{0.0, 0.0, 0.0, 0.0}));
9898

9999
op_mul_out(
100-
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}),
101-
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}),
100+
tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}),
101+
tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}),
102102
out);
103103
EXPECT_TENSOR_CLOSE(
104-
out, tf.make(sizes, /*data=*/{1.21, 4.84, 19.36, 77.44}));
104+
out, tf.make(sizes, /*data=*/{1.5625, 6.25, 22.5625, 78.765625}));
105105
}
106106

107107
void test_mul_enumerate_a_types() {
108108
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
109109
test_mul_enumerate_b_types<ScalarType::dtype>();
110110

111-
ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY)
111+
ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)
112112

113113
#undef ENUMERATE_TEST_ENTRY
114114
}
115+
116+
template <ScalarType DTYPE>
117+
void test_optimized_path_ignores_leading_1_dimensions() {
118+
TensorFactory<DTYPE> tf;
119+
120+
const std::vector<int32_t> sizes1 = {1, 1, 2, 2};
121+
const std::vector<int32_t> sizes2 = {1, 2, 2};
122+
123+
// Destination for the mul.
124+
Tensor out = tf.zeros(sizes1);
125+
126+
// Multiply two tensors
127+
op_mul_out(
128+
tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.ones(sizes2), out);
129+
EXPECT_TENSOR_CLOSE(out, tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}));
130+
}
131+
132+
template <ScalarType DTYPE>
133+
void test_broadcast_a2b() {
134+
TensorFactory<DTYPE> tf_a;
135+
136+
std::vector<std::vector<int32_t>> b_sizeses = {
137+
{2},
138+
{1, 2},
139+
};
140+
for (const auto& b_sizes : b_sizeses) {
141+
// a and b of different shapes
142+
Tensor a = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});
143+
Tensor b = tf_a.make(b_sizes, /*data=*/{2, 2});
144+
145+
// Destination for output of mul.
146+
Tensor out = tf_a.zeros({2, 2});
147+
148+
// Check that it matches the expected output.
149+
EXPECT_TENSOR_CLOSE(
150+
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
151+
}
152+
}
153+
154+
template <ScalarType DTYPE>
155+
void test_broadcast_b2a() {
156+
TensorFactory<DTYPE> tf_a;
157+
// a and b of different shapes
158+
Tensor a = tf_a.make({2}, /*data=*/{2, 2});
159+
Tensor b = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});
160+
161+
// Destination for output of mul.
162+
Tensor out = tf_a.zeros({2, 2});
163+
164+
// Check that it matches the expected output.
165+
EXPECT_TENSOR_CLOSE(
166+
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
167+
}
168+
169+
template <ScalarType DTYPE>
170+
void test_scalar_input_broadcast() {
171+
TensorFactory<DTYPE> tf_a;
172+
173+
// a is a 1d tensor and b is a scalar
174+
Tensor a = tf_a.make({2}, /*data=*/{2, 2});
175+
Tensor b = tf_a.make({}, /*data=*/{2});
176+
177+
// Destination for output of mul.
178+
Tensor out = tf_a.make({2}, /*data=*/{2, 2});
179+
Tensor expected = tf_a.make({2}, /*data=*/{4, 4});
180+
181+
// Check that it matches the expected output.
182+
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
183+
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
184+
}
115185
};
116186

117187
class OpMulScalarOutTest : public OperatorTest {
@@ -141,6 +211,14 @@ TEST_F(OpMulOutTest, DoubleTensors) {
141211
test_floating_point_mul_out<ScalarType::Double>();
142212
}
143213

214+
TEST_F(OpMulOutTest, HalfTensors) {
215+
test_floating_point_mul_out<ScalarType::Half>();
216+
}
217+
218+
TEST_F(OpMulOutTest, BFloat16Tensors) {
219+
test_floating_point_mul_out<ScalarType::BFloat16>();
220+
}
221+
144222
TEST_F(OpMulOutTest, BoolTensors) {
145223
TensorFactory<ScalarType::Bool> tf;
146224

@@ -166,18 +244,12 @@ TEST_F(OpMulOutTest, BoolTensors) {
166244
}
167245

168246
TEST_F(OpMulOutTest, OptimizedPathIgnoresLeading1Dimensions) {
169-
TensorFactory<ScalarType::Float> tf;
247+
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
248+
test_optimized_path_ignores_leading_1_dimensions<ScalarType::dtype>();
170249

171-
const std::vector<int32_t> sizes1 = {1, 1, 2, 2};
172-
const std::vector<int32_t> sizes2 = {1, 2, 2};
250+
ET_FORALL_FLOATHBF16_TYPES(ENUMERATE_TEST_ENTRY);
173251

174-
// Destination for the mul.
175-
Tensor out = tf.zeros(sizes1);
176-
177-
// Multiply two tensors
178-
op_mul_out(
179-
tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.ones(sizes2), out);
180-
EXPECT_TENSOR_CLOSE(out, tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}));
252+
#undef ENUMERATE_TEST_ENTRY
181253
}
182254

183255
// Mismatched shape tests.
@@ -202,40 +274,16 @@ TEST_F(OpMulOutTest, MismatchedNonBroadcastableInputShapesDies) {
202274

203275
// Broadcast tensor b's size to tensor a's size
204276
TEST_F(OpMulOutTest, BroadcastA2BTest) {
205-
TensorFactory<ScalarType::Int> tf_a;
206-
207-
std::vector<std::vector<int32_t>> b_sizeses = {
208-
{2},
209-
{1, 2},
210-
};
211-
for (const auto& b_sizes : b_sizeses) {
212-
// a and b of different shapes
213-
Tensor a = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});
214-
Tensor b = tf_a.make(b_sizes, /*data=*/{2, 2});
215-
216-
// Destination for output of mul.
217-
Tensor out = tf_a.zeros({2, 2});
218-
219-
// Check that it matches the expected output.
220-
EXPECT_TENSOR_CLOSE(
221-
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
222-
}
277+
test_broadcast_a2b<ScalarType::Int>();
278+
test_broadcast_a2b<ScalarType::Half>();
279+
test_broadcast_a2b<ScalarType::BFloat16>();
223280
}
224281

225282
// Broadcast tensor a's size to tensor b's size
226283
TEST_F(OpMulOutTest, BroadcastB2ATest) {
227-
TensorFactory<ScalarType::Int> tf_a;
228-
229-
// a and b of different shapes
230-
Tensor a = tf_a.make({2}, /*data=*/{2, 2});
231-
Tensor b = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4});
232-
233-
// Destination for output of mul.
234-
Tensor out = tf_a.zeros({2, 2});
235-
236-
// Check that it matches the expected output.
237-
EXPECT_TENSOR_CLOSE(
238-
op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8}));
284+
test_broadcast_b2a<ScalarType::Int>();
285+
test_broadcast_b2a<ScalarType::Half>();
286+
test_broadcast_b2a<ScalarType::BFloat16>();
239287
}
240288

241289
// Broadcast tensor a and b's size to a new size c.
@@ -256,19 +304,9 @@ TEST_F(OpMulOutTest, BroadcastAB2CTest) {
256304
}
257305

258306
TEST_F(OpMulOutTest, ScalarInputBroadcastTest) {
259-
TensorFactory<ScalarType::Int> tf_a;
260-
261-
// a is a 1d tensor and b is a scalar
262-
Tensor a = tf_a.make({2}, /*data=*/{2, 2});
263-
Tensor b = tf_a.make({}, /*data=*/{2});
264-
265-
// Destination for output of mul.
266-
Tensor out = tf_a.make({2}, /*data=*/{2, 2});
267-
Tensor expected = tf_a.make({2}, /*data=*/{4, 4});
268-
269-
// Check that it matches the expected output.
270-
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
271-
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
307+
test_scalar_input_broadcast<ScalarType::Int>();
308+
test_scalar_input_broadcast<ScalarType::Half>();
309+
test_scalar_input_broadcast<ScalarType::BFloat16>();
272310
}
273311

274312
TEST_F(OpMulOutTest, MismatchedOutputShapesDies) {

runtime/core/exec_aten/testing_util/tensor_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ std::ostream& operator<<(std::ostream& os, const Tensor& t) {
283283
break;
284284

285285
switch (t.scalar_type()) {
286-
ET_FORALL_REAL_TYPES_AND2(Half, Bool, PRINT_CASE)
286+
ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, PRINT_CASE)
287287
default:
288288
ET_CHECK_MSG(
289289
false,

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,15 @@ inline bool tensor_is_realhb_type(exec_aten::Tensor t) {
516516
return true;
517517
}
518518

519+
inline bool tensor_is_realhbbf16_type(exec_aten::Tensor t) {
520+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
521+
executorch::runtime::isRealHBBF16Type(t.scalar_type()),
522+
"Expected to find a real type, but tensor has type %s",
523+
torch::executor::toString(t.scalar_type()));
524+
525+
return true;
526+
}
527+
519528
inline bool tensor_is_complex_type(exec_aten::Tensor t) {
520529
ET_LOG_MSG_AND_RETURN_IF_FALSE(
521530
torch::executor::isComplexType(t.scalar_type()),

0 commit comments

Comments
 (0)