14
14
#include < executorch/runtime/kernel/kernel_includes.h>
15
15
#include < executorch/runtime/platform/assert.h>
16
16
17
+ #include < executorch/kernels/optimized/cpu/op_add_sub_impl.h>
18
+
17
19
namespace torch {
18
20
namespace executor {
19
21
namespace native {
20
- namespace {
21
-
22
- template <
23
- bool can_cast,
24
- typename CTYPE_A,
25
- typename CTYPE_B,
26
- typename CTYPE_IN,
27
- typename CTYPE_OUT>
28
- struct AddInner ;
29
-
30
- template <
31
- typename CTYPE_A,
32
- typename CTYPE_B,
33
- typename CTYPE_IN,
34
- typename CTYPE_OUT>
35
- struct AddInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36
- static void
37
- run (const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39
- // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40
- [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
42
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
43
- CTYPE_IN value = a_casted + alpha_val * b_casted;
44
-
45
- return static_cast <CTYPE_OUT>(value);
46
- },
47
- a,
48
- b,
49
- out);
50
- }
51
- };
52
-
53
- template <typename CTYPE_IN>
54
- struct ReportCanCastBug {
55
- static void run (const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56
- ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
57
- }
58
- };
59
-
60
- template <
61
- typename CTYPE_A,
62
- typename CTYPE_B,
63
- typename CTYPE_IN,
64
- typename CTYPE_OUT>
65
- struct AddInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66
- : public ReportCanCastBug<CTYPE_IN> {};
67
-
68
- } // namespace
69
-
70
22
using Tensor = executorch::aten::Tensor;
71
23
using ScalarType = executorch::aten::ScalarType;
72
24
@@ -76,8 +28,6 @@ Tensor& opt_add_out(
76
28
const Tensor& b,
77
29
const Scalar& alpha,
78
30
Tensor& out) {
79
- (void )ctx;
80
-
81
31
ScalarType a_type = a.scalar_type ();
82
32
ScalarType b_type = b.scalar_type ();
83
33
ScalarType out_type = out.scalar_type ();
@@ -95,7 +45,9 @@ Tensor& opt_add_out(
95
45
ET_SWITCH_REALB_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
96
46
CTYPE alpha_val;
97
47
ET_KERNEL_CHECK (
98
- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
48
+ ctx,
49
+ torch::executor::native::utils::extract_scalar (alpha, &alpha_val),
50
+ InvalidArgument, );
99
51
CTYPE_B b_val = *b.const_data_ptr <CTYPE_B>();
100
52
CTYPE b_casted = static_cast <CTYPE>(b_val);
101
53
@@ -115,101 +67,9 @@ Tensor& opt_add_out(
115
67
return opt_add_out (ctx, b, a, alpha, out);
116
68
}
117
69
118
- auto selected_optimized_path = select_optimized_path (a, b, out);
119
- if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
120
- // Resize for dynamic shape
121
- auto error = resize_tensor (out, a.sizes ());
122
- ET_KERNEL_CHECK_MSG (
123
- ctx,
124
- error == Error::Ok,
125
- InvalidArgument,
126
- out,
127
- " Failed to resize output tensor." );
128
-
129
- ET_SWITCH_REALB_TYPES (a_type, ctx, " add.out" , CTYPE, [&]() {
130
- CTYPE alpha_val;
131
- ET_KERNEL_CHECK (
132
- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
133
-
134
- using Vec = executorch::vec::Vectorized<CTYPE>;
135
- executorch::vec::map2<CTYPE>(
136
- [alpha_val](Vec x, Vec y) { return x + Vec (alpha_val) * y; },
137
- out.mutable_data_ptr <CTYPE>(),
138
- a.const_data_ptr <CTYPE>(),
139
- b.const_data_ptr <CTYPE>(),
140
- out.numel ());
141
- });
142
- } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
143
- ScalarType out_type = out.scalar_type ();
144
- ET_SWITCH_REALB_TYPES (out_type, ctx, " add.out" , CTYPE, [&]() {
145
- CTYPE alpha_val;
146
- ET_KERNEL_CHECK_MSG (
147
- ctx,
148
- utils::extract_scalar (alpha, &alpha_val),
149
- InvalidArgument,
150
- out,
151
- " Failed to extract scalar alpha." );
152
- using Vec = executorch::vec::Vectorized<CTYPE>;
153
- Vec alpha_val_vec (alpha_val);
154
- if (selected_optimized_path ==
155
- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
156
- selected_optimized_path ==
157
- ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
158
- selected_optimized_path ==
159
- ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
160
- // Reason we swap out args here is because handle_broadcast_elementwise
161
- // handles this selected_optimized_path option a bit differently.
162
- // This should really be resolved in handle_broadcast_elementwise.
163
- // However, the current blocker is that handle_broadcast_elementwise
164
- // tries to be agnostic of op. This should be fixed, likely by moving
165
- // lambda creation to handle_broadcast_elementwise and it be aware of
166
- // which op is being executed.
167
- auto add_lambda = [&alpha_val_vec](auto x, auto y) {
168
- return y + alpha_val_vec * x;
169
- };
170
- return torch::executor::handle_broadcast_elementwise<CTYPE>(
171
- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
172
- } else {
173
- auto add_lambda = [&alpha_val_vec](auto x, auto y) {
174
- return x + alpha_val_vec * y;
175
- };
176
- return torch::executor::handle_broadcast_elementwise<CTYPE>(
177
- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
178
- }
179
- });
180
- } else {
181
- ScalarType common_type =
182
- promoteTypes (a_type, b_type, /* half_to_float*/ true );
183
- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
184
-
185
- ET_KERNEL_CHECK (
186
- ctx,
187
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
188
- InvalidArgument,
189
- out);
190
-
191
- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " add.out" , CTYPE_A, [&]() {
192
- ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
193
- using CTYPE_IN = typename torch::executor::
194
- promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
195
- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
196
- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " add.out" , CTYPE_OUT, [&]() {
197
- CTYPE_IN alpha_val;
198
- ET_KERNEL_CHECK (
199
- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
200
-
201
- AddInner<
202
- can_cast<CTYPE_IN, CTYPE_OUT>::value,
203
- CTYPE_A,
204
- CTYPE_B,
205
- CTYPE_IN,
206
- CTYPE_OUT>::run (a, b, alpha_val, out);
207
- });
208
- });
209
- });
210
- }
211
-
212
- return out;
70
+ static constexpr const char op_name[] = " add.out" ;
71
+ return torch::executor::kernels::impl::opt_add_sub_out_impl<false , op_name>(
72
+ ctx, a, b, alpha, out);
213
73
}
214
74
215
75
Tensor& opt_add_scalar_out (
0 commit comments