@@ -191,7 +191,7 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
191
191
return normalized_tensor_size;
192
192
}
193
193
194
- template <typename Op>
194
+ template <typename CTYPE, typename Op>
195
195
Tensor& handle_last_dim_broadcast_elementwise (
196
196
KernelRuntimeContext& ctx,
197
197
const Op& vec_fun,
@@ -219,19 +219,17 @@ Tensor& handle_last_dim_broadcast_elementwise(
219
219
" Failed to resize output tensor." );
220
220
const size_t outer_size = getLeadingDims (out, out.dim () - 1 );
221
221
const auto broadcast_size = out.size (out.dim () - 1 );
222
- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
223
- executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
224
- vec_fun,
225
- out.mutable_data_ptr <CTYPE>(),
226
- lhs->const_data_ptr <CTYPE>(),
227
- rhs->const_data_ptr <CTYPE>(),
228
- outer_size,
229
- broadcast_size);
230
- });
222
+ executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
223
+ vec_fun,
224
+ out.mutable_data_ptr <CTYPE>(),
225
+ lhs->const_data_ptr <CTYPE>(),
226
+ rhs->const_data_ptr <CTYPE>(),
227
+ outer_size,
228
+ broadcast_size);
231
229
return out;
232
230
}
233
231
234
- template <typename Op>
232
+ template <typename CTYPE, typename Op>
235
233
Tensor& handle_broadcast_elementwise (
236
234
KernelRuntimeContext& ctx,
237
235
const Op& vec_fun,
@@ -243,11 +241,10 @@ Tensor& handle_broadcast_elementwise(
243
241
ElementwiseOptimizedPath::kBroadcastLastDim ) ||
244
242
(selected_optimized_path ==
245
243
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments )) {
246
- return handle_last_dim_broadcast_elementwise (
244
+ return handle_last_dim_broadcast_elementwise<CTYPE> (
247
245
ctx, vec_fun, a, b, out, selected_optimized_path);
248
246
}
249
247
250
- ScalarType out_type = out.scalar_type ();
251
248
const Tensor* lhs;
252
249
const Tensor* rhs;
253
250
if ((selected_optimized_path ==
@@ -290,16 +287,14 @@ Tensor& handle_broadcast_elementwise(
290
287
broadcast_size = lhs->sizes ()[lhs->dim () - 2 ];
291
288
inner_size = lhs->sizes ()[lhs->dim () - 1 ];
292
289
}
293
- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
294
- executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
295
- vec_fun,
296
- out.mutable_data_ptr <CTYPE>(),
297
- lhs->const_data_ptr <CTYPE>(),
298
- rhs->const_data_ptr <CTYPE>(),
299
- outer_size,
300
- broadcast_size,
301
- inner_size);
302
- });
290
+ executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
291
+ vec_fun,
292
+ out.mutable_data_ptr <CTYPE>(),
293
+ lhs->const_data_ptr <CTYPE>(),
294
+ rhs->const_data_ptr <CTYPE>(),
295
+ outer_size,
296
+ broadcast_size,
297
+ inner_size);
303
298
return out;
304
299
}
305
300
} // namespace executor
0 commit comments