Skip to content

Commit 140cc14

Browse files
[Exutorch] Add broadcast support for le op (#11635)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11569 by @kimishpatel ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/kimishpatel/191/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/kimishpatel/191/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/kimishpatel/191/orig @diff-train-skip-merge Co-authored-by: Kimish Patel <[email protected]>
1 parent 2a30250 commit 140cc14

File tree

3 files changed

+963
-45
lines changed

3 files changed

+963
-45
lines changed

kernels/optimized/cpu/op_le.cpp

Lines changed: 34 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <ATen/cpu/vec/functional.h>
1010
#include <ATen/cpu/vec/vec.h>
11+
#include <executorch/kernels/optimized/cpu/binary_ops.h>
12+
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1113
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1214
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1315
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -79,52 +81,39 @@ Tensor& opt_le_tensor_out(
7981
return out;
8082
}
8183

82-
ET_KERNEL_CHECK(ctx, tensors_have_same_shape(a, b), InvalidArgument, out);
83-
84-
// Resize for dynamic shape
85-
auto error = resize_tensor(out, a.sizes());
86-
ET_KERNEL_CHECK_MSG(
87-
ctx,
88-
error == Error::Ok,
89-
InvalidArgument,
90-
out,
91-
"Failed to resize output tensor.");
92-
93-
if (a_type == b_type && a_type == out_type) {
94-
ET_SWITCH_REAL_TYPES_AND(
95-
Bool, out_type, ctx, "le.Tensor_out", CTYPE, [&]() {
96-
using Vec = at::vec::Vectorized<CTYPE>;
97-
at::vec::map2<CTYPE>(
98-
[](Vec x, Vec y) { return x.le(y); },
99-
out.mutable_data_ptr<CTYPE>(),
100-
a.const_data_ptr<CTYPE>(),
101-
b.const_data_ptr<CTYPE>(),
102-
a.numel());
103-
});
84+
// Check for optimized broadcast paths
85+
auto selected_optimized_path = select_optimized_path(a, b, out);
86+
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
87+
// Resize for dynamic shape
88+
auto error = resize_to_broadcast_target_size(a, b, out);
89+
ET_KERNEL_CHECK_MSG(
90+
ctx,
91+
error == Error::Ok,
92+
InvalidArgument,
93+
out,
94+
"Failed to resize output tensor.");
95+
96+
ET_SWITCH_REALB_TYPES(a_type, ctx, "le.Tensor_out", CTYPE, [&]() {
97+
using Vec = at::vec::Vectorized<CTYPE>;
98+
at::vec::map2<CTYPE>(
99+
[](Vec x, Vec y) { return x.le(y); },
100+
out.mutable_data_ptr<CTYPE>(),
101+
a.const_data_ptr<CTYPE>(),
102+
b.const_data_ptr<CTYPE>(),
103+
out.numel());
104+
});
105+
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
106+
// Handle optimized broadcast cases
107+
ET_SWITCH_REALB_TYPES(out_type, ctx, "le.Tensor_out", CTYPE, [&]() {
108+
auto le_lambda = [](auto x, auto y) { return x.le(y); };
109+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
110+
ctx, le_lambda, a, b, out, selected_optimized_path);
111+
});
104112
} else {
105-
ET_SWITCH_REAL_TYPES_AND(
106-
Bool, a_type, ctx, "le.Tensor_out", CTYPE_A, [&]() {
107-
ET_SWITCH_REAL_TYPES_AND(
108-
Bool, b_type, ctx, "le.Tensor_out", CTYPE_B, [&]() {
109-
using CTYPE_IN = typename torch::executor::
110-
promote_types<CTYPE_A, CTYPE_B>::type;
111-
ET_DCHECK(
112-
CppTypeToScalarType<CTYPE_IN>::value ==
113-
promoteTypes(a_type, b_type));
114-
ET_SWITCH_REAL_TYPES_AND(
115-
Bool, out_type, ctx, "le.Tensor_out", CTYPE_OUT, [&]() {
116-
const size_t n = a.numel();
117-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
118-
const CTYPE_B* b_data = b.const_data_ptr<CTYPE_B>();
119-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
120-
for (auto i = 0; i < n; ++i) {
121-
out_data[i] = static_cast<CTYPE_OUT>(
122-
static_cast<CTYPE_IN>(a_data[i]) <=
123-
static_cast<CTYPE_IN>(b_data[i]));
124-
}
125-
});
126-
});
127-
});
113+
// @lint-ignore CLANGTIDY facebook-hte-CArray
114+
static constexpr const char op_name[] = "le.Tensor_out";
115+
return internal::comparison_tensor_out<std::less_equal, op_name>(
116+
ctx, a, b, out);
128117
}
129118

130119
return out;

0 commit comments

Comments
 (0)