Skip to content

Commit fae9914

Browse files
committed
[ExecuTorch] Add broadcast support for optimized add op
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 69d90ed Pull Request resolved: #8205
1 parent d729176 commit fae9914

File tree

4 files changed

+171
-42
lines changed

4 files changed

+171
-42
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <executorch/kernels/optimized/vec/functional.h>
12+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314

1415
namespace torch {
@@ -235,7 +236,8 @@ Tensor& handle_broadcast_elementwise(
235236
const Tensor& a,
236237
const Tensor& b,
237238
Tensor& out,
238-
const ElementwiseOptimizedPath selected_optimized_path) {
239+
const ElementwiseOptimizedPath selected_optimized_path,
240+
const executorch::aten::optional<Scalar>& alpha = {}) {
239241
if ((selected_optimized_path ==
240242
ElementwiseOptimizedPath::kBroadcastLastDim) ||
241243
(selected_optimized_path ==

kernels/optimized/cpu/op_add.cpp

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -140,40 +140,41 @@ Tensor& opt_add_out(
140140
out.numel());
141141
});
142142
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143-
const Tensor* lhs;
144-
const Tensor* rhs;
145-
if (selected_optimized_path ==
146-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
147-
lhs = &b;
148-
rhs = &a;
149-
} else {
150-
// Catch failure to update logic when adding new broadcasting possibility.
151-
ET_DCHECK(
152-
selected_optimized_path ==
153-
ElementwiseOptimizedPath::kBroadcast2dBy1d);
154-
lhs = &a;
155-
rhs = &b;
156-
}
157-
auto error = resize_tensor(out, lhs->sizes());
158-
ET_KERNEL_CHECK_MSG(
159-
ctx,
160-
error == Error::Ok,
161-
InvalidArgument,
162-
out,
163-
"Failed to resize output tensor.");
164143
ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
165144
CTYPE alpha_val;
166-
ET_KERNEL_CHECK(
167-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
168-
145+
ET_KERNEL_CHECK_MSG(
146+
ctx,
147+
utils::extract_scalar(alpha, &alpha_val),
148+
InvalidArgument,
149+
out,
150+
"Failed to extract scalar alpha.");
169151
using Vec = executorch::vec::Vectorized<CTYPE>;
170-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
171-
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
172-
out.mutable_data_ptr<CTYPE>(),
173-
lhs->const_data_ptr<CTYPE>(),
174-
rhs->const_data_ptr<CTYPE>(),
175-
lhs->sizes()[lhs->dim() - 2],
176-
lhs->sizes()[lhs->dim() - 1]);
152+
Vec alpha_val_vec(alpha_val);
153+
if (selected_optimized_path ==
154+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
155+
selected_optimized_path ==
156+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
157+
selected_optimized_path ==
158+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
159+
// Reason we swap out args here is because handle_broadcast_elementwise
160+
// handles this selected_optimized_path option a bit differently.
161+
// This should really be resolved in handle_broadcast_elementwise.
162+
// However, the current blocker is that handle_broadcast_elementwise
163+
// tries to be agnostic of op. This should be fixed, likely by moving
164+
// lambda creation to handle_broadcast_elementwise and it be aware of
165+
// which op is being executed.
166+
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
167+
return y + alpha_val_vec * x;
168+
};
169+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
170+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
171+
} else {
172+
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
173+
return x + alpha_val_vec * y;
174+
};
175+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
176+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
177+
}
177178
});
178179
} else {
179180
ScalarType common_type =

kernels/test/op_add_test.cpp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,125 @@ class OpAddOutKernelTest : public OperatorTest {
112112
// tests.
113113
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.5, 3.5, 5.75, 10.125}));
114114
}
115+
116+
template <ScalarType DTYPE>
117+
void test_broadcast_3D() {
118+
TensorFactory<DTYPE> tf_a;
119+
120+
Tensor a =
121+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
122+
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});
123+
124+
// Destination for output of mul.
125+
Tensor out =
126+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
127+
Tensor expected = tf_a.make(
128+
{2, 2, 3}, /*data=*/{3, 5, 7, 6, 8, 10, 12, 14, 16, 15, 17, 19});
129+
130+
// Check that it matches the expected output.
131+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
132+
expected = tf_a.make(
133+
{2, 2, 3},
134+
/*data=*/{3.5, 6, 8.5, 8, 10.5, 13, 15.5, 18, 20.5, 20, 22.5, 25});
135+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.5, out), expected);
136+
}
137+
138+
template <ScalarType DTYPE>
139+
void test_broadcast_4D() {
140+
TensorFactory<DTYPE> tf_a;
141+
142+
Tensor a = tf_a.make(
143+
{2, 2, 3, 5},
144+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
145+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
146+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
147+
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
148+
Tensor b = tf_a.make(
149+
{2, 1, 3, 5},
150+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
151+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
152+
153+
// Destination for output of mul.
154+
Tensor out = tf_a.zeros({2, 2, 3, 5});
155+
Tensor expected = tf_a.make(
156+
{2, 2, 3, 5},
157+
/*data=*/{2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
158+
17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45,
159+
47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75,
160+
62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90});
161+
162+
// Check that it matches the expected output.
163+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
164+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
165+
166+
b = tf_a.make(
167+
{2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
168+
11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
169+
out = tf_a.zeros({2, 2, 3, 5});
170+
expected = tf_a.make(
171+
{2, 2, 3, 5},
172+
/*data=*/{2, 4, 6, 8, 10, 7, 9, 11, 13, 15, 12, 14, 16, 18, 20,
173+
22, 24, 26, 28, 30, 27, 29, 31, 33, 35, 32, 34, 36, 38, 40,
174+
42, 44, 46, 48, 50, 47, 49, 51, 53, 55, 52, 54, 56, 58, 60,
175+
62, 64, 66, 68, 70, 67, 69, 71, 73, 75, 72, 74, 76, 78, 80});
176+
177+
// Check that it matches the expected output.
178+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
179+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
180+
}
181+
182+
template <ScalarType DTYPE>
183+
void test_broadcast_last_dim() {
184+
TensorFactory<DTYPE> tf_a;
185+
186+
Tensor a =
187+
tf_a.make({4, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
188+
Tensor b = tf_a.make({4, 1}, /*data=*/{2, 3, 4, 5});
189+
190+
// Destination for output of mul.
191+
Tensor out = tf_a.zeros({4, 3});
192+
Tensor expected =
193+
tf_a.make({4, 3}, /*data=*/{3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17});
194+
195+
// Check that it matches the expected output.
196+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
197+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
198+
199+
a = tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
200+
b = tf_a.make({2, 2, 1}, /*data=*/{2, 3, 4, 5});
201+
202+
// Destination for output of mul.
203+
out = tf_a.zeros({2, 2, 3});
204+
expected = tf_a.make(
205+
{2, 2, 3}, /*data=*/{3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17});
206+
207+
// Check that it matches the expected output.
208+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
209+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
210+
211+
a = tf_a.make(
212+
{2, 2, 3, 5},
213+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
214+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
215+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
216+
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
217+
b = tf_a.make(
218+
{2, 2, 3, 1},
219+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
220+
221+
// Destination for output of mul.
222+
out = tf_a.zeros({2, 2, 3, 5});
223+
expected = tf_a.make(
224+
{2, 2, 3, 5},
225+
/*data=*/{2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18,
226+
20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36,
227+
38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54,
228+
56, 57, 58, 59, 60, 62, 63, 64, 65, 66, 68, 69, 70, 71, 72});
229+
230+
// Check that it matches the expected output.
231+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
232+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
233+
}
115234
};
116235

117236
class OpAddScalarOutKernelTest : public OperatorTest {
@@ -371,6 +490,23 @@ TEST_F(OpAddOutKernelTest, BroadcastOneElementRank0Tensor) {
371490
EXPECT_TENSOR_EQ(out, ret);
372491
}
373492

493+
TEST_F(OpAddOutKernelTest, BroadcastNDTest) {
494+
// Test 3D tensors
495+
test_broadcast_3D<ScalarType::Float>();
496+
test_broadcast_3D<ScalarType::Half>();
497+
test_broadcast_3D<ScalarType::BFloat16>();
498+
499+
// Test 4D tensors
500+
test_broadcast_4D<ScalarType::Float>();
501+
test_broadcast_4D<ScalarType::Half>();
502+
test_broadcast_4D<ScalarType::BFloat16>();
503+
504+
// Test broadcasting on the last dimension
505+
test_broadcast_last_dim<ScalarType::Float>();
506+
test_broadcast_last_dim<ScalarType::Half>();
507+
test_broadcast_last_dim<ScalarType::BFloat16>();
508+
}
509+
374510
//
375511
// Death Tests
376512
//

kernels/test/op_mul_test.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -417,16 +417,6 @@ TEST_F(OpMulOutTest, BroadcastA2BTest) {
417417
test_broadcast_a2b<ScalarType::Int>();
418418
test_broadcast_a2b<ScalarType::Half>();
419419
test_broadcast_a2b<ScalarType::BFloat16>();
420-
421-
// Test 3D tensors
422-
test_broadcast_3D<ScalarType::Float>();
423-
test_broadcast_3D<ScalarType::Half>();
424-
test_broadcast_3D<ScalarType::BFloat16>();
425-
426-
// Test 4D tensors
427-
test_broadcast_4D<ScalarType::Float>();
428-
test_broadcast_4D<ScalarType::Half>();
429-
test_broadcast_4D<ScalarType::BFloat16>();
430420
}
431421

432422
// Broadcast tensor a's size to tensor b's size

0 commit comments

Comments
 (0)