Skip to content

Commit 6ef10c2

Browse files
committed
[Executorch] Refactor op_mul's broadcasting utils
Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 4336169 Pull Request resolved: #8204
1 parent e8c9ccc commit 6ef10c2

File tree

3 files changed

+113
-116
lines changed

3 files changed

+113
-116
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include <executorch/kernels/optimized/vec/functional.h>
1112
#include <executorch/runtime/kernel/kernel_includes.h>
1213

1314
namespace torch {
@@ -190,5 +191,111 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
190191
return normalized_tensor_size;
191192
}
192193

194+
template <typename CTYPE, typename Op>
195+
Tensor& handle_last_dim_broadcast_elementwise(
196+
KernelRuntimeContext& ctx,
197+
const Op& vec_fun,
198+
const Tensor& a,
199+
const Tensor& b,
200+
Tensor& out,
201+
const ElementwiseOptimizedPath selected_optimized_path) {
202+
ScalarType out_type = out.scalar_type();
203+
const Tensor* lhs;
204+
const Tensor* rhs;
205+
if (selected_optimized_path ==
206+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
207+
lhs = &b;
208+
rhs = &a;
209+
} else {
210+
lhs = &a;
211+
rhs = &b;
212+
}
213+
auto error = resize_tensor(out, lhs->sizes());
214+
ET_KERNEL_CHECK_MSG(
215+
ctx,
216+
error == Error::Ok,
217+
InvalidArgument,
218+
out,
219+
"Failed to resize output tensor.");
220+
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
221+
const auto broadcast_size = out.size(out.dim() - 1);
222+
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
223+
vec_fun,
224+
out.mutable_data_ptr<CTYPE>(),
225+
lhs->const_data_ptr<CTYPE>(),
226+
rhs->const_data_ptr<CTYPE>(),
227+
outer_size,
228+
broadcast_size);
229+
return out;
230+
}
231+
232+
template <typename CTYPE, typename Op>
233+
Tensor& handle_broadcast_elementwise(
234+
KernelRuntimeContext& ctx,
235+
const Op& vec_fun,
236+
const Tensor& a,
237+
const Tensor& b,
238+
Tensor& out,
239+
const ElementwiseOptimizedPath selected_optimized_path) {
240+
if ((selected_optimized_path ==
241+
ElementwiseOptimizedPath::kBroadcastLastDim) ||
242+
(selected_optimized_path ==
243+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
244+
return handle_last_dim_broadcast_elementwise<CTYPE>(
245+
ctx, vec_fun, a, b, out, selected_optimized_path);
246+
}
247+
248+
const Tensor* lhs;
249+
const Tensor* rhs;
250+
if ((selected_optimized_path ==
251+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
252+
(selected_optimized_path ==
253+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
254+
lhs = &b;
255+
rhs = &a;
256+
} else {
257+
// Catch failure to update logic when adding new broadcasting possibility.
258+
ET_DCHECK(
259+
(selected_optimized_path ==
260+
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
261+
(selected_optimized_path ==
262+
ElementwiseOptimizedPath::kBroadcastNdByNd));
263+
lhs = &a;
264+
rhs = &b;
265+
}
266+
auto error = resize_tensor(out, lhs->sizes());
267+
ET_KERNEL_CHECK_MSG(
268+
ctx,
269+
error == Error::Ok,
270+
InvalidArgument,
271+
out,
272+
"Failed to resize output tensor.");
273+
int64_t outer_size = 1;
274+
int64_t broadcast_size;
275+
int64_t inner_size;
276+
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
277+
(selected_optimized_path ==
278+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
279+
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
280+
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
281+
auto normalized_tensor_size_lhs =
282+
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
283+
outer_size = normalized_tensor_size_lhs[0];
284+
broadcast_size = normalized_tensor_size_lhs[1];
285+
inner_size = normalized_tensor_size_lhs[2];
286+
} else {
287+
broadcast_size = lhs->sizes()[lhs->dim() - 2];
288+
inner_size = lhs->sizes()[lhs->dim() - 1];
289+
}
290+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
291+
vec_fun,
292+
out.mutable_data_ptr<CTYPE>(),
293+
lhs->const_data_ptr<CTYPE>(),
294+
rhs->const_data_ptr<CTYPE>(),
295+
outer_size,
296+
broadcast_size,
297+
inner_size);
298+
return out;
299+
}
193300
} // namespace executor
194301
} // namespace torch

kernels/optimized/cpu/op_mul.cpp

Lines changed: 6 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -68,114 +68,6 @@ template <
6868
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
6969
: public ReportCanCastBug {};
7070

71-
Tensor& handle_last_dim_broadcast(
72-
KernelRuntimeContext& ctx,
73-
const Tensor& a,
74-
const Tensor& b,
75-
Tensor& out,
76-
const ElementwiseOptimizedPath selected_optimized_path) {
77-
ScalarType out_type = out.scalar_type();
78-
const Tensor* lhs;
79-
const Tensor* rhs;
80-
if (selected_optimized_path ==
81-
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
82-
lhs = &b;
83-
rhs = &a;
84-
} else {
85-
lhs = &a;
86-
rhs = &b;
87-
}
88-
auto error = resize_tensor(out, lhs->sizes());
89-
ET_KERNEL_CHECK_MSG(
90-
ctx,
91-
error == Error::Ok,
92-
InvalidArgument,
93-
out,
94-
"Failed to resize output tensor.");
95-
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
96-
const auto broadcast_size = out.size(out.dim() - 1);
97-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
98-
using Vec = executorch::vec::Vectorized<CTYPE>;
99-
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
100-
[](Vec x, Vec y) { return x * y; },
101-
out.mutable_data_ptr<CTYPE>(),
102-
lhs->const_data_ptr<CTYPE>(),
103-
rhs->const_data_ptr<CTYPE>(),
104-
outer_size,
105-
broadcast_size);
106-
});
107-
return out;
108-
}
109-
110-
Tensor& handle_broadcast_mul(
111-
KernelRuntimeContext& ctx,
112-
const Tensor& a,
113-
const Tensor& b,
114-
Tensor& out,
115-
const ElementwiseOptimizedPath selected_optimized_path) {
116-
if ((selected_optimized_path ==
117-
ElementwiseOptimizedPath::kBroadcastLastDim) ||
118-
(selected_optimized_path ==
119-
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
120-
return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path);
121-
}
122-
123-
ScalarType out_type = out.scalar_type();
124-
const Tensor* lhs;
125-
const Tensor* rhs;
126-
if ((selected_optimized_path ==
127-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
128-
(selected_optimized_path ==
129-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
130-
lhs = &b;
131-
rhs = &a;
132-
} else {
133-
// Catch failure to update logic when adding new broadcasting possibility.
134-
ET_DCHECK(
135-
(selected_optimized_path ==
136-
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
137-
(selected_optimized_path ==
138-
ElementwiseOptimizedPath::kBroadcastNdByNd));
139-
lhs = &a;
140-
rhs = &b;
141-
}
142-
auto error = resize_tensor(out, lhs->sizes());
143-
ET_KERNEL_CHECK_MSG(
144-
ctx,
145-
error == Error::Ok,
146-
InvalidArgument,
147-
out,
148-
"Failed to resize output tensor.");
149-
int64_t outer_size = 1;
150-
int64_t broadcast_size;
151-
int64_t inner_size;
152-
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
153-
(selected_optimized_path ==
154-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
155-
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
156-
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
157-
auto normalized_tensor_size_lhs =
158-
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
159-
outer_size = normalized_tensor_size_lhs[0];
160-
broadcast_size = normalized_tensor_size_lhs[1];
161-
inner_size = normalized_tensor_size_lhs[2];
162-
} else {
163-
broadcast_size = lhs->sizes()[lhs->dim() - 2];
164-
inner_size = lhs->sizes()[lhs->dim() - 1];
165-
}
166-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
167-
using Vec = executorch::vec::Vectorized<CTYPE>;
168-
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
169-
[](Vec x, Vec y) { return x * y; },
170-
out.mutable_data_ptr<CTYPE>(),
171-
lhs->const_data_ptr<CTYPE>(),
172-
rhs->const_data_ptr<CTYPE>(),
173-
outer_size,
174-
broadcast_size,
175-
inner_size);
176-
});
177-
return out;
178-
}
17971
} // namespace
18072

18173
Tensor& opt_mul_out(
@@ -238,7 +130,12 @@ Tensor& opt_mul_out(
238130
out.numel());
239131
});
240132
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
241-
return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path);
133+
ScalarType out_type = out.scalar_type();
134+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
135+
auto mul_lambda = [](auto x, auto y) { return x * y; };
136+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
137+
ctx, mul_lambda, a, b, out, selected_optimized_path);
138+
});
242139
} else {
243140
ScalarType common_type =
244141
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/test/op_mul_test.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,6 @@ TEST_F(OpMulOutTest, BroadcastNDTest) {
453453
test_broadcast_last_dim<ScalarType::BFloat16>();
454454
}
455455

456-
TEST_F(OpMulOutTest, BroadcastLastDimTest) {
457-
// Test broadcasting on the last dimension
458-
test_broadcast_last_dim<ScalarType::Float>();
459-
test_broadcast_last_dim<ScalarType::Half>();
460-
test_broadcast_last_dim<ScalarType::BFloat16>();
461-
}
462-
463456
// Broadcast tensor a and b's size to a new size c.
464457
TEST_F(OpMulOutTest, BroadcastAB2CTest) {
465458
TensorFactory<ScalarType::Int> tf_a;

0 commit comments

Comments
 (0)