6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
+ #include < executorch/kernels/optimized/cpu/binary_ops.h>
9
10
#include < executorch/kernels/optimized/vec/functional.h>
10
11
#include < executorch/kernels/optimized/vec/vec.h>
11
12
#include < executorch/kernels/portable/cpu/scalar_utils.h>
@@ -48,7 +49,57 @@ Tensor& opt_div_out(
48
49
ScalarType b_type = b.scalar_type ();
49
50
ScalarType out_type = out.scalar_type ();
50
51
51
- if (a_type == b_type && a_type == out_type && a.sizes ().equals (b.sizes ())) {
52
+ if (a.numel () == 1 || b.numel () == 1 ) {
53
+ if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
54
+ const Tensor* tensor;
55
+ const Tensor* scalar;
56
+ ScalarType tensor_type;
57
+ ScalarType scalar_type;
58
+ if (a.numel () == 1 ) {
59
+ tensor = &b;
60
+ tensor_type = b_type;
61
+ scalar = &a;
62
+ scalar_type = a_type;
63
+ } else {
64
+ tensor = &a;
65
+ tensor_type = a_type;
66
+ scalar = &b;
67
+ scalar_type = b_type;
68
+ }
69
+ auto error = resize_tensor (out, tensor->sizes ());
70
+ ET_KERNEL_CHECK_MSG (
71
+ ctx,
72
+ error == Error::Ok,
73
+ InvalidArgument,
74
+ out,
75
+ " Failed to resize output tensor." );
76
+ ET_SWITCH_REALB_TYPES (tensor_type, ctx, " div.out" , CTYPE, [&]() {
77
+ ET_SWITCH_REALB_TYPES (scalar_type, ctx, " div.out" , CTYPE_SCALAR, [&]() {
78
+ CTYPE_SCALAR scalar_val = *scalar->const_data_ptr <CTYPE_SCALAR>();
79
+ CTYPE scalar_casted = static_cast <CTYPE>(scalar_val);
80
+
81
+ using Vec = executorch::vec::Vectorized<CTYPE>;
82
+ if (a.numel () == 1 ) {
83
+ executorch::vec::map<CTYPE>(
84
+ [scalar_casted](Vec x) { return Vec (scalar_casted) / x; },
85
+ out.mutable_data_ptr <CTYPE>(),
86
+ tensor->const_data_ptr <CTYPE>(),
87
+ out.numel ());
88
+ } else {
89
+ executorch::vec::map<CTYPE>(
90
+ [scalar_casted](Vec x) { return x / Vec (scalar_casted); },
91
+ out.mutable_data_ptr <CTYPE>(),
92
+ tensor->const_data_ptr <CTYPE>(),
93
+ out.numel ());
94
+ }
95
+ });
96
+ });
97
+ return out;
98
+ }
99
+ }
100
+
101
+ auto selected_optimized_path = select_optimized_path (a, b, out);
102
+ if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
52
103
// Resize for dynamic shape
53
104
auto error = resize_tensor (out, a.sizes ());
54
105
ET_KERNEL_CHECK_MSG (
@@ -67,6 +118,49 @@ Tensor& opt_div_out(
67
118
b.const_data_ptr <CTYPE>(),
68
119
out.numel ());
69
120
});
121
+ } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
122
+ const Tensor* lhs;
123
+ const Tensor* rhs;
124
+ if (selected_optimized_path ==
125
+ ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) {
126
+ lhs = &b;
127
+ rhs = &a;
128
+ } else {
129
+ // Catch failure to update logic when subing new broadcasting possibility.
130
+ ET_DCHECK (
131
+ selected_optimized_path ==
132
+ ElementwiseOptimizedPath::kBroadcast2dBy1d );
133
+ lhs = &a;
134
+ rhs = &b;
135
+ }
136
+ auto error = resize_tensor (out, lhs->sizes ());
137
+ ET_KERNEL_CHECK_MSG (
138
+ ctx,
139
+ error == Error::Ok,
140
+ InvalidArgument,
141
+ out,
142
+ " Failed to resize output tensor." );
143
+ ET_SWITCH_REALB_TYPES (out_type, ctx, " sub.out" , CTYPE, [&]() {
144
+ using Vec = executorch::vec::Vectorized<CTYPE>;
145
+ if (selected_optimized_path ==
146
+ ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) {
147
+ executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
148
+ [](Vec x, Vec y) { return y / x; },
149
+ out.mutable_data_ptr <CTYPE>(),
150
+ lhs->const_data_ptr <CTYPE>(),
151
+ rhs->const_data_ptr <CTYPE>(),
152
+ lhs->sizes ()[lhs->dim () - 2 ],
153
+ lhs->sizes ()[lhs->dim () - 1 ]);
154
+ } else {
155
+ executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
156
+ [](Vec x, Vec y) { return x / y; },
157
+ out.mutable_data_ptr <CTYPE>(),
158
+ lhs->const_data_ptr <CTYPE>(),
159
+ rhs->const_data_ptr <CTYPE>(),
160
+ lhs->sizes ()[lhs->dim () - 2 ],
161
+ lhs->sizes ()[lhs->dim () - 1 ]);
162
+ }
163
+ });
70
164
} else {
71
165
ScalarType common_type = get_compute_type (a_type, b_type);
72
166
ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
@@ -77,25 +171,23 @@ Tensor& opt_div_out(
77
171
InvalidArgument,
78
172
out);
79
173
80
- ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " div.out" , CTYPE_A, [&]() {
81
- ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, " div.out" , CTYPE_B, [&]() {
82
- ET_SWITCH_REAL_TYPES_AND (
83
- Bool, common_type, ctx, " div.out" , CTYPE_IN, [&]() {
84
- ET_SWITCH_REAL_TYPES_AND (
85
- Bool, out_type, ctx, " div.out" , CTYPE_OUT, [&]() {
86
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
87
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
88
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
89
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
90
- CTYPE_IN value = a_casted / b_casted;
91
-
92
- return static_cast <CTYPE_OUT>(value);
93
- },
94
- a,
95
- b,
96
- out);
97
- });
98
- });
174
+ ET_SWITCH_REALB_TYPES (a_type, ctx, " div.out" , CTYPE_A, [&]() {
175
+ ET_SWITCH_REALB_TYPES (b_type, ctx, " div.out" , CTYPE_B, [&]() {
176
+ ET_SWITCH_REALB_TYPES (common_type, ctx, " div.out" , CTYPE_IN, [&]() {
177
+ ET_SWITCH_REALB_TYPES (out_type, ctx, " div.out" , CTYPE_OUT, [&]() {
178
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
179
+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
180
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
181
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
182
+ CTYPE_IN value = a_casted / b_casted;
183
+
184
+ return static_cast <CTYPE_OUT>(value);
185
+ },
186
+ a,
187
+ b,
188
+ out);
189
+ });
190
+ });
99
191
});
100
192
});
101
193
}
0 commit comments