Skip to content

Commit 5ef9025

Browse files
committed
braodcast in op_le
Differential Revision: [D76456398](https://our.internmc.facebook.com/intern/diff/D76456398/) [ghstack-poisoned]
1 parent 2dda7a2 commit 5ef9025

File tree

3 files changed

+851
-45
lines changed

3 files changed

+851
-45
lines changed

kernels/optimized/cpu/op_le.cpp

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/optimized/cpu/binary_ops.h>
910
#include <executorch/kernels/optimized/vec/functional.h>
1011
#include <executorch/kernels/optimized/vec/vec.h>
1112
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1213
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1314
#include <executorch/runtime/kernel/kernel_includes.h>
1415
#include <executorch/runtime/platform/assert.h>
16+
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1517

1618
namespace torch {
1719
namespace executor {
@@ -79,52 +81,41 @@ Tensor& opt_le_tensor_out(
7981
return out;
8082
}
8183

82-
ET_KERNEL_CHECK(ctx, tensors_have_same_shape(a, b), InvalidArgument, out);
83-
84-
// Resize for dynamic shape
85-
auto error = resize_tensor(out, a.sizes());
86-
ET_KERNEL_CHECK_MSG(
87-
ctx,
88-
error == Error::Ok,
89-
InvalidArgument,
90-
out,
91-
"Failed to resize output tensor.");
92-
93-
if (a_type == b_type && a_type == out_type) {
94-
ET_SWITCH_REAL_TYPES_AND(
95-
Bool, out_type, ctx, "le.Tensor_out", CTYPE, [&]() {
96-
using Vec = executorch::vec::Vectorized<CTYPE>;
97-
executorch::vec::map2<CTYPE>(
98-
[](Vec x, Vec y) { return x.le(y); },
99-
out.mutable_data_ptr<CTYPE>(),
100-
a.const_data_ptr<CTYPE>(),
101-
b.const_data_ptr<CTYPE>(),
102-
a.numel());
103-
});
84+
// Check for optimized broadcast paths
85+
auto selected_optimized_path = select_optimized_path(a, b, out);
86+
printf("selected_optimized_path: %d\n", static_cast<int>(selected_optimized_path));
87+
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
88+
// Resize for dynamic shape
89+
auto error = resize_tensor(out, a.sizes());
90+
ET_KERNEL_CHECK_MSG(
91+
ctx,
92+
error == Error::Ok,
93+
InvalidArgument,
94+
out,
95+
"Failed to resize output tensor.");
96+
97+
ET_SWITCH_REALB_TYPES(a_type, ctx, "le.Tensor_out", CTYPE, [&]() {
98+
using Vec = executorch::vec::Vectorized<CTYPE>;
99+
executorch::vec::map2<CTYPE>(
100+
[](Vec x, Vec y) { return x.le(y); },
101+
out.mutable_data_ptr<CTYPE>(),
102+
a.const_data_ptr<CTYPE>(),
103+
b.const_data_ptr<CTYPE>(),
104+
out.numel());
105+
});
106+
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
107+
// Handle optimized broadcast cases
108+
ET_SWITCH_REALB_TYPES(out_type, ctx, "le.Tensor_out", CTYPE, [&]() {
109+
using Vec = executorch::vec::Vectorized<CTYPE>;
110+
auto le_lambda = [](auto x, auto y) { return x.le(y); };
111+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
112+
ctx, le_lambda, a, b, out, selected_optimized_path);
113+
});
104114
} else {
105-
ET_SWITCH_REAL_TYPES_AND(
106-
Bool, a_type, ctx, "le.Tensor_out", CTYPE_A, [&]() {
107-
ET_SWITCH_REAL_TYPES_AND(
108-
Bool, b_type, ctx, "le.Tensor_out", CTYPE_B, [&]() {
109-
using CTYPE_IN = typename torch::executor::
110-
promote_types<CTYPE_A, CTYPE_B>::type;
111-
ET_DCHECK(
112-
CppTypeToScalarType<CTYPE_IN>::value ==
113-
promoteTypes(a_type, b_type));
114-
ET_SWITCH_REAL_TYPES_AND(
115-
Bool, out_type, ctx, "le.Tensor_out", CTYPE_OUT, [&]() {
116-
const size_t n = a.numel();
117-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
118-
const CTYPE_B* b_data = b.const_data_ptr<CTYPE_B>();
119-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
120-
for (auto i = 0; i < n; ++i) {
121-
out_data[i] = static_cast<CTYPE_OUT>(
122-
static_cast<CTYPE_IN>(a_data[i]) <=
123-
static_cast<CTYPE_IN>(b_data[i]));
124-
}
125-
});
126-
});
127-
});
115+
// @lint-ignore CLANGTIDY facebook-hte-CArray
116+
static constexpr const char op_name[] = "le.Tensor_out";
117+
return internal::comparison_tensor_out<std::less_equal, op_name>(
118+
ctx, a, b, out);
128119
}
129120

130121
return out;

0 commit comments

Comments
 (0)