|
6 | 6 | * LICENSE file in the root directory of this source tree.
|
7 | 7 | */
|
8 | 8 |
|
| 9 | +#include <executorch/kernels/optimized/cpu/binary_ops.h> |
9 | 10 | #include <executorch/kernels/optimized/vec/functional.h>
|
10 | 11 | #include <executorch/kernels/optimized/vec/vec.h>
|
11 | 12 | #include <executorch/kernels/portable/cpu/scalar_utils.h>
|
12 | 13 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
|
13 | 14 | #include <executorch/runtime/kernel/kernel_includes.h>
|
14 | 15 | #include <executorch/runtime/platform/assert.h>
|
| 16 | +#include <executorch/kernels/portable/cpu/pattern/comparison_op.h> |
15 | 17 |
|
16 | 18 | namespace torch {
|
17 | 19 | namespace executor {
|
@@ -79,52 +81,41 @@ Tensor& opt_le_tensor_out(
|
79 | 81 | return out;
|
80 | 82 | }
|
81 | 83 |
|
82 |
| - ET_KERNEL_CHECK(ctx, tensors_have_same_shape(a, b), InvalidArgument, out); |
83 |
| - |
84 |
| - // Resize for dynamic shape |
85 |
| - auto error = resize_tensor(out, a.sizes()); |
86 |
| - ET_KERNEL_CHECK_MSG( |
87 |
| - ctx, |
88 |
| - error == Error::Ok, |
89 |
| - InvalidArgument, |
90 |
| - out, |
91 |
| - "Failed to resize output tensor."); |
92 |
| - |
93 |
| - if (a_type == b_type && a_type == out_type) { |
94 |
| - ET_SWITCH_REAL_TYPES_AND( |
95 |
| - Bool, out_type, ctx, "le.Tensor_out", CTYPE, [&]() { |
96 |
| - using Vec = executorch::vec::Vectorized<CTYPE>; |
97 |
| - executorch::vec::map2<CTYPE>( |
98 |
| - [](Vec x, Vec y) { return x.le(y); }, |
99 |
| - out.mutable_data_ptr<CTYPE>(), |
100 |
| - a.const_data_ptr<CTYPE>(), |
101 |
| - b.const_data_ptr<CTYPE>(), |
102 |
| - a.numel()); |
103 |
| - }); |
| 84 | + // Check for optimized broadcast paths |
| 85 | + auto selected_optimized_path = select_optimized_path(a, b, out); |
| 86 | + printf("selected_optimized_path: %d\n", static_cast<int>(selected_optimized_path)); |
| 87 | + if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) { |
| 88 | + // Resize for dynamic shape |
| 89 | + auto error = resize_tensor(out, a.sizes()); |
| 90 | + ET_KERNEL_CHECK_MSG( |
| 91 | + ctx, |
| 92 | + error == Error::Ok, |
| 93 | + InvalidArgument, |
| 94 | + out, |
| 95 | + "Failed to resize output tensor."); |
| 96 | + |
| 97 | + ET_SWITCH_REALB_TYPES(a_type, ctx, "le.Tensor_out", CTYPE, [&]() { |
| 98 | + using Vec = executorch::vec::Vectorized<CTYPE>; |
| 99 | + executorch::vec::map2<CTYPE>( |
| 100 | + [](Vec x, Vec y) { return x.le(y); }, |
| 101 | + out.mutable_data_ptr<CTYPE>(), |
| 102 | + a.const_data_ptr<CTYPE>(), |
| 103 | + b.const_data_ptr<CTYPE>(), |
| 104 | + out.numel()); |
| 105 | + }); |
| 106 | + } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { |
| 107 | + // Handle optimized broadcast cases |
| 108 | + ET_SWITCH_REALB_TYPES(out_type, ctx, "le.Tensor_out", CTYPE, [&]() { |
| 109 | + using Vec = executorch::vec::Vectorized<CTYPE>; |
| 110 | + auto le_lambda = [](auto x, auto y) { return x.le(y); }; |
| 111 | + return torch::executor::handle_broadcast_elementwise<CTYPE>( |
| 112 | + ctx, le_lambda, a, b, out, selected_optimized_path); |
| 113 | + }); |
104 | 114 | } else {
|
105 |
| - ET_SWITCH_REAL_TYPES_AND( |
106 |
| - Bool, a_type, ctx, "le.Tensor_out", CTYPE_A, [&]() { |
107 |
| - ET_SWITCH_REAL_TYPES_AND( |
108 |
| - Bool, b_type, ctx, "le.Tensor_out", CTYPE_B, [&]() { |
109 |
| - using CTYPE_IN = typename torch::executor:: |
110 |
| - promote_types<CTYPE_A, CTYPE_B>::type; |
111 |
| - ET_DCHECK( |
112 |
| - CppTypeToScalarType<CTYPE_IN>::value == |
113 |
| - promoteTypes(a_type, b_type)); |
114 |
| - ET_SWITCH_REAL_TYPES_AND( |
115 |
| - Bool, out_type, ctx, "le.Tensor_out", CTYPE_OUT, [&]() { |
116 |
| - const size_t n = a.numel(); |
117 |
| - const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>(); |
118 |
| - const CTYPE_B* b_data = b.const_data_ptr<CTYPE_B>(); |
119 |
| - CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>(); |
120 |
| - for (auto i = 0; i < n; ++i) { |
121 |
| - out_data[i] = static_cast<CTYPE_OUT>( |
122 |
| - static_cast<CTYPE_IN>(a_data[i]) <= |
123 |
| - static_cast<CTYPE_IN>(b_data[i])); |
124 |
| - } |
125 |
| - }); |
126 |
| - }); |
127 |
| - }); |
| 115 | + // @lint-ignore CLANGTIDY facebook-hte-CArray |
| 116 | + static constexpr const char op_name[] = "le.Tensor_out"; |
| 117 | + return internal::comparison_tensor_out<std::less_equal, op_name>( |
| 118 | + ctx, a, b, out); |
128 | 119 | }
|
129 | 120 |
|
130 | 121 | return out;
|
|
0 commit comments