Skip to content

Commit 5937f4a

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Rewrite logical op pattern. Binary ops: logical_and, logical_or, logical_xor (#6015)
Summary: Pull Request resolved: #6015 Aggregate of binary logical ops: - 680 K -> 5 K ghstack-source-id: 246985134 exported-using-ghexport Reviewed By: swolchok Differential Revision: D63989999 fbshipit-source-id: e419a116b4fe7126ab2e499db64cb5e661cfd90e
1 parent addb403 commit 5937f4a

File tree

8 files changed

+87
-87
lines changed

8 files changed

+87
-87
lines changed

kernels/portable/cpu/op_logical_and.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
9+
#include <executorch/kernels/portable/cpu/pattern/logical_op.h>
1010
#include <executorch/runtime/kernel/kernel_includes.h>
1111
#include <cmath>
1212

@@ -26,8 +26,9 @@ Tensor& logical_and_out(
2626
const Tensor& a,
2727
const Tensor& b,
2828
Tensor& out) {
29-
return internal::binary_ufunc_realb_realb_to_realb_logical(
30-
logical_and, ctx, a, b, out);
29+
// @lint-ignore CLANGTIDY facebook-hte-CArray
30+
static constexpr const char op_name[] = "logical_and.out";
31+
return internal::logical_tensor_out<op_name>(logical_and, ctx, a, b, out);
3132
}
3233

3334
} // namespace native

kernels/portable/cpu/op_logical_or.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
9+
#include <executorch/kernels/portable/cpu/pattern/logical_op.h>
1010
#include <executorch/runtime/kernel/kernel_includes.h>
1111
#include <cmath>
1212

@@ -26,8 +26,9 @@ Tensor& logical_or_out(
2626
const Tensor& a,
2727
const Tensor& b,
2828
Tensor& out) {
29-
return internal::binary_ufunc_realb_realb_to_realb_logical(
30-
logical_or, ctx, a, b, out);
29+
// @lint-ignore CLANGTIDY facebook-hte-CArray
30+
static constexpr const char op_name[] = "logical_or.out";
31+
return internal::logical_tensor_out<op_name>(logical_or, ctx, a, b, out);
3132
}
3233

3334
} // namespace native

kernels/portable/cpu/op_logical_xor.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
9+
#include <executorch/kernels/portable/cpu/pattern/logical_op.h>
1010
#include <executorch/runtime/kernel/kernel_includes.h>
1111
#include <cmath>
1212

@@ -26,8 +26,9 @@ Tensor& logical_xor_out(
2626
const Tensor& a,
2727
const Tensor& b,
2828
Tensor& out) {
29-
return internal::binary_ufunc_realb_realb_to_realb_logical(
30-
logical_xor, ctx, a, b, out);
29+
// @lint-ignore CLANGTIDY facebook-hte-CArray
30+
static constexpr const char op_name[] = "logical_xor.out";
31+
return internal::logical_tensor_out<op_name>(logical_xor, ctx, a, b, out);
3132
}
3233

3334
} // namespace native

kernels/portable/cpu/pattern/binary_ufunc_realb_realb_to_realb_logical.cpp

Lines changed: 0 additions & 61 deletions
This file was deleted.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
namespace internal {
16+
17+
/**
18+
* Implements an op pattern for ops that take two broadcastable input tensors
19+
* and performs an element-wise binary logical operation `fn`.
20+
*/
21+
template <const char* op_name>
22+
Tensor& logical_tensor_out(
23+
bool (*fn)(bool, bool),
24+
KernelRuntimeContext& ctx,
25+
const Tensor& a,
26+
const Tensor& b,
27+
Tensor& out) {
28+
ET_KERNEL_CHECK(
29+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
30+
31+
ET_KERNEL_CHECK(
32+
ctx,
33+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
34+
InvalidArgument,
35+
out);
36+
37+
utils::apply_bitensor_elementwise_fn<bool, op_name>(
38+
fn,
39+
ctx,
40+
a,
41+
utils::SupportedTensorDtypes::REALHBBF16,
42+
b,
43+
utils::SupportedTensorDtypes::REALHBBF16,
44+
out,
45+
utils::SupportedTensorDtypes::REALHBBF16);
46+
47+
return out;
48+
}
49+
50+
} // namespace internal
51+
} // namespace native
52+
} // namespace executor
53+
} // namespace torch

kernels/portable/cpu/pattern/pattern.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,6 @@ Tensor& unary_ufunc_realhbbf16_to_floathbf16(
8989
const Tensor& in,
9090
Tensor& out);
9191

92-
/**
93-
* Implements an op pattern for ops that take two broadcastable input tensors
94-
* of any realb dtype, no additional arguments, performs an element-wise binary
95-
* logical operation, and outputs a realb tensor. The function fn specifies the
96-
* binary logical operation which is applied to the input tensors element-wise.
97-
*/
98-
Tensor& binary_ufunc_realb_realb_to_realb_logical(
99-
bool (*fn)(bool, bool),
100-
KernelRuntimeContext& ctx,
101-
const Tensor& a,
102-
const Tensor& b,
103-
Tensor& out);
104-
10592
} // namespace internal
10693
} // namespace native
10794
} // namespace executor

kernels/portable/cpu/pattern/targets.bzl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def define_common_targets():
1515
"//executorch/kernels/portable/cpu/pattern:pattern",
1616
"//executorch/kernels/portable/cpu/pattern:bitwise_op",
1717
"//executorch/kernels/portable/cpu/pattern:comparison_op",
18+
"//executorch/kernels/portable/cpu/pattern:logical_op"
1819
],
1920
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
2021
)
@@ -37,13 +38,21 @@ def define_common_targets():
3738
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
3839
)
3940

41+
runtime.cxx_library(
42+
name = "logical_op",
43+
exported_headers = [
44+
"logical_op.h",
45+
],
46+
compiler_flags = [],
47+
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
48+
)
49+
4050
runtime.cxx_library(
4151
name = "pattern",
4252
srcs = [
4353
"unary_ufunc_realhb_to_bool.cpp",
4454
"unary_ufunc_realhbbf16_to_floathbf16.cpp",
4555
"unary_ufunc_realh.cpp",
46-
"binary_ufunc_realb_realb_to_realb_logical.cpp",
4756
],
4857
exported_headers = [
4958
"pattern.h",

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,10 @@ ATEN_OPS = (
699699
op_target(
700700
name = "op_logical_and",
701701
deps = [
702-
"//executorch/kernels/portable/cpu/pattern:pattern",
702+
":scalar_utils",
703+
"//executorch/kernels/portable/cpu/pattern:logical_op",
704+
"//executorch/kernels/portable/cpu/util:broadcast_util",
705+
"//executorch/kernels/portable/cpu/util:elementwise_util",
703706
],
704707
),
705708
op_target(
@@ -712,13 +715,19 @@ ATEN_OPS = (
712715
op_target(
713716
name = "op_logical_or",
714717
deps = [
715-
"//executorch/kernels/portable/cpu/pattern:pattern",
718+
":scalar_utils",
719+
"//executorch/kernels/portable/cpu/pattern:logical_op",
720+
"//executorch/kernels/portable/cpu/util:broadcast_util",
721+
"//executorch/kernels/portable/cpu/util:elementwise_util",
716722
],
717723
),
718724
op_target(
719725
name = "op_logical_xor",
720726
deps = [
721-
"//executorch/kernels/portable/cpu/pattern:pattern",
727+
":scalar_utils",
728+
"//executorch/kernels/portable/cpu/pattern:logical_op",
729+
"//executorch/kernels/portable/cpu/util:broadcast_util",
730+
"//executorch/kernels/portable/cpu/util:elementwise_util",
722731
],
723732
),
724733
op_target(

0 commit comments

Comments
 (0)