@@ -120,48 +120,37 @@ Tensor& opt_div_out(
120
120
out.numel ());
121
121
});
122
122
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
123
- const Tensor* lhs;
124
- const Tensor* rhs;
123
+ // Reason for using alpha is becasuse handle_broadcast_elementwise
124
+ // is used for add and sub as well:
125
+ static constexpr const char op_name[] = " mul.out" ;
125
126
if (selected_optimized_path ==
126
- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) {
127
- lhs = &b;
128
- rhs = &a;
127
+ ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
128
+ selected_optimized_path ==
129
+ ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
130
+ selected_optimized_path ==
131
+ ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
132
+ // This behavior is a bit confusing.
133
+ // Reason we swap out args here is because handle_broadcast_elementwise
134
+ // handles this selected_optimized_path option a bit differently.
135
+ // This should really be resoled in handle_broadcast_elementwise.
136
+ // However, the current blocker is that handle_broadcast_elementwise tries
137
+ // to be agnostic of op. This should be fixed, likely by moving lambda
138
+ // creation to handle_broadcast_elementwise and it be aware of which op is
139
+ // being executed.
140
+ auto div_lambda = [](auto x, auto y, auto alpha) {
141
+ [[maybe_unused]] alpha;
142
+ return y / x;
143
+ };
144
+ return torch::executor::handle_broadcast_elementwise<op_name>(
145
+ ctx, div_lambda, a, b, out, selected_optimized_path);
129
146
} else {
130
- // Catch failure to update logic when subing new broadcasting possibility.
131
- ET_DCHECK (
132
- selected_optimized_path ==
133
- ElementwiseOptimizedPath:: kBroadcast2dBy1d ) ;
134
- lhs = &a;
135
- rhs = &b ;
147
+ auto div_lambda = []( auto x, auto y, auto alpha) {
148
+ [[maybe_unused]] alpha;
149
+ return x / y;
150
+ } ;
151
+ return torch::executor::handle_broadcast_elementwise<op_name>(
152
+ ctx, div_lambda, a, b, out, selected_optimized_path) ;
136
153
}
137
- auto error = resize_tensor (out, lhs->sizes ());
138
- ET_KERNEL_CHECK_MSG (
139
- ctx,
140
- error == Error::Ok,
141
- InvalidArgument,
142
- out,
143
- " Failed to resize output tensor." );
144
- ET_SWITCH_REALB_TYPES (out_type, ctx, " sub.out" , CTYPE, [&]() {
145
- using Vec = executorch::vec::Vectorized<CTYPE>;
146
- if (selected_optimized_path ==
147
- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) {
148
- executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
149
- [](Vec x, Vec y) { return y / x; },
150
- out.mutable_data_ptr <CTYPE>(),
151
- lhs->const_data_ptr <CTYPE>(),
152
- rhs->const_data_ptr <CTYPE>(),
153
- lhs->sizes ()[lhs->dim () - 2 ],
154
- lhs->sizes ()[lhs->dim () - 1 ]);
155
- } else {
156
- executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
157
- [](Vec x, Vec y) { return x / y; },
158
- out.mutable_data_ptr <CTYPE>(),
159
- lhs->const_data_ptr <CTYPE>(),
160
- rhs->const_data_ptr <CTYPE>(),
161
- lhs->sizes ()[lhs->dim () - 2 ],
162
- lhs->sizes ()[lhs->dim () - 1 ]);
163
- }
164
- });
165
154
} else {
166
155
ScalarType common_type = get_compute_type (a_type, b_type);
167
156
ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
0 commit comments