Skip to content

Commit a9d5779

Browse files
committed
[ExecuTorch] Add broadcast support for optimized add op
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c544136 Pull Request resolved: #8205
1 parent 14cc773 commit a9d5779

File tree

5 files changed

+215
-62
lines changed

5 files changed

+215
-62
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <executorch/kernels/optimized/vec/functional.h>
12+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314

1415
namespace torch {
@@ -191,14 +192,15 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
191192
return normalized_tensor_size;
192193
}
193194

194-
template <typename Op>
195+
template <const char* op_name, typename Op>
195196
Tensor& handle_last_dim_broadcast_elementwise(
196197
KernelRuntimeContext& ctx,
197198
const Op& vec_fun,
198199
const Tensor& a,
199200
const Tensor& b,
200201
Tensor& out,
201-
const ElementwiseOptimizedPath selected_optimized_path) {
202+
const ElementwiseOptimizedPath selected_optimized_path,
203+
const executorch::aten::optional<Scalar>& alpha = {}) {
202204
ScalarType out_type = out.scalar_type();
203205
const Tensor* lhs;
204206
const Tensor* rhs;
@@ -219,9 +221,22 @@ Tensor& handle_last_dim_broadcast_elementwise(
219221
"Failed to resize output tensor.");
220222
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
221223
const auto broadcast_size = out.size(out.dim() - 1);
222-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
223-
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
224-
vec_fun,
224+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
225+
using Vec = executorch::vec::Vectorized<CTYPE>;
226+
Vec alpha_val_vec;
227+
if (alpha.has_value()) {
228+
CTYPE alpha_val;
229+
ET_KERNEL_CHECK(
230+
ctx,
231+
native::utils::extract_scalar(alpha.value(), &alpha_val),
232+
InvalidArgument, );
233+
alpha_val_vec = Vec(alpha_val);
234+
}
235+
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
236+
return vec_fun(a, b, alpha_val_vec);
237+
};
238+
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
239+
vec_fun_alpha,
225240
out.mutable_data_ptr<CTYPE>(),
226241
lhs->const_data_ptr<CTYPE>(),
227242
rhs->const_data_ptr<CTYPE>(),
@@ -231,20 +246,21 @@ Tensor& handle_last_dim_broadcast_elementwise(
231246
return out;
232247
}
233248

234-
template <typename Op>
249+
template <const char* op_name, typename Op>
235250
Tensor& handle_broadcast_elementwise(
236251
KernelRuntimeContext& ctx,
237252
const Op& vec_fun,
238253
const Tensor& a,
239254
const Tensor& b,
240255
Tensor& out,
241-
const ElementwiseOptimizedPath selected_optimized_path) {
256+
const ElementwiseOptimizedPath selected_optimized_path,
257+
const executorch::aten::optional<Scalar>& alpha = {}) {
242258
if ((selected_optimized_path ==
243259
ElementwiseOptimizedPath::kBroadcastLastDim) ||
244260
(selected_optimized_path ==
245261
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
246-
return handle_last_dim_broadcast_elementwise(
247-
ctx, vec_fun, a, b, out, selected_optimized_path);
262+
return handle_last_dim_broadcast_elementwise<op_name>(
263+
ctx, vec_fun, a, b, out, selected_optimized_path, alpha);
248264
}
249265

250266
ScalarType out_type = out.scalar_type();
@@ -290,15 +306,29 @@ Tensor& handle_broadcast_elementwise(
290306
broadcast_size = lhs->sizes()[lhs->dim() - 2];
291307
inner_size = lhs->sizes()[lhs->dim() - 1];
292308
}
293-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
294-
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
295-
vec_fun,
296-
out.mutable_data_ptr<CTYPE>(),
297-
lhs->const_data_ptr<CTYPE>(),
298-
rhs->const_data_ptr<CTYPE>(),
299-
outer_size,
300-
broadcast_size,
301-
inner_size);
309+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
310+
using Vec = executorch::vec::Vectorized<CTYPE>;
311+
Vec alpha_val_vec;
312+
if (alpha.has_value()) {
313+
CTYPE alpha_val;
314+
ET_KERNEL_CHECK(
315+
ctx,
316+
native::utils::extract_scalar(alpha.value(), &alpha_val),
317+
InvalidArgument, );
318+
alpha_val_vec = Vec(alpha_val);
319+
}
320+
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
321+
return vec_fun(a, b, alpha_val_vec);
322+
};
323+
executorch::vec::
324+
broadcasting_map_3d_and_unsqueezed_3d<CTYPE, decltype(vec_fun_alpha)>(
325+
vec_fun_alpha,
326+
out.mutable_data_ptr<CTYPE>(),
327+
lhs->const_data_ptr<CTYPE>(),
328+
rhs->const_data_ptr<CTYPE>(),
329+
outer_size,
330+
broadcast_size,
331+
inner_size);
302332
});
303333
return out;
304334
}

kernels/optimized/cpu/op_add.cpp

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -140,41 +140,32 @@ Tensor& opt_add_out(
140140
out.numel());
141141
});
142142
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143-
const Tensor* lhs;
144-
const Tensor* rhs;
143+
static constexpr const char op_name[] = "add.out";
145144
if (selected_optimized_path ==
146-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
147-
lhs = &b;
148-
rhs = &a;
145+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
146+
selected_optimized_path ==
147+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
148+
selected_optimized_path ==
149+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
150+
// Reason we swap out args here is because handle_broadcast_elementwise
151+
// handles this selected_optimized_path option a bit differently.
152+
// This should really be resolved in handle_broadcast_elementwise.
153+
// However, the current blocker is that handle_broadcast_elementwise tries
154+
// to be agnostic of op. This should be fixed, likely by moving lambda
155+
// creation to handle_broadcast_elementwise and it be aware of which op is
156+
// being executed.
157+
auto add_lambda = [](auto x, auto y, auto alpha_val) {
158+
return y + alpha_val * x;
159+
};
160+
return torch::executor::handle_broadcast_elementwise<op_name>(
161+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
149162
} else {
150-
// Catch failure to update logic when adding new broadcasting possibility.
151-
ET_DCHECK(
152-
selected_optimized_path ==
153-
ElementwiseOptimizedPath::kBroadcast2dBy1d);
154-
lhs = &a;
155-
rhs = &b;
163+
auto add_lambda = [](auto x, auto y, auto alpha_val) {
164+
return x + alpha_val * y;
165+
};
166+
return torch::executor::handle_broadcast_elementwise<op_name>(
167+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
156168
}
157-
auto error = resize_tensor(out, lhs->sizes());
158-
ET_KERNEL_CHECK_MSG(
159-
ctx,
160-
error == Error::Ok,
161-
InvalidArgument,
162-
out,
163-
"Failed to resize output tensor.");
164-
ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
165-
CTYPE alpha_val;
166-
ET_KERNEL_CHECK(
167-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
168-
169-
using Vec = executorch::vec::Vectorized<CTYPE>;
170-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
171-
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
172-
out.mutable_data_ptr<CTYPE>(),
173-
lhs->const_data_ptr<CTYPE>(),
174-
rhs->const_data_ptr<CTYPE>(),
175-
lhs->sizes()[lhs->dim() - 2],
176-
lhs->sizes()[lhs->dim() - 1]);
177-
});
178169
} else {
179170
ScalarType common_type =
180171
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/optimized/cpu/op_mul.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,14 @@ Tensor& opt_mul_out(
130130
out.numel());
131131
});
132132
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
133-
auto mul_lambda = [](auto x, auto y) { return x * y; };
134-
return torch::executor::handle_broadcast_elementwise(
133+
// Reason for using alpha even when used for mul is becasuse
134+
// handle_broadcast_elementwise is used for add and sub as well
135+
// and it uses alpha.
136+
auto mul_lambda = [](auto x, auto y, [[maybe_unused]] auto alpha) {
137+
return x * y;
138+
};
139+
static constexpr const char op_name[] = "mul.out";
140+
return torch::executor::handle_broadcast_elementwise<op_name>(
135141
ctx, mul_lambda, a, b, out, selected_optimized_path);
136142
} else {
137143
ScalarType common_type =

kernels/test/op_add_test.cpp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,125 @@ class OpAddOutKernelTest : public OperatorTest {
112112
// tests.
113113
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.5, 3.5, 5.75, 10.125}));
114114
}
115+
116+
template <ScalarType DTYPE>
117+
void test_broadcast_3D() {
118+
TensorFactory<DTYPE> tf_a;
119+
120+
Tensor a =
121+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
122+
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});
123+
124+
// Destination for output of mul.
125+
Tensor out =
126+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
127+
Tensor expected = tf_a.make(
128+
{2, 2, 3}, /*data=*/{3, 5, 7, 6, 8, 10, 12, 14, 16, 15, 17, 19});
129+
130+
// Check that it matches the expected output.
131+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
132+
expected = tf_a.make(
133+
{2, 2, 3},
134+
/*data=*/{3.5, 6, 8.5, 8, 10.5, 13, 15.5, 18, 20.5, 20, 22.5, 25});
135+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.5, out), expected);
136+
}
137+
138+
template <ScalarType DTYPE>
139+
void test_broadcast_4D() {
140+
TensorFactory<DTYPE> tf_a;
141+
142+
Tensor a = tf_a.make(
143+
{2, 2, 3, 5},
144+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
145+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
146+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
147+
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
148+
Tensor b = tf_a.make(
149+
{2, 1, 3, 5},
150+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
151+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
152+
153+
// Destination for output of mul.
154+
Tensor out = tf_a.zeros({2, 2, 3, 5});
155+
Tensor expected = tf_a.make(
156+
{2, 2, 3, 5},
157+
/*data=*/{2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
158+
17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45,
159+
47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75,
160+
62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90});
161+
162+
// Check that it matches the expected output.
163+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
164+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
165+
166+
b = tf_a.make(
167+
{2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
168+
11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
169+
out = tf_a.zeros({2, 2, 3, 5});
170+
expected = tf_a.make(
171+
{2, 2, 3, 5},
172+
/*data=*/{2, 4, 6, 8, 10, 7, 9, 11, 13, 15, 12, 14, 16, 18, 20,
173+
22, 24, 26, 28, 30, 27, 29, 31, 33, 35, 32, 34, 36, 38, 40,
174+
42, 44, 46, 48, 50, 47, 49, 51, 53, 55, 52, 54, 56, 58, 60,
175+
62, 64, 66, 68, 70, 67, 69, 71, 73, 75, 72, 74, 76, 78, 80});
176+
177+
// Check that it matches the expected output.
178+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
179+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
180+
}
181+
182+
template <ScalarType DTYPE>
183+
void test_broadcast_last_dim() {
184+
TensorFactory<DTYPE> tf_a;
185+
186+
Tensor a =
187+
tf_a.make({4, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
188+
Tensor b = tf_a.make({4, 1}, /*data=*/{2, 3, 4, 5});
189+
190+
// Destination for output of mul.
191+
Tensor out = tf_a.zeros({4, 3});
192+
Tensor expected =
193+
tf_a.make({4, 3}, /*data=*/{3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17});
194+
195+
// Check that it matches the expected output.
196+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
197+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
198+
199+
a = tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
200+
b = tf_a.make({2, 2, 1}, /*data=*/{2, 3, 4, 5});
201+
202+
// Destination for output of mul.
203+
out = tf_a.zeros({2, 2, 3});
204+
expected = tf_a.make(
205+
{2, 2, 3}, /*data=*/{3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17});
206+
207+
// Check that it matches the expected output.
208+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
209+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
210+
211+
a = tf_a.make(
212+
{2, 2, 3, 5},
213+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
214+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
215+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
216+
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
217+
b = tf_a.make(
218+
{2, 2, 3, 1},
219+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
220+
221+
// Destination for output of mul.
222+
out = tf_a.zeros({2, 2, 3, 5});
223+
expected = tf_a.make(
224+
{2, 2, 3, 5},
225+
/*data=*/{2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18,
226+
20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36,
227+
38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54,
228+
56, 57, 58, 59, 60, 62, 63, 64, 65, 66, 68, 69, 70, 71, 72});
229+
230+
// Check that it matches the expected output.
231+
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
232+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
233+
}
115234
};
116235

117236
class OpAddScalarOutKernelTest : public OperatorTest {
@@ -371,6 +490,23 @@ TEST_F(OpAddOutKernelTest, BroadcastOneElementRank0Tensor) {
371490
EXPECT_TENSOR_EQ(out, ret);
372491
}
373492

493+
TEST_F(OpAddOutKernelTest, BroadcastNDTest) {
494+
// Test 3D tensors
495+
test_broadcast_3D<ScalarType::Float>();
496+
test_broadcast_3D<ScalarType::Half>();
497+
test_broadcast_3D<ScalarType::BFloat16>();
498+
499+
// Test 4D tensors
500+
test_broadcast_4D<ScalarType::Float>();
501+
test_broadcast_4D<ScalarType::Half>();
502+
test_broadcast_4D<ScalarType::BFloat16>();
503+
504+
// Test broadcasting on the last dimension
505+
test_broadcast_last_dim<ScalarType::Float>();
506+
test_broadcast_last_dim<ScalarType::Half>();
507+
test_broadcast_last_dim<ScalarType::BFloat16>();
508+
}
509+
374510
//
375511
// Death Tests
376512
//

kernels/test/op_mul_test.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -417,16 +417,6 @@ TEST_F(OpMulOutTest, BroadcastA2BTest) {
417417
test_broadcast_a2b<ScalarType::Int>();
418418
test_broadcast_a2b<ScalarType::Half>();
419419
test_broadcast_a2b<ScalarType::BFloat16>();
420-
421-
// Test 3D tensors
422-
test_broadcast_3D<ScalarType::Float>();
423-
test_broadcast_3D<ScalarType::Half>();
424-
test_broadcast_3D<ScalarType::BFloat16>();
425-
426-
// Test 4D tensors
427-
test_broadcast_4D<ScalarType::Float>();
428-
test_broadcast_4D<ScalarType::Half>();
429-
test_broadcast_4D<ScalarType::BFloat16>();
430420
}
431421

432422
// Broadcast tensor a's size to tensor b's size

0 commit comments

Comments
 (0)