Skip to content

Commit addb403

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Rewrite bitwise op pattern. Binary ops: bitwise_and, bitwise_or, bitwise_xor (#6014)
Summary: Pull Request resolved: #6014 - bitwise_or: 417 K -> 12 K - bitwise_xor: 416 K -> 12 K - bitwise_and: 406 K -> 12 K ghstack-source-id: 246985126 exported-using-ghexport Reviewed By: swolchok Differential Revision: D63933330 fbshipit-source-id: 5227617464642e99c20d1fea275d8be2280d020a
1 parent a2cf409 commit addb403

File tree

8 files changed

+167
-393
lines changed

8 files changed

+167
-393
lines changed

kernels/portable/cpu/op_bitwise_and.cpp

Lines changed: 6 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -6,135 +6,30 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
// patternlint-disable-next-line executorch-cpp-nostdinc
10-
#include <functional>
11-
129
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
13-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
14-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
15-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
16-
#include <executorch/runtime/kernel/kernel_includes.h>
1710

1811
namespace torch {
1912
namespace executor {
2013
namespace native {
2114

22-
using Tensor = exec_aten::Tensor;
23-
2415
Tensor& bitwise_and_Tensor_out(
2516
KernelRuntimeContext& ctx,
2617
const Tensor& a,
2718
const Tensor& b,
2819
Tensor& out) {
29-
ET_KERNEL_CHECK(
30-
ctx,
31-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
32-
InvalidArgument,
33-
out);
34-
35-
ET_KERNEL_CHECK(
36-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
37-
38-
ScalarType a_type = a.scalar_type();
39-
ScalarType b_type = b.scalar_type();
40-
ScalarType common_type = promoteTypes(a_type, b_type);
41-
ScalarType out_type = out.scalar_type();
42-
43-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
44-
45-
ET_SWITCH_INT_TYPES_AND(
46-
Bool, a_type, ctx, "bitwise_and.Tensor_out", CTYPE_A, [&]() {
47-
ET_SWITCH_INT_TYPES_AND(
48-
Bool, b_type, ctx, "bitwise_and.Tensor_out", CTYPE_B, [&]() {
49-
using CTYPE_IN = typename torch::executor::
50-
promote_types<CTYPE_A, CTYPE_B>::type;
51-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
52-
ET_SWITCH_REAL_TYPES_AND(
53-
Bool,
54-
out_type,
55-
ctx,
56-
"bitwise_and.Tensor_out",
57-
CTYPE_OUT,
58-
[&]() {
59-
internal::BitwiseOpInner<
60-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
61-
std::bit_and,
62-
CTYPE_A,
63-
CTYPE_B,
64-
CTYPE_IN,
65-
CTYPE_OUT>::run(a, b, out);
66-
});
67-
});
68-
});
69-
70-
return out;
20+
// @lint-ignore CLANGTIDY facebook-hte-CArray
21+
static constexpr const char op_name[] = "bitwise_and.Tensor_out";
22+
return internal::bitwise_tensor_out<op_name>(ctx, a, b, out);
7123
}
7224

7325
Tensor& bitwise_and_Scalar_out(
7426
KernelRuntimeContext& ctx,
7527
const Tensor& a,
7628
const Scalar& b,
7729
Tensor& out) {
78-
(void)ctx;
79-
80-
// Resize for dynamic shape
81-
ET_KERNEL_CHECK_MSG(
82-
ctx,
83-
resize_tensor(out, a.sizes()) == Error::Ok,
84-
InvalidArgument,
85-
out,
86-
"Failed to resize output tensor.");
87-
88-
ET_KERNEL_CHECK(
89-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
90-
91-
ScalarType a_type = a.scalar_type();
92-
ScalarType b_type = utils::get_scalar_dtype(b);
93-
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
94-
ScalarType out_type = out.scalar_type();
95-
96-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
97-
98-
ET_SWITCH_INT_TYPES_AND(
99-
Bool, a_type, ctx, "bitwise_and.Scalar_out", CTYPE_A, [&]() {
100-
ET_SWITCH_SCALAR_OBJ_INTB_TYPES(
101-
b_type, ctx, "bitwise_and.Scalar_out", CTYPE_B, [&]() {
102-
CTYPE_B val_b = 0;
103-
utils::extract_scalar(b, &val_b);
104-
ET_SWITCH_INT_TYPES_AND(
105-
Bool,
106-
common_type,
107-
ctx,
108-
"bitwise_and.Scalar_out",
109-
CTYPE_IN,
110-
[&]() {
111-
ET_SWITCH_REAL_TYPES_AND(
112-
Bool,
113-
out_type,
114-
ctx,
115-
"bitwise_and.Scalar_out",
116-
CTYPE_OUT,
117-
[&]() {
118-
apply_unary_map_fn(
119-
[val_b](const CTYPE_A val_a) {
120-
CTYPE_IN a_casted =
121-
static_cast<CTYPE_IN>(val_a);
122-
CTYPE_IN b_casted =
123-
static_cast<CTYPE_IN>(val_b);
124-
CTYPE_IN value = std::bit_and<CTYPE_IN>()(
125-
a_casted, b_casted);
126-
127-
return static_cast<CTYPE_OUT>(value);
128-
},
129-
a.const_data_ptr<CTYPE_A>(),
130-
out.mutable_data_ptr<CTYPE_OUT>(),
131-
out.numel());
132-
});
133-
});
134-
});
135-
});
136-
137-
return out;
30+
// @lint-ignore CLANGTIDY facebook-hte-CArray
31+
static constexpr const char op_name[] = "bitwise_and.Scalar_out";
32+
return internal::bitwise_scalar_out<op_name>(ctx, a, b, out);
13833
}
13934

14035
} // namespace native

kernels/portable/cpu/op_bitwise_or.cpp

Lines changed: 6 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -6,135 +6,30 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
// patternlint-disable-next-line executorch-cpp-nostdinc
10-
#include <functional>
11-
129
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
13-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
14-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
15-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
16-
#include <executorch/runtime/kernel/kernel_includes.h>
1710

1811
namespace torch {
1912
namespace executor {
2013
namespace native {
2114

22-
using Tensor = exec_aten::Tensor;
23-
2415
Tensor& bitwise_or_Tensor_out(
2516
KernelRuntimeContext& ctx,
2617
const Tensor& a,
2718
const Tensor& b,
2819
Tensor& out) {
29-
ET_KERNEL_CHECK(
30-
ctx,
31-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
32-
InvalidArgument,
33-
out);
34-
35-
ET_KERNEL_CHECK(
36-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
37-
38-
ScalarType a_type = a.scalar_type();
39-
ScalarType b_type = b.scalar_type();
40-
ScalarType common_type = promoteTypes(a_type, b_type);
41-
ScalarType out_type = out.scalar_type();
42-
43-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
44-
45-
ET_SWITCH_INT_TYPES_AND(
46-
Bool, a_type, ctx, "bitwise_or.Tensor_out", CTYPE_A, [&]() {
47-
ET_SWITCH_INT_TYPES_AND(
48-
Bool, b_type, ctx, "bitwise_or.Tensor_out", CTYPE_B, [&]() {
49-
using CTYPE_IN = typename torch::executor::
50-
promote_types<CTYPE_A, CTYPE_B>::type;
51-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
52-
ET_SWITCH_REAL_TYPES_AND(
53-
Bool,
54-
out_type,
55-
ctx,
56-
"bitwise_or.Tensor_out",
57-
CTYPE_OUT,
58-
[&]() {
59-
internal::BitwiseOpInner<
60-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
61-
std::bit_or,
62-
CTYPE_A,
63-
CTYPE_B,
64-
CTYPE_IN,
65-
CTYPE_OUT>::run(a, b, out);
66-
});
67-
});
68-
});
69-
70-
return out;
20+
// @lint-ignore CLANGTIDY facebook-hte-CArray
21+
static constexpr const char op_name[] = "bitwise_or.Tensor_out";
22+
return internal::bitwise_tensor_out<op_name>(ctx, a, b, out);
7123
}
7224

7325
Tensor& bitwise_or_Scalar_out(
7426
KernelRuntimeContext& ctx,
7527
const Tensor& a,
7628
const Scalar& b,
7729
Tensor& out) {
78-
(void)ctx;
79-
80-
ET_KERNEL_CHECK(
81-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
82-
83-
// Resize for dynamic shape
84-
ET_KERNEL_CHECK_MSG(
85-
ctx,
86-
resize_tensor(out, a.sizes()) == Error::Ok,
87-
InvalidArgument,
88-
out,
89-
"Failed to resize output tensor.");
90-
91-
ScalarType a_type = a.scalar_type();
92-
ScalarType b_type = utils::get_scalar_dtype(b);
93-
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
94-
ScalarType out_type = out.scalar_type();
95-
96-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
97-
98-
ET_SWITCH_INT_TYPES_AND(
99-
Bool, a_type, ctx, "bitwise_or.Scalar_out", CTYPE_A, [&]() {
100-
ET_SWITCH_SCALAR_OBJ_INTB_TYPES(
101-
b_type, ctx, "bitwise_or.Scalar_out", CTYPE_B, [&]() {
102-
CTYPE_B val_b = 0;
103-
utils::extract_scalar(b, &val_b);
104-
ET_SWITCH_INT_TYPES_AND(
105-
Bool,
106-
common_type,
107-
ctx,
108-
"bitwise_or.Scalar_out",
109-
CTYPE_IN,
110-
[&]() {
111-
ET_SWITCH_REAL_TYPES_AND(
112-
Bool,
113-
out_type,
114-
ctx,
115-
"bitwise_or.Scalar_out",
116-
CTYPE_OUT,
117-
[&]() {
118-
apply_unary_map_fn(
119-
[val_b](const CTYPE_A val_a) {
120-
CTYPE_IN a_casted =
121-
static_cast<CTYPE_IN>(val_a);
122-
CTYPE_IN b_casted =
123-
static_cast<CTYPE_IN>(val_b);
124-
CTYPE_IN value =
125-
std::bit_or<CTYPE_IN>()(a_casted, b_casted);
126-
127-
return static_cast<CTYPE_OUT>(value);
128-
},
129-
a.const_data_ptr<CTYPE_A>(),
130-
out.mutable_data_ptr<CTYPE_OUT>(),
131-
out.numel());
132-
});
133-
});
134-
});
135-
});
136-
137-
return out;
30+
// @lint-ignore CLANGTIDY facebook-hte-CArray
31+
static constexpr const char op_name[] = "bitwise_or.Scalar_out";
32+
return internal::bitwise_scalar_out<op_name>(ctx, a, b, out);
13833
}
13934

14035
} // namespace native

0 commit comments

Comments
 (0)