Skip to content

Commit 50b6ce9

Browse files
committed
[ExecuTorch] Add broadcast support for optimized add op
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e4dea30 Pull Request resolved: #8205
1 parent 87d2f86 commit 50b6ce9

File tree

4 files changed

+186
-49
lines changed

4 files changed

+186
-49
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <executorch/kernels/optimized/vec/functional.h>
12+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314

1415
namespace torch {
@@ -198,7 +199,8 @@ Tensor& handle_last_dim_broadcast_elementwise(
198199
const Tensor& a,
199200
const Tensor& b,
200201
Tensor& out,
201-
const ElementwiseOptimizedPath selected_optimized_path) {
202+
const ElementwiseOptimizedPath selected_optimized_path,
203+
executorch::aten::optional<Scalar>& alpha = {}) {
202204
ScalarType out_type = out.scalar_type();
203205
const Tensor* lhs;
204206
const Tensor* rhs;
@@ -220,8 +222,21 @@ Tensor& handle_last_dim_broadcast_elementwise(
220222
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
221223
const auto broadcast_size = out.size(out.dim() - 1);
222224
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
223-
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
224-
vec_fun,
225+
using Vec = executorch::vec::Vectorized<CTYPE>;
226+
CTYPE alpha_val;
227+
Vec alpha_val_vec(alpha_val);
228+
if (alpha.has_value()) {
229+
ET_KERNEL_CHECK(
230+
ctx,
231+
native::utils::extract_scalar(alpha.value(), &alpha_val),
232+
InvalidArgument, );
233+
alpha_val_vec = Vec(alpha_val);
234+
}
235+
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
236+
return vec_fun(a, b, alpha_val_vec);
237+
};
238+
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
239+
vec_fun_alpha,
225240
out.mutable_data_ptr<CTYPE>(),
226241
lhs->const_data_ptr<CTYPE>(),
227242
rhs->const_data_ptr<CTYPE>(),
@@ -238,13 +253,14 @@ Tensor& handle_broadcast_elementwise(
238253
const Tensor& a,
239254
const Tensor& b,
240255
Tensor& out,
241-
const ElementwiseOptimizedPath selected_optimized_path) {
256+
const ElementwiseOptimizedPath selected_optimized_path,
257+
executorch::aten::optional<Scalar> alpha = {}) {
242258
if ((selected_optimized_path ==
243259
ElementwiseOptimizedPath::kBroadcastLastDim) ||
244260
(selected_optimized_path ==
245261
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
246262
return handle_last_dim_broadcast_elementwise(
247-
ctx, vec_fun, a, b, out, selected_optimized_path);
263+
ctx, vec_fun, a, b, out, selected_optimized_path, alpha);
248264
}
249265

250266
ScalarType out_type = out.scalar_type();
@@ -291,14 +307,28 @@ Tensor& handle_broadcast_elementwise(
291307
inner_size = lhs->sizes()[lhs->dim() - 1];
292308
}
293309
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
294-
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
295-
vec_fun,
296-
out.mutable_data_ptr<CTYPE>(),
297-
lhs->const_data_ptr<CTYPE>(),
298-
rhs->const_data_ptr<CTYPE>(),
299-
outer_size,
300-
broadcast_size,
301-
inner_size);
310+
using Vec = executorch::vec::Vectorized<CTYPE>;
311+
CTYPE alpha_val;
312+
Vec alpha_val_vec;
313+
if (alpha.has_value()) {
314+
ET_KERNEL_CHECK(
315+
ctx,
316+
native::utils::extract_scalar(alpha.value(), &alpha_val),
317+
InvalidArgument, );
318+
alpha_val_vec = Vec(alpha_val);
319+
}
320+
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
321+
return vec_fun(a, b, alpha_val_vec);
322+
};
323+
executorch::vec::
324+
broadcasting_map_3d_and_unsqueezed_3d<CTYPE, decltype(vec_fun_alpha)>(
325+
vec_fun_alpha,
326+
out.mutable_data_ptr<CTYPE>(),
327+
lhs->const_data_ptr<CTYPE>(),
328+
rhs->const_data_ptr<CTYPE>(),
329+
outer_size,
330+
broadcast_size,
331+
inner_size);
302332
});
303333
return out;
304334
}

kernels/optimized/cpu/op_add.cpp

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -140,41 +140,11 @@ Tensor& opt_add_out(
140140
out.numel());
141141
});
142142
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143-
const Tensor* lhs;
144-
const Tensor* rhs;
145-
if (selected_optimized_path ==
146-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
147-
lhs = &b;
148-
rhs = &a;
149-
} else {
150-
// Catch failure to update logic when adding new broadcasting possibility.
151-
ET_DCHECK(
152-
selected_optimized_path ==
153-
ElementwiseOptimizedPath::kBroadcast2dBy1d);
154-
lhs = &a;
155-
rhs = &b;
156-
}
157-
auto error = resize_tensor(out, lhs->sizes());
158-
ET_KERNEL_CHECK_MSG(
159-
ctx,
160-
error == Error::Ok,
161-
InvalidArgument,
162-
out,
163-
"Failed to resize output tensor.");
164-
ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
165-
CTYPE alpha_val;
166-
ET_KERNEL_CHECK(
167-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
168-
169-
using Vec = executorch::vec::Vectorized<CTYPE>;
170-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
171-
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
172-
out.mutable_data_ptr<CTYPE>(),
173-
lhs->const_data_ptr<CTYPE>(),
174-
rhs->const_data_ptr<CTYPE>(),
175-
lhs->sizes()[lhs->dim() - 2],
176-
lhs->sizes()[lhs->dim() - 1]);
177-
});
143+
auto add_lambda = [](auto x, auto y, auto alpha_val) {
144+
return x + alpha_val * y;
145+
};
146+
return torch::executor::handle_broadcast_elementwise(
147+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
178148
} else {
179149
ScalarType common_type =
180150
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/optimized/cpu/op_mul.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,11 @@ Tensor& opt_mul_out(
130130
out.numel());
131131
});
132132
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
133-
auto mul_lambda = [](auto x, auto y) { return x * y; };
133+
// Reason for using alpha:
134+
auto mul_lambda = [](auto x, auto y, auto alpha) {
135+
(void)alpha;
136+
return x * y;
137+
};
134138
return torch::executor::handle_broadcast_elementwise(
135139
ctx, mul_lambda, a, b, out, selected_optimized_path);
136140
} else {

kernels/test/op_add_test.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,122 @@ class OpAddOutKernelTest : public OperatorTest {
112112
// tests.
113113
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.5, 3.5, 5.75, 10.125}));
114114
}
115+
116+
template <ScalarType DTYPE>
117+
void test_broadcast_3D() {
118+
TensorFactory<DTYPE> tf_a;
119+
120+
Tensor a =
121+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
122+
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});
123+
124+
// Destination for output of mul.
125+
Tensor out =
126+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
127+
Tensor expected = tf_a.make(
128+
{2, 2, 3}, /*data=*/{3, 5, 7, 6, 8, 10, 12, 14, 16, 15, 17, 19});
129+
130+
// Check that it matches the expected output.
131+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
132+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
133+
}
134+
135+
template <ScalarType DTYPE>
136+
void test_broadcast_4D() {
137+
TensorFactory<DTYPE> tf_a;
138+
139+
Tensor a = tf_a.make(
140+
{2, 2, 3, 5},
141+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
142+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
143+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
144+
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
145+
Tensor b = tf_a.make(
146+
{2, 1, 3, 5},
147+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
148+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
149+
150+
// Destination for output of mul.
151+
Tensor out = tf_a.zeros({2, 2, 3, 5});
152+
Tensor expected = tf_a.make(
153+
{2, 2, 3, 5},
154+
/*data=*/{2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
155+
17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45,
156+
47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75,
157+
62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90});
158+
159+
// Check that it matches the expected output.
160+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
161+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
162+
163+
b = tf_a.make(
164+
{2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
165+
11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
166+
out = tf_a.zeros({2, 2, 3, 5});
167+
expected = tf_a.make(
168+
{2, 2, 3, 5},
169+
/*data=*/{2, 4, 6, 8, 10, 7, 9, 11, 13, 15, 12, 14, 16, 18, 20,
170+
22, 24, 26, 28, 30, 27, 29, 31, 33, 35, 32, 34, 36, 38, 40,
171+
42, 44, 46, 48, 50, 47, 49, 51, 53, 55, 52, 54, 56, 58, 60,
172+
62, 64, 66, 68, 70, 67, 69, 71, 73, 75, 72, 74, 76, 78, 80});
173+
174+
// Check that it matches the expected output.
175+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
176+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
177+
}
178+
179+
template <ScalarType DTYPE>
180+
void test_broadcast_last_dim() {
181+
TensorFactory<DTYPE> tf_a;
182+
183+
Tensor a =
184+
tf_a.make({4, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
185+
Tensor b = tf_a.make({4, 1}, /*data=*/{2, 3, 4, 5});
186+
187+
// Destination for output of mul.
188+
Tensor out = tf_a.zeros({4, 3});
189+
Tensor expected =
190+
tf_a.make({4, 3}, /*data=*/{3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17});
191+
192+
// Check that it matches the expected output.
193+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
194+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
195+
196+
a = tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
197+
b = tf_a.make({2, 2, 1}, /*data=*/{2, 3, 4, 5});
198+
199+
// Destination for output of mul.
200+
out = tf_a.zeros({2, 2, 3});
201+
expected = tf_a.make(
202+
{2, 2, 3}, /*data=*/{3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17});
203+
204+
// Check that it matches the expected output.
205+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
206+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
207+
208+
a = tf_a.make(
209+
{2, 2, 3, 5},
210+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
211+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
212+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
213+
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
214+
b = tf_a.make(
215+
{2, 2, 3, 1},
216+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
217+
218+
// Destination for output of mul.
219+
out = tf_a.zeros({2, 2, 3, 5});
220+
expected = tf_a.make(
221+
{2, 2, 3, 5},
222+
/*data=*/{2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18,
223+
20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36,
224+
38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54,
225+
56, 57, 58, 59, 60, 62, 63, 64, 65, 66, 68, 69, 70, 71, 72});
226+
227+
// Check that it matches the expected output.
228+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
229+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
230+
}
115231
};
116232

117233
class OpAddScalarOutKernelTest : public OperatorTest {
@@ -371,6 +487,23 @@ TEST_F(OpAddOutKernelTest, BroadcastOneElementRank0Tensor) {
371487
EXPECT_TENSOR_EQ(out, ret);
372488
}
373489

490+
TEST_F(OpAddOutKernelTest, BroadcastNDTest) {
491+
// Test 3D tensors
492+
test_broadcast_3D<ScalarType::Float>();
493+
test_broadcast_3D<ScalarType::Half>();
494+
test_broadcast_3D<ScalarType::BFloat16>();
495+
496+
// Test 4D tensors
497+
test_broadcast_4D<ScalarType::Float>();
498+
test_broadcast_4D<ScalarType::Half>();
499+
test_broadcast_4D<ScalarType::BFloat16>();
500+
501+
// Test broadcasting on the last dimension
502+
test_broadcast_last_dim<ScalarType::Float>();
503+
test_broadcast_last_dim<ScalarType::Half>();
504+
test_broadcast_last_dim<ScalarType::BFloat16>();
505+
}
506+
374507
//
375508
// Death Tests
376509
//

0 commit comments

Comments
 (0)