Skip to content

Commit 7d9494f

Browse files
committed
Update on "[Executorch] Refactor op_mul's broadcasting utils"
Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
2 parents e814bb7 + ed79e8c commit 7d9494f

File tree

2 files changed

+24
-26
lines changed

2 files changed

+24
-26
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
191191
return normalized_tensor_size;
192192
}
193193

194-
template <typename Op>
194+
template <typename CTYPE, typename Op>
195195
Tensor& handle_last_dim_broadcast_elementwise(
196196
KernelRuntimeContext& ctx,
197197
const Op& vec_fun,
@@ -219,19 +219,17 @@ Tensor& handle_last_dim_broadcast_elementwise(
219219
"Failed to resize output tensor.");
220220
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
221221
const auto broadcast_size = out.size(out.dim() - 1);
222-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
223-
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
224-
vec_fun,
225-
out.mutable_data_ptr<CTYPE>(),
226-
lhs->const_data_ptr<CTYPE>(),
227-
rhs->const_data_ptr<CTYPE>(),
228-
outer_size,
229-
broadcast_size);
230-
});
222+
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
223+
vec_fun,
224+
out.mutable_data_ptr<CTYPE>(),
225+
lhs->const_data_ptr<CTYPE>(),
226+
rhs->const_data_ptr<CTYPE>(),
227+
outer_size,
228+
broadcast_size);
231229
return out;
232230
}
233231

234-
template <typename Op>
232+
template <typename CTYPE, typename Op>
235233
Tensor& handle_broadcast_elementwise(
236234
KernelRuntimeContext& ctx,
237235
const Op& vec_fun,
@@ -243,11 +241,10 @@ Tensor& handle_broadcast_elementwise(
243241
ElementwiseOptimizedPath::kBroadcastLastDim) ||
244242
(selected_optimized_path ==
245243
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
246-
return handle_last_dim_broadcast_elementwise(
244+
return handle_last_dim_broadcast_elementwise<CTYPE>(
247245
ctx, vec_fun, a, b, out, selected_optimized_path);
248246
}
249247

250-
ScalarType out_type = out.scalar_type();
251248
const Tensor* lhs;
252249
const Tensor* rhs;
253250
if ((selected_optimized_path ==
@@ -290,16 +287,14 @@ Tensor& handle_broadcast_elementwise(
290287
broadcast_size = lhs->sizes()[lhs->dim() - 2];
291288
inner_size = lhs->sizes()[lhs->dim() - 1];
292289
}
293-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
294-
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
295-
vec_fun,
296-
out.mutable_data_ptr<CTYPE>(),
297-
lhs->const_data_ptr<CTYPE>(),
298-
rhs->const_data_ptr<CTYPE>(),
299-
outer_size,
300-
broadcast_size,
301-
inner_size);
302-
});
290+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
291+
vec_fun,
292+
out.mutable_data_ptr<CTYPE>(),
293+
lhs->const_data_ptr<CTYPE>(),
294+
rhs->const_data_ptr<CTYPE>(),
295+
outer_size,
296+
broadcast_size,
297+
inner_size);
303298
return out;
304299
}
305300
} // namespace executor

kernels/optimized/cpu/op_mul.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,12 @@ Tensor& opt_mul_out(
130130
out.numel());
131131
});
132132
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
133-
auto mul_lambda = [](auto x, auto y) { return x * y; };
134-
return torch::executor::handle_broadcast_elementwise(
135-
ctx, mul_lambda, a, b, out, selected_optimized_path);
133+
ScalarType out_type = out.scalar_type();
134+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
135+
auto mul_lambda = [](auto x, auto y) { return x * y; };
136+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
137+
ctx, mul_lambda, a, b, out, selected_optimized_path);
138+
});
136139
} else {
137140
ScalarType common_type =
138141
promoteTypes(a_type, b_type, /*half_to_float*/ true);

0 commit comments

Comments
 (0)