|
9 | 9 | #pragma once
|
10 | 10 |
|
11 | 11 | #include <c10/util/irange.h>
|
| 12 | +#include <executorch/kernels/portable/cpu/selective_build.h> |
12 | 13 | #include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
|
13 | 14 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
|
14 | 15 | #include <executorch/kernels/portable/cpu/util/dtype_util.h>
|
@@ -345,20 +346,22 @@ inline void apply_elementwise_fn(
|
345 | 346 | }
|
346 | 347 |
|
347 | 348 | constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
|
348 |
| - const bool all_inputs_compute_dtype = |
349 |
| - ((inputs.first->scalar_type() == compute_type) && ...); |
350 |
| - |
351 |
| - constexpr ScalarType out_specialized_scalar_type = |
352 |
| - specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes); |
353 |
| - if (all_inputs_compute_dtype && |
354 |
| - out.scalar_type() == out_specialized_scalar_type) { |
355 |
| - using CTYPE_OUT = |
356 |
| - typename ScalarTypeToCppType<out_specialized_scalar_type>::type; |
357 |
| - dtype_specialized_elementwise_fn_impl< |
358 |
| - CTYPE_COMPUTE, |
359 |
| - CTYPE_OUT, |
360 |
| - support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); |
361 |
| - return; |
| 349 | + if constexpr (should_include_kernel_dtype(op_name, compute_type)) { |
| 350 | + const bool all_inputs_compute_dtype = |
| 351 | + ((inputs.first->scalar_type() == compute_type) && ...); |
| 352 | + |
| 353 | + constexpr ScalarType out_specialized_scalar_type = |
| 354 | + specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes); |
| 355 | + if (all_inputs_compute_dtype && |
| 356 | + out.scalar_type() == out_specialized_scalar_type) { |
| 357 | + using CTYPE_OUT = |
| 358 | + typename ScalarTypeToCppType<out_specialized_scalar_type>::type; |
| 359 | + dtype_specialized_elementwise_fn_impl< |
| 360 | + CTYPE_COMPUTE, |
| 361 | + CTYPE_OUT, |
| 362 | + support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); |
| 363 | + return; |
| 364 | + } |
362 | 365 | }
|
363 | 366 |
|
364 | 367 | apply_elementwise_fn_generic_impl<
|
|
0 commit comments