Skip to content

Commit fd6a590

Browse files
authored
[ExecuTorch] Handle rank 0 tensors correctly in optimized add/sub/div/mul
Differential Revision: D62310838 Pull Request resolved: #5140
1 parent 2863536 commit fd6a590

File tree

8 files changed

+92
-20
lines changed

8 files changed

+92
-20
lines changed

kernels/optimized/cpu/op_add.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,12 @@ Tensor& opt_add_out(
8585
if (b.numel() == 1) {
8686
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
8787
a_type != ScalarType::BFloat16) {
88-
auto error = resize_tensor(out, a.sizes());
89-
ET_KERNEL_CHECK_MSG(
88+
ET_KERNEL_CHECK(
9089
ctx,
91-
error == Error::Ok,
90+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
9291
InvalidArgument,
93-
out,
94-
"Failed to resize output tensor.");
92+
out);
93+
9594
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
9695
ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
9796
CTYPE alpha_val;

kernels/optimized/cpu/op_div.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,11 @@ Tensor& opt_div_out(
6666
scalar = &b;
6767
scalar_type = b_type;
6868
}
69-
auto error = resize_tensor(out, tensor->sizes());
70-
ET_KERNEL_CHECK_MSG(
69+
ET_KERNEL_CHECK(
7170
ctx,
72-
error == Error::Ok,
71+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
7372
InvalidArgument,
74-
out,
75-
"Failed to resize output tensor.");
73+
out);
7674
ET_SWITCH_REALB_TYPES(tensor_type, ctx, "div.out", CTYPE, [&]() {
7775
ET_SWITCH_REALB_TYPES(scalar_type, ctx, "div.out", CTYPE_SCALAR, [&]() {
7876
CTYPE_SCALAR scalar_val = *scalar->const_data_ptr<CTYPE_SCALAR>();

kernels/optimized/cpu/op_mul.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,12 @@ Tensor& opt_mul_out(
8282
if (b.numel() == 1) {
8383
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
8484
a_type != ScalarType::BFloat16) {
85-
auto error = resize_tensor(out, a.sizes());
86-
ET_KERNEL_CHECK_MSG(
85+
ET_KERNEL_CHECK(
8786
ctx,
88-
error == Error::Ok,
87+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
8988
InvalidArgument,
90-
out,
91-
"Failed to resize output tensor.");
89+
out);
90+
9291
ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.out", CTYPE, [&]() {
9392
ET_SWITCH_REALB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
9493
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();

kernels/optimized/cpu/op_sub.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,11 @@ Tensor& opt_sub_out(
101101
scalar = &b;
102102
scalar_type = b_type;
103103
}
104-
auto error = resize_tensor(out, tensor->sizes());
105-
ET_KERNEL_CHECK_MSG(
104+
ET_KERNEL_CHECK(
106105
ctx,
107-
error == Error::Ok,
106+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
108107
InvalidArgument,
109-
out,
110-
"Failed to resize output tensor.");
108+
out);
111109
ET_SWITCH_REAL_TYPES(tensor_type, ctx, "sub.out", CTYPE, [&]() {
112110
ET_SWITCH_REAL_TYPES(scalar_type, ctx, "sub.out", CTYPE_SCALAR, [&]() {
113111
CTYPE alpha_val;

kernels/test/op_add_test.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,23 @@ TEST_F(OpAddOutKernelTest, BroadcastOneElementTensorTypePromotion) {
352352
EXPECT_TENSOR_EQ(out, expected);
353353
}
354354

355+
TEST_F(OpAddOutKernelTest, BroadcastOneElementRank0Tensor) {
356+
TensorFactory<ScalarType::Float> tf;
357+
358+
Tensor a = tf.make({1}, {5});
359+
Tensor b = tf.make({}, {2});
360+
361+
Tensor out = tf.zeros({1});
362+
363+
op_add_out(a, b, 1, out);
364+
365+
Tensor ret = tf.make({1}, {7});
366+
EXPECT_TENSOR_EQ(out, ret);
367+
368+
op_add_out(b, a, 1, out);
369+
EXPECT_TENSOR_EQ(out, ret);
370+
}
371+
355372
//
356373
// Death Tests
357374
//

kernels/test/op_div_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,25 @@ TEST_F(OpDivOutTest, BroadcastScalarSupported2) {
237237
EXPECT_TENSOR_EQ(out, ret);
238238
}
239239

240+
TEST_F(OpDivOutTest, BroadcastScalarRank0Supported) {
241+
TensorFactory<ScalarType::Float> tf;
242+
243+
Tensor a = tf.make({1}, {8});
244+
Tensor b = tf.make({}, {2});
245+
246+
Tensor out = tf.zeros({1});
247+
248+
op_div_out(a, b, out);
249+
250+
Tensor ret = tf.make({1}, {4});
251+
EXPECT_TENSOR_EQ(out, ret);
252+
253+
op_div_out(b, a, out);
254+
255+
ret = tf.make({1}, {0.25});
256+
EXPECT_TENSOR_EQ(out, ret);
257+
}
258+
240259
TEST_F(OpDivOutTest, BroadcastDimSizeIsOneAB) {
241260
TensorFactory<ScalarType::Float> tf;
242261

kernels/test/op_mul_test.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,23 @@ class OpMulOutTest : public OperatorTest {
182182
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
183183
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
184184
}
185+
186+
template <ScalarType DTYPE>
187+
void test_both_scalar_input_broadcast() {
188+
TensorFactory<DTYPE> tf_a;
189+
190+
// a is a rank-1 scalar and b is a rank-0 scalar
191+
Tensor a = tf_a.make({1}, /*data=*/{2});
192+
Tensor b = tf_a.make({}, /*data=*/{2});
193+
194+
// Destination for output of mul.
195+
Tensor out = tf_a.make({1}, /*data=*/{2});
196+
Tensor expected = tf_a.make({1}, /*data=*/{4});
197+
198+
// Check that it matches the expected output.
199+
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
200+
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
201+
}
185202
};
186203

187204
class OpMulScalarOutTest : public OperatorTest {
@@ -309,6 +326,12 @@ TEST_F(OpMulOutTest, ScalarInputBroadcastTest) {
309326
test_scalar_input_broadcast<ScalarType::BFloat16>();
310327
}
311328

329+
TEST_F(OpMulOutTest, BothScalarInputBroadcastTest) {
330+
test_both_scalar_input_broadcast<ScalarType::Int>();
331+
test_both_scalar_input_broadcast<ScalarType::Half>();
332+
test_both_scalar_input_broadcast<ScalarType::BFloat16>();
333+
}
334+
312335
TEST_F(OpMulOutTest, MismatchedOutputShapesDies) {
313336
if (SupportedFeatures::get()->is_aten) {
314337
GTEST_SKIP() << "ATen currently supports mismatched shapes";

kernels/test/op_sub_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,25 @@ TEST_F(OpSubOutTest, BroadcastScalarSupported2) {
206206
EXPECT_TENSOR_EQ(out, ret);
207207
}
208208

209+
TEST_F(OpSubOutTest, BroadcastScalarRank0Supported) {
210+
TensorFactory<ScalarType::Float> tf;
211+
212+
Tensor a = tf.make({1}, {5});
213+
Tensor b = tf.make({}, {2});
214+
215+
Tensor out = tf.zeros({1});
216+
217+
op_sub_out(a, b, 1, out);
218+
219+
Tensor ret = tf.make({1}, {3});
220+
EXPECT_TENSOR_EQ(out, ret);
221+
222+
op_sub_out(b, a, 1, out);
223+
224+
ret = tf.make({1}, {-3});
225+
EXPECT_TENSOR_EQ(out, ret);
226+
}
227+
209228
//
210229
// Death Tests
211230
//

0 commit comments

Comments
 (0)