|
6 | 6 | * LICENSE file in the root directory of this source tree.
|
7 | 7 | */
|
8 | 8 |
|
9 |
| -// patternlint-disable-next-line executorch-cpp-nostdinc |
10 |
| -#include <functional> |
11 |
| - |
12 | 9 | #include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
|
13 |
| -#include <executorch/kernels/portable/cpu/scalar_utils.h> |
14 |
| -#include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
15 |
| -#include <executorch/kernels/portable/cpu/util/functional_util.h> |
16 |
| -#include <executorch/runtime/kernel/kernel_includes.h> |
17 | 10 |
|
18 | 11 | namespace torch {
|
19 | 12 | namespace executor {
|
20 | 13 | namespace native {
|
21 | 14 |
|
22 |
| -using Tensor = exec_aten::Tensor; |
23 |
| - |
24 | 15 | Tensor& bitwise_or_Tensor_out(
|
25 | 16 | KernelRuntimeContext& ctx,
|
26 | 17 | const Tensor& a,
|
27 | 18 | const Tensor& b,
|
28 | 19 | Tensor& out) {
|
29 |
| - ET_KERNEL_CHECK( |
30 |
| - ctx, |
31 |
| - resize_to_broadcast_target_size(a, b, out) == Error::Ok, |
32 |
| - InvalidArgument, |
33 |
| - out); |
34 |
| - |
35 |
| - ET_KERNEL_CHECK( |
36 |
| - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); |
37 |
| - |
38 |
| - ScalarType a_type = a.scalar_type(); |
39 |
| - ScalarType b_type = b.scalar_type(); |
40 |
| - ScalarType common_type = promoteTypes(a_type, b_type); |
41 |
| - ScalarType out_type = out.scalar_type(); |
42 |
| - |
43 |
| - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); |
44 |
| - |
45 |
| - ET_SWITCH_INT_TYPES_AND( |
46 |
| - Bool, a_type, ctx, "bitwise_or.Tensor_out", CTYPE_A, [&]() { |
47 |
| - ET_SWITCH_INT_TYPES_AND( |
48 |
| - Bool, b_type, ctx, "bitwise_or.Tensor_out", CTYPE_B, [&]() { |
49 |
| - using CTYPE_IN = typename torch::executor:: |
50 |
| - promote_types<CTYPE_A, CTYPE_B>::type; |
51 |
| - ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type); |
52 |
| - ET_SWITCH_REAL_TYPES_AND( |
53 |
| - Bool, |
54 |
| - out_type, |
55 |
| - ctx, |
56 |
| - "bitwise_or.Tensor_out", |
57 |
| - CTYPE_OUT, |
58 |
| - [&]() { |
59 |
| - internal::BitwiseOpInner< |
60 |
| - can_cast<CTYPE_IN, CTYPE_OUT>::value, |
61 |
| - std::bit_or, |
62 |
| - CTYPE_A, |
63 |
| - CTYPE_B, |
64 |
| - CTYPE_IN, |
65 |
| - CTYPE_OUT>::run(a, b, out); |
66 |
| - }); |
67 |
| - }); |
68 |
| - }); |
69 |
| - |
70 |
| - return out; |
| 20 | + // @lint-ignore CLANGTIDY facebook-hte-CArray |
| 21 | + static constexpr const char op_name[] = "bitwise_or.Tensor_out"; |
| 22 | + return internal::bitwise_tensor_out<op_name>(ctx, a, b, out); |
71 | 23 | }
|
72 | 24 |
|
73 | 25 | Tensor& bitwise_or_Scalar_out(
|
74 | 26 | KernelRuntimeContext& ctx,
|
75 | 27 | const Tensor& a,
|
76 | 28 | const Scalar& b,
|
77 | 29 | Tensor& out) {
|
78 |
| - (void)ctx; |
79 |
| - |
80 |
| - ET_KERNEL_CHECK( |
81 |
| - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); |
82 |
| - |
83 |
| - // Resize for dynamic shape |
84 |
| - ET_KERNEL_CHECK_MSG( |
85 |
| - ctx, |
86 |
| - resize_tensor(out, a.sizes()) == Error::Ok, |
87 |
| - InvalidArgument, |
88 |
| - out, |
89 |
| - "Failed to resize output tensor."); |
90 |
| - |
91 |
| - ScalarType a_type = a.scalar_type(); |
92 |
| - ScalarType b_type = utils::get_scalar_dtype(b); |
93 |
| - ScalarType common_type = utils::promote_type_with_scalar(a_type, b); |
94 |
| - ScalarType out_type = out.scalar_type(); |
95 |
| - |
96 |
| - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); |
97 |
| - |
98 |
| - ET_SWITCH_INT_TYPES_AND( |
99 |
| - Bool, a_type, ctx, "bitwise_or.Scalar_out", CTYPE_A, [&]() { |
100 |
| - ET_SWITCH_SCALAR_OBJ_INTB_TYPES( |
101 |
| - b_type, ctx, "bitwise_or.Scalar_out", CTYPE_B, [&]() { |
102 |
| - CTYPE_B val_b = 0; |
103 |
| - utils::extract_scalar(b, &val_b); |
104 |
| - ET_SWITCH_INT_TYPES_AND( |
105 |
| - Bool, |
106 |
| - common_type, |
107 |
| - ctx, |
108 |
| - "bitwise_or.Scalar_out", |
109 |
| - CTYPE_IN, |
110 |
| - [&]() { |
111 |
| - ET_SWITCH_REAL_TYPES_AND( |
112 |
| - Bool, |
113 |
| - out_type, |
114 |
| - ctx, |
115 |
| - "bitwise_or.Scalar_out", |
116 |
| - CTYPE_OUT, |
117 |
| - [&]() { |
118 |
| - apply_unary_map_fn( |
119 |
| - [val_b](const CTYPE_A val_a) { |
120 |
| - CTYPE_IN a_casted = |
121 |
| - static_cast<CTYPE_IN>(val_a); |
122 |
| - CTYPE_IN b_casted = |
123 |
| - static_cast<CTYPE_IN>(val_b); |
124 |
| - CTYPE_IN value = |
125 |
| - std::bit_or<CTYPE_IN>()(a_casted, b_casted); |
126 |
| - |
127 |
| - return static_cast<CTYPE_OUT>(value); |
128 |
| - }, |
129 |
| - a.const_data_ptr<CTYPE_A>(), |
130 |
| - out.mutable_data_ptr<CTYPE_OUT>(), |
131 |
| - out.numel()); |
132 |
| - }); |
133 |
| - }); |
134 |
| - }); |
135 |
| - }); |
136 |
| - |
137 |
| - return out; |
| 30 | + // @lint-ignore CLANGTIDY facebook-hte-CArray |
| 31 | + static constexpr const char op_name[] = "bitwise_or.Scalar_out"; |
| 32 | + return internal::bitwise_scalar_out<op_name>(ctx, a, b, out); |
138 | 33 | }
|
139 | 34 |
|
140 | 35 | } // namespace native
|
|
0 commit comments