@@ -35,21 +35,7 @@ Tensor& add_out(
35
35
const Tensor& b,
36
36
const Scalar& alpha,
37
37
Tensor& out) {
38
- // Common Dtype
39
- ScalarType common_type =
40
- executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
41
-
42
38
#ifdef OP_ARG_CHECK
43
- // Check Common Dtype
44
- ET_KERNEL_CHECK (
45
- ctx,
46
- (canCast (common_type, out.scalar_type ()) &&
47
- torch::executor::check_alpha_type (
48
- torch::executor::native::utils::get_scalar_dtype (alpha),
49
- common_type)),
50
- InvalidArgument,
51
- out);
52
-
53
39
// Check Dim Order
54
40
ET_KERNEL_CHECK (
55
41
ctx,
@@ -65,10 +51,6 @@ Tensor& add_out(
65
51
out);
66
52
#endif
67
53
68
- // Compute Dtype
69
- ScalarType compute_type =
70
- torch::executor::native::utils::get_compute_type (common_type);
71
-
72
54
static constexpr const char op_name[] = " add.out" ;
73
55
74
56
int kTensorDimensionLimit = 5 ;
@@ -77,12 +59,12 @@ Tensor& add_out(
77
59
int inp2_shape[kTensorDimensionLimit ];
78
60
int out_shape[kTensorDimensionLimit ];
79
61
80
- bool broadcast = 0 ;
62
+ bool broadcast = false ;
81
63
82
64
int max_dim = a.dim () > b.dim () ? a.dim () : b.dim ();
83
65
max_dim = out.dim () > max_dim ? out.dim () : max_dim;
84
66
85
- bool optimized = 1 ;
67
+ bool optimized = true ;
86
68
87
69
/* Added change to work with input dimensions more than 5 */
88
70
for (int i = 0 ; i < max_dim; i++) {
@@ -109,15 +91,19 @@ Tensor& add_out(
109
91
for (int i = 0 ; i < out.dim (); i++) {
110
92
if (((inp1_shape[i]) != (out_shape[i])) ||
111
93
((inp2_shape[i]) != (out_shape[i]))) {
112
- broadcast = 1 ;
94
+ broadcast = true ;
113
95
}
114
96
}
115
97
116
- if ((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) {
117
- optimized = 0 ;
98
+ if (((broadcast) && (max_dim > kTensorDimensionLimit )) ||
99
+ (!(((a.scalar_type () == ScalarType::Int) ||
100
+ (a.scalar_type () == ScalarType::Float)) &&
101
+ (a.scalar_type () == b.scalar_type ()) &&
102
+ (a.scalar_type () == out.scalar_type ())))) {
103
+ optimized = false ;
118
104
}
119
105
120
- if ((compute_type == ScalarType::Int) && (optimized)) {
106
+ if ((a. scalar_type () == ScalarType::Int) && (optimized)) {
121
107
const int * const inp1_data = a.const_data_ptr <int >();
122
108
const int * const inp2_data = b.const_data_ptr <int >();
123
109
int * const out_data = out.mutable_data_ptr <int >();
@@ -169,7 +155,7 @@ Tensor& add_out(
169
155
alpha_val,
170
156
out.numel ());
171
157
}
172
- } else if ((compute_type == ScalarType::Float) && (optimized)) {
158
+ } else if ((a. scalar_type () == ScalarType::Float) && (optimized)) {
173
159
const float * const inp1_data = a.const_data_ptr <float >();
174
160
const float * const inp2_data = b.const_data_ptr <float >();
175
161
float * const out_data = out.mutable_data_ptr <float >();
@@ -222,6 +208,23 @@ Tensor& add_out(
222
208
out.numel ());
223
209
}
224
210
} else {
211
+ // Common Dtype
212
+ ScalarType common_type =
213
+ executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
214
+ // Compute Dtype
215
+ ScalarType compute_type =
216
+ torch::executor::native::utils::get_compute_type (common_type);
217
+
218
+ // Check Common Dtype
219
+ ET_KERNEL_CHECK (
220
+ ctx,
221
+ (canCast (common_type, out.scalar_type ()) &&
222
+ torch::executor::check_alpha_type (
223
+ torch::executor::native::utils::get_scalar_dtype (alpha),
224
+ common_type)),
225
+ InvalidArgument,
226
+ out);
227
+
225
228
ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
226
229
const CTYPE_COMPUTE val_alpha =
227
230
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
@@ -249,22 +252,7 @@ Tensor& add_scalar_out(
249
252
const Scalar& b,
250
253
const Scalar& alpha,
251
254
Tensor& out) {
252
- // Common Dtype
253
- ScalarType common_type =
254
- torch::executor::native::utils::promote_type_with_scalar (
255
- a.scalar_type (), b);
256
-
257
255
#ifdef OP_ARG_CHECK
258
- // Check Common Dtype
259
- ET_KERNEL_CHECK (
260
- ctx,
261
- (common_type == out.scalar_type () &&
262
- torch::executor::check_alpha_type (
263
- torch::executor::native::utils::get_scalar_dtype (alpha),
264
- common_type)),
265
- InvalidArgument,
266
- out);
267
-
268
256
// Check Dim Order
269
257
ET_KERNEL_CHECK (
270
258
ctx,
@@ -279,14 +267,23 @@ Tensor& add_scalar_out(
279
267
InvalidArgument,
280
268
out);
281
269
#endif
282
- // Compute Dtype
283
- ScalarType compute_type =
284
- torch::executor::native::utils::get_compute_type (common_type);
285
270
286
271
// @lint-ignore CLANGTIDY facebook-hte-CArray
287
272
static constexpr const char op_name[] = " add.Scalar_out" ;
288
273
289
- if (compute_type == ScalarType::Int) {
274
+ bool optimized = true ;
275
+
276
+ if (!(((a.scalar_type () == ScalarType::Int) ||
277
+ (a.scalar_type () == ScalarType::Float)) &&
278
+ (a.scalar_type () == out.scalar_type ()))) {
279
+ optimized = false ;
280
+ }
281
+
282
+ if ((b.isFloatingPoint ()) && (a.scalar_type () == ScalarType::Int)) {
283
+ optimized = false ;
284
+ }
285
+
286
+ if ((a.scalar_type () == ScalarType::Int) && (optimized)) {
290
287
const int * const inp1_data = a.const_data_ptr <int >();
291
288
int inp2_val;
292
289
torch::executor::native::utils::extract_scalar (b, &inp2_val);
@@ -306,7 +303,7 @@ Tensor& add_scalar_out(
306
303
alpha_val,
307
304
out.numel ());
308
305
309
- } else if (compute_type == ScalarType::Float) {
306
+ } else if ((a. scalar_type () == ScalarType::Float) && (optimized) ) {
310
307
const float * const inp1_data = a.const_data_ptr <float >();
311
308
float inp2_val;
312
309
torch::executor::native::utils::extract_scalar (b, &inp2_val);
@@ -327,6 +324,24 @@ Tensor& add_scalar_out(
327
324
out.numel ());
328
325
329
326
} else {
327
+ // Common Dtype
328
+ ScalarType common_type =
329
+ torch::executor::native::utils::promote_type_with_scalar (
330
+ a.scalar_type (), b);
331
+ // Compute Dtype
332
+ ScalarType compute_type =
333
+ torch::executor::native::utils::get_compute_type (common_type);
334
+
335
+ // Check Common Dtype
336
+ ET_KERNEL_CHECK (
337
+ ctx,
338
+ (common_type == out.scalar_type () &&
339
+ torch::executor::check_alpha_type (
340
+ torch::executor::native::utils::get_scalar_dtype (alpha),
341
+ common_type)),
342
+ InvalidArgument,
343
+ out);
344
+
330
345
ET_SWITCH_REALB_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
331
346
torch::executor::native::utils::
332
347
apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
0 commit comments