Skip to content

Commit 6d3bf82

Browse files
authored
[Executorch] Refactor op_mul's broadcasting utils (#8204)
* [Executorch] Refactor op_mul's broadcasting utils Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Update base for Update on "[Executorch] Refactor op_mul's broadcasting utils" Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned] * Update base for Update on "[Executorch] Refactor op_mul's broadcasting utils" Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491816](https://our.internmc.facebook.com/intern/diff/D69491816) [ghstack-poisoned] * Update base for Update on "[Executorch] Refactor op_mul's broadcasting utils" Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491816](https://our.internmc.facebook.com/intern/diff/D69491816) [ghstack-poisoned]
1 parent 8ad15f3 commit 6d3bf82

File tree

3 files changed

+111
-116
lines changed

3 files changed

+111
-116
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include <executorch/kernels/optimized/vec/functional.h>
1112
#include <executorch/runtime/kernel/kernel_includes.h>
1213

1314
namespace torch {
@@ -190,5 +191,110 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
190191
return normalized_tensor_size;
191192
}
192193

194+
template <typename CTYPE, typename Op>
195+
Tensor& handle_last_dim_broadcast_elementwise(
196+
KernelRuntimeContext& ctx,
197+
const Op& vec_fun,
198+
const Tensor& a,
199+
const Tensor& b,
200+
Tensor& out,
201+
const ElementwiseOptimizedPath selected_optimized_path) {
202+
const Tensor* lhs;
203+
const Tensor* rhs;
204+
if (selected_optimized_path ==
205+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
206+
lhs = &b;
207+
rhs = &a;
208+
} else {
209+
lhs = &a;
210+
rhs = &b;
211+
}
212+
auto error = resize_tensor(out, lhs->sizes());
213+
ET_KERNEL_CHECK_MSG(
214+
ctx,
215+
error == Error::Ok,
216+
InvalidArgument,
217+
out,
218+
"Failed to resize output tensor.");
219+
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
220+
const auto broadcast_size = out.size(out.dim() - 1);
221+
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
222+
vec_fun,
223+
out.mutable_data_ptr<CTYPE>(),
224+
lhs->const_data_ptr<CTYPE>(),
225+
rhs->const_data_ptr<CTYPE>(),
226+
outer_size,
227+
broadcast_size);
228+
return out;
229+
}
230+
231+
template <typename CTYPE, typename Op>
232+
Tensor& handle_broadcast_elementwise(
233+
KernelRuntimeContext& ctx,
234+
const Op& vec_fun,
235+
const Tensor& a,
236+
const Tensor& b,
237+
Tensor& out,
238+
const ElementwiseOptimizedPath selected_optimized_path) {
239+
if ((selected_optimized_path ==
240+
ElementwiseOptimizedPath::kBroadcastLastDim) ||
241+
(selected_optimized_path ==
242+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
243+
return handle_last_dim_broadcast_elementwise<CTYPE>(
244+
ctx, vec_fun, a, b, out, selected_optimized_path);
245+
}
246+
247+
const Tensor* lhs;
248+
const Tensor* rhs;
249+
if ((selected_optimized_path ==
250+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
251+
(selected_optimized_path ==
252+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
253+
lhs = &b;
254+
rhs = &a;
255+
} else {
256+
// Catch failure to update logic when adding new broadcasting possibility.
257+
ET_DCHECK(
258+
(selected_optimized_path ==
259+
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
260+
(selected_optimized_path ==
261+
ElementwiseOptimizedPath::kBroadcastNdByNd));
262+
lhs = &a;
263+
rhs = &b;
264+
}
265+
auto error = resize_tensor(out, lhs->sizes());
266+
ET_KERNEL_CHECK_MSG(
267+
ctx,
268+
error == Error::Ok,
269+
InvalidArgument,
270+
out,
271+
"Failed to resize output tensor.");
272+
int64_t outer_size = 1;
273+
int64_t broadcast_size;
274+
int64_t inner_size;
275+
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
276+
(selected_optimized_path ==
277+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
278+
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
279+
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
280+
auto normalized_tensor_size_lhs =
281+
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
282+
outer_size = normalized_tensor_size_lhs[0];
283+
broadcast_size = normalized_tensor_size_lhs[1];
284+
inner_size = normalized_tensor_size_lhs[2];
285+
} else {
286+
broadcast_size = lhs->sizes()[lhs->dim() - 2];
287+
inner_size = lhs->sizes()[lhs->dim() - 1];
288+
}
289+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
290+
vec_fun,
291+
out.mutable_data_ptr<CTYPE>(),
292+
lhs->const_data_ptr<CTYPE>(),
293+
rhs->const_data_ptr<CTYPE>(),
294+
outer_size,
295+
broadcast_size,
296+
inner_size);
297+
return out;
298+
}
193299
} // namespace executor
194300
} // namespace torch

kernels/optimized/cpu/op_mul.cpp

Lines changed: 5 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -68,114 +68,6 @@ template <
6868
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
6969
: public ReportCanCastBug {};
7070

71-
Tensor& handle_last_dim_broadcast(
72-
KernelRuntimeContext& ctx,
73-
const Tensor& a,
74-
const Tensor& b,
75-
Tensor& out,
76-
const ElementwiseOptimizedPath selected_optimized_path) {
77-
ScalarType out_type = out.scalar_type();
78-
const Tensor* lhs;
79-
const Tensor* rhs;
80-
if (selected_optimized_path ==
81-
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
82-
lhs = &b;
83-
rhs = &a;
84-
} else {
85-
lhs = &a;
86-
rhs = &b;
87-
}
88-
auto error = resize_tensor(out, lhs->sizes());
89-
ET_KERNEL_CHECK_MSG(
90-
ctx,
91-
error == Error::Ok,
92-
InvalidArgument,
93-
out,
94-
"Failed to resize output tensor.");
95-
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
96-
const auto broadcast_size = out.size(out.dim() - 1);
97-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
98-
using Vec = executorch::vec::Vectorized<CTYPE>;
99-
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
100-
[](Vec x, Vec y) { return x * y; },
101-
out.mutable_data_ptr<CTYPE>(),
102-
lhs->const_data_ptr<CTYPE>(),
103-
rhs->const_data_ptr<CTYPE>(),
104-
outer_size,
105-
broadcast_size);
106-
});
107-
return out;
108-
}
109-
110-
Tensor& handle_broadcast_mul(
111-
KernelRuntimeContext& ctx,
112-
const Tensor& a,
113-
const Tensor& b,
114-
Tensor& out,
115-
const ElementwiseOptimizedPath selected_optimized_path) {
116-
if ((selected_optimized_path ==
117-
ElementwiseOptimizedPath::kBroadcastLastDim) ||
118-
(selected_optimized_path ==
119-
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
120-
return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path);
121-
}
122-
123-
ScalarType out_type = out.scalar_type();
124-
const Tensor* lhs;
125-
const Tensor* rhs;
126-
if ((selected_optimized_path ==
127-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
128-
(selected_optimized_path ==
129-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
130-
lhs = &b;
131-
rhs = &a;
132-
} else {
133-
// Catch failure to update logic when adding new broadcasting possibility.
134-
ET_DCHECK(
135-
(selected_optimized_path ==
136-
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
137-
(selected_optimized_path ==
138-
ElementwiseOptimizedPath::kBroadcastNdByNd));
139-
lhs = &a;
140-
rhs = &b;
141-
}
142-
auto error = resize_tensor(out, lhs->sizes());
143-
ET_KERNEL_CHECK_MSG(
144-
ctx,
145-
error == Error::Ok,
146-
InvalidArgument,
147-
out,
148-
"Failed to resize output tensor.");
149-
int64_t outer_size = 1;
150-
int64_t broadcast_size;
151-
int64_t inner_size;
152-
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
153-
(selected_optimized_path ==
154-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
155-
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
156-
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
157-
auto normalized_tensor_size_lhs =
158-
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
159-
outer_size = normalized_tensor_size_lhs[0];
160-
broadcast_size = normalized_tensor_size_lhs[1];
161-
inner_size = normalized_tensor_size_lhs[2];
162-
} else {
163-
broadcast_size = lhs->sizes()[lhs->dim() - 2];
164-
inner_size = lhs->sizes()[lhs->dim() - 1];
165-
}
166-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
167-
using Vec = executorch::vec::Vectorized<CTYPE>;
168-
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
169-
[](Vec x, Vec y) { return x * y; },
170-
out.mutable_data_ptr<CTYPE>(),
171-
lhs->const_data_ptr<CTYPE>(),
172-
rhs->const_data_ptr<CTYPE>(),
173-
outer_size,
174-
broadcast_size,
175-
inner_size);
176-
});
177-
return out;
178-
}
17971
} // namespace
18072

18173
Tensor& opt_mul_out(
@@ -238,7 +130,11 @@ Tensor& opt_mul_out(
238130
out.numel());
239131
});
240132
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
241-
return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path);
133+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
134+
auto mul_lambda = [](auto x, auto y) { return x * y; };
135+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
136+
ctx, mul_lambda, a, b, out, selected_optimized_path);
137+
});
242138
} else {
243139
ScalarType common_type =
244140
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/test/op_mul_test.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,6 @@ TEST_F(OpMulOutTest, BroadcastNDTest) {
453453
test_broadcast_last_dim<ScalarType::BFloat16>();
454454
}
455455

456-
TEST_F(OpMulOutTest, BroadcastLastDimTest) {
457-
// Test broadcasting on the last dimension
458-
test_broadcast_last_dim<ScalarType::Float>();
459-
test_broadcast_last_dim<ScalarType::Half>();
460-
test_broadcast_last_dim<ScalarType::BFloat16>();
461-
}
462-
463456
// Broadcast tensor a and b's size to a new size c.
464457
TEST_F(OpMulOutTest, BroadcastAB2CTest) {
465458
TensorFactory<ScalarType::Int> tf_a;

0 commit comments

Comments
 (0)