Skip to content

Commit 8bc99d3

Browse files
authored
Respect selective build for dtype_specialized_elementwise_fn_impl in elementwise_util (#11975)
This fancy fast path I added didn't respect selective build. Now it should.
1 parent 52b008c commit 8bc99d3

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <c10/util/irange.h>
12+
#include <executorch/kernels/portable/cpu/selective_build.h>
1213
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
1314
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1415
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
@@ -345,20 +346,22 @@ inline void apply_elementwise_fn(
345346
}
346347

347348
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
348-
const bool all_inputs_compute_dtype =
349-
((inputs.first->scalar_type() == compute_type) && ...);
350-
351-
constexpr ScalarType out_specialized_scalar_type =
352-
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
353-
if (all_inputs_compute_dtype &&
354-
out.scalar_type() == out_specialized_scalar_type) {
355-
using CTYPE_OUT =
356-
typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
357-
dtype_specialized_elementwise_fn_impl<
358-
CTYPE_COMPUTE,
359-
CTYPE_OUT,
360-
support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...);
361-
return;
349+
if constexpr (should_include_kernel_dtype(op_name, compute_type)) {
350+
const bool all_inputs_compute_dtype =
351+
((inputs.first->scalar_type() == compute_type) && ...);
352+
353+
constexpr ScalarType out_specialized_scalar_type =
354+
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
355+
if (all_inputs_compute_dtype &&
356+
out.scalar_type() == out_specialized_scalar_type) {
357+
using CTYPE_OUT =
358+
typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
359+
dtype_specialized_elementwise_fn_impl<
360+
CTYPE_COMPUTE,
361+
CTYPE_OUT,
362+
support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...);
363+
return;
364+
}
362365
}
363366

364367
apply_elementwise_fn_generic_impl<

kernels/portable/cpu/util/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ def define_common_targets():
115115
":vectorized_math",
116116
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
117117
"//executorch/runtime/kernel:kernel_runtime_context",
118+
"//executorch/kernels/portable/cpu:scalar_utils",
118119
"//executorch/extension/threadpool:threadpool",
119120
],
120121
deps = [
121-
"//executorch/kernels/portable/cpu:scalar_utils",
122122
"//executorch/runtime/kernel:kernel_includes",
123123
],
124124
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"],

0 commit comments

Comments
 (0)