@@ -68,114 +68,6 @@ template <
68
68
struct MulInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
69
69
: public ReportCanCastBug {};
70
70
71
- Tensor& handle_last_dim_broadcast (
72
- KernelRuntimeContext& ctx,
73
- const Tensor& a,
74
- const Tensor& b,
75
- Tensor& out,
76
- const ElementwiseOptimizedPath selected_optimized_path) {
77
- ScalarType out_type = out.scalar_type ();
78
- const Tensor* lhs;
79
- const Tensor* rhs;
80
- if (selected_optimized_path ==
81
- ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ) {
82
- lhs = &b;
83
- rhs = &a;
84
- } else {
85
- lhs = &a;
86
- rhs = &b;
87
- }
88
- auto error = resize_tensor (out, lhs->sizes ());
89
- ET_KERNEL_CHECK_MSG (
90
- ctx,
91
- error == Error::Ok,
92
- InvalidArgument,
93
- out,
94
- " Failed to resize output tensor." );
95
- const size_t outer_size = getLeadingDims (out, out.dim () - 1 );
96
- const auto broadcast_size = out.size (out.dim () - 1 );
97
- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
98
- using Vec = executorch::vec::Vectorized<CTYPE>;
99
- executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
100
- [](Vec x, Vec y) { return x * y; },
101
- out.mutable_data_ptr <CTYPE>(),
102
- lhs->const_data_ptr <CTYPE>(),
103
- rhs->const_data_ptr <CTYPE>(),
104
- outer_size,
105
- broadcast_size);
106
- });
107
- return out;
108
- }
109
-
110
- Tensor& handle_broadcast_mul (
111
- KernelRuntimeContext& ctx,
112
- const Tensor& a,
113
- const Tensor& b,
114
- Tensor& out,
115
- const ElementwiseOptimizedPath selected_optimized_path) {
116
- if ((selected_optimized_path ==
117
- ElementwiseOptimizedPath::kBroadcastLastDim ) ||
118
- (selected_optimized_path ==
119
- ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments )) {
120
- return handle_last_dim_broadcast (ctx, a, b, out, selected_optimized_path);
121
- }
122
-
123
- ScalarType out_type = out.scalar_type ();
124
- const Tensor* lhs;
125
- const Tensor* rhs;
126
- if ((selected_optimized_path ==
127
- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) ||
128
- (selected_optimized_path ==
129
- ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments )) {
130
- lhs = &b;
131
- rhs = &a;
132
- } else {
133
- // Catch failure to update logic when adding new broadcasting possibility.
134
- ET_DCHECK (
135
- (selected_optimized_path ==
136
- ElementwiseOptimizedPath::kBroadcast2dBy1d ) ||
137
- (selected_optimized_path ==
138
- ElementwiseOptimizedPath::kBroadcastNdByNd ));
139
- lhs = &a;
140
- rhs = &b;
141
- }
142
- auto error = resize_tensor (out, lhs->sizes ());
143
- ET_KERNEL_CHECK_MSG (
144
- ctx,
145
- error == Error::Ok,
146
- InvalidArgument,
147
- out,
148
- " Failed to resize output tensor." );
149
- int64_t outer_size = 1 ;
150
- int64_t broadcast_size;
151
- int64_t inner_size;
152
- if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd ) ||
153
- (selected_optimized_path ==
154
- ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments )) {
155
- int32_t broadcast_dim = internal::get_broadcast_dim (*lhs, *rhs);
156
- int32_t broadcast_dim_lhs = lhs->dim () + broadcast_dim;
157
- auto normalized_tensor_size_lhs =
158
- get_normalized_tensor_size (*lhs, broadcast_dim_lhs);
159
- outer_size = normalized_tensor_size_lhs[0 ];
160
- broadcast_size = normalized_tensor_size_lhs[1 ];
161
- inner_size = normalized_tensor_size_lhs[2 ];
162
- } else {
163
- broadcast_size = lhs->sizes ()[lhs->dim () - 2 ];
164
- inner_size = lhs->sizes ()[lhs->dim () - 1 ];
165
- }
166
- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
167
- using Vec = executorch::vec::Vectorized<CTYPE>;
168
- executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
169
- [](Vec x, Vec y) { return x * y; },
170
- out.mutable_data_ptr <CTYPE>(),
171
- lhs->const_data_ptr <CTYPE>(),
172
- rhs->const_data_ptr <CTYPE>(),
173
- outer_size,
174
- broadcast_size,
175
- inner_size);
176
- });
177
- return out;
178
- }
179
71
} // namespace
180
72
181
73
Tensor& opt_mul_out (
@@ -238,7 +130,9 @@ Tensor& opt_mul_out(
238
130
out.numel ());
239
131
});
240
132
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
241
- return handle_broadcast_mul (ctx, a, b, out, selected_optimized_path);
133
+ auto mul_lambda = [](auto x, auto y) { return x * y; };
134
+ return torch::executor::handle_broadcast_elementwise (
135
+ ctx, mul_lambda, a, b, out, selected_optimized_path);
242
136
} else {
243
137
ScalarType common_type =
244
138
promoteTypes (a_type, b_type, /* half_to_float*/ true );
0 commit comments