14
14
#include < executorch/runtime/kernel/kernel_includes.h>
15
15
#include < executorch/runtime/platform/assert.h>
16
16
17
+ #include < executorch/kernels/optimized/cpu/op_add_sub_impl.h>
18
+
17
19
namespace torch {
18
20
namespace executor {
19
21
namespace native {
20
- namespace {
21
-
22
- template <
23
- bool can_cast,
24
- typename CTYPE_A,
25
- typename CTYPE_B,
26
- typename CTYPE_IN,
27
- typename CTYPE_OUT>
28
- struct AddInner ;
29
-
30
- template <
31
- typename CTYPE_A,
32
- typename CTYPE_B,
33
- typename CTYPE_IN,
34
- typename CTYPE_OUT>
35
- struct AddInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36
- static void
37
- run (const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39
- // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40
- [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
42
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
43
- CTYPE_IN value = a_casted + alpha_val * b_casted;
44
-
45
- return static_cast <CTYPE_OUT>(value);
46
- },
47
- a,
48
- b,
49
- out);
50
- }
51
- };
52
-
53
- template <typename CTYPE_IN>
54
- struct ReportCanCastBug {
55
- static void run (const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56
- ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
57
- }
58
- };
59
-
60
- template <
61
- typename CTYPE_A,
62
- typename CTYPE_B,
63
- typename CTYPE_IN,
64
- typename CTYPE_OUT>
65
- struct AddInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66
- : public ReportCanCastBug<CTYPE_IN> {};
67
-
68
- } // namespace
69
-
70
22
using Tensor = executorch::aten::Tensor;
71
23
using ScalarType = executorch::aten::ScalarType;
72
24
@@ -76,8 +28,6 @@ Tensor& opt_add_out(
76
28
const Tensor& b,
77
29
const Scalar& alpha,
78
30
Tensor& out) {
79
- (void )ctx;
80
-
81
31
ScalarType a_type = a.scalar_type ();
82
32
ScalarType b_type = b.scalar_type ();
83
33
ScalarType out_type = out.scalar_type ();
@@ -95,7 +45,9 @@ Tensor& opt_add_out(
95
45
ET_SWITCH_REALB_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
96
46
CTYPE alpha_val;
97
47
ET_KERNEL_CHECK (
98
- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
48
+ ctx,
49
+ torch::executor::native::utils::extract_scalar (alpha, &alpha_val),
50
+ InvalidArgument, );
99
51
CTYPE_B b_val = *b.const_data_ptr <CTYPE_B>();
100
52
CTYPE b_casted = static_cast <CTYPE>(b_val);
101
53
@@ -115,100 +67,9 @@ Tensor& opt_add_out(
115
67
return opt_add_out (ctx, b, a, alpha, out);
116
68
}
117
69
118
- auto selected_optimized_path = select_optimized_path (a, b, out);
119
- if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
120
- // Resize for dynamic shape
121
- auto error = resize_tensor (out, a.sizes ());
122
- ET_KERNEL_CHECK_MSG (
123
- ctx,
124
- error == Error::Ok,
125
- InvalidArgument,
126
- out,
127
- " Failed to resize output tensor." );
128
-
129
- ET_SWITCH_REALB_TYPES (a_type, ctx, " add.out" , CTYPE, [&]() {
130
- CTYPE alpha_val;
131
- ET_KERNEL_CHECK (
132
- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
133
-
134
- using Vec = executorch::vec::Vectorized<CTYPE>;
135
- executorch::vec::map2<CTYPE>(
136
- [alpha_val](Vec x, Vec y) { return x + Vec (alpha_val) * y; },
137
- out.mutable_data_ptr <CTYPE>(),
138
- a.const_data_ptr <CTYPE>(),
139
- b.const_data_ptr <CTYPE>(),
140
- out.numel ());
141
- });
142
- } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
143
- ET_SWITCH_REALB_TYPES (out_type, ctx, " add.out" , CTYPE, [&]() {
144
- CTYPE alpha_val;
145
- ET_KERNEL_CHECK_MSG (
146
- ctx,
147
- utils::extract_scalar (alpha, &alpha_val),
148
- InvalidArgument,
149
- out,
150
- " Failed to extract scalar alpha." );
151
- using Vec = executorch::vec::Vectorized<CTYPE>;
152
- Vec alpha_val_vec (alpha_val);
153
- if (selected_optimized_path ==
154
- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
155
- selected_optimized_path ==
156
- ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
157
- selected_optimized_path ==
158
- ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
159
- // Reason we swap out args here is because handle_broadcast_elementwise
160
- // handles this selected_optimized_path option a bit differently.
161
- // This should really be resolved in handle_broadcast_elementwise.
162
- // However, the current blocker is that handle_broadcast_elementwise
163
- // tries to be agnostic of op. This should be fixed, likely by moving
164
- // lambda creation to handle_broadcast_elementwise and it be aware of
165
- // which op is being executed.
166
- auto add_lambda = [&alpha_val_vec](auto x, auto y) {
167
- return y + alpha_val_vec * x;
168
- };
169
- return torch::executor::handle_broadcast_elementwise<CTYPE>(
170
- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
171
- } else {
172
- auto add_lambda = [&alpha_val_vec](auto x, auto y) {
173
- return x + alpha_val_vec * y;
174
- };
175
- return torch::executor::handle_broadcast_elementwise<CTYPE>(
176
- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
177
- }
178
- });
179
- } else {
180
- ScalarType common_type =
181
- promoteTypes (a_type, b_type, /* half_to_float*/ true );
182
- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
183
-
184
- ET_KERNEL_CHECK (
185
- ctx,
186
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
187
- InvalidArgument,
188
- out);
189
-
190
- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " add.out" , CTYPE_A, [&]() {
191
- ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, " add.out" , CTYPE_B, [&]() {
192
- using CTYPE_IN = typename torch::executor::
193
- promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
194
- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
195
- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " add.out" , CTYPE_OUT, [&]() {
196
- CTYPE_IN alpha_val;
197
- ET_KERNEL_CHECK (
198
- ctx, utils::extract_scalar (alpha, &alpha_val), InvalidArgument, );
199
-
200
- AddInner<
201
- can_cast<CTYPE_IN, CTYPE_OUT>::value,
202
- CTYPE_A,
203
- CTYPE_B,
204
- CTYPE_IN,
205
- CTYPE_OUT>::run (a, b, alpha_val, out);
206
- });
207
- });
208
- });
209
- }
210
-
211
- return out;
70
+ static constexpr const char op_name[] = " add.out" ;
71
+ return torch::executor::kernels::impl::opt_add_sub_out_impl<false , op_name>(
72
+ ctx, a, b, alpha, out);
212
73
}
213
74
214
75
Tensor& opt_add_scalar_out (
0 commit comments