@@ -34,13 +34,9 @@ Tensor& sub_out(
34
34
const Tensor& b,
35
35
const Scalar& alpha,
36
36
Tensor& out) {
37
- // Common Dtype
38
- ScalarType common_type =
39
- executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
40
37
#ifdef OP_ARG_CHECK
41
38
ScalarType alpha_type =
42
39
torch::executor::native::utils::get_scalar_dtype (alpha);
43
-
44
40
// Check alpha type
45
41
ET_KERNEL_CHECK (ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
46
42
@@ -67,10 +63,6 @@ Tensor& sub_out(
67
63
out);
68
64
#endif
69
65
70
- // Compute Dtype
71
- ScalarType compute_type =
72
- torch::executor::native::utils::get_compute_type (common_type);
73
-
74
66
// @lint-ignore CLANGTIDY facebook-hte-CArray
75
67
static constexpr const char op_name[] = " sub.out" ;
76
68
@@ -115,11 +107,15 @@ Tensor& sub_out(
115
107
}
116
108
}
117
109
118
- if ((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) {
110
+ if (((broadcast == 1 ) && (max_dim > kTensorDimensionLimit )) ||
111
+ (!(((a.scalar_type () == ScalarType::Int) ||
112
+ (a.scalar_type () == ScalarType::Float)) &&
113
+ (a.scalar_type () == b.scalar_type ()) &&
114
+ (a.scalar_type () == out.scalar_type ())))) {
119
115
optimized = 0 ;
120
116
}
121
117
122
- if ((compute_type == ScalarType::Int) && (optimized)) {
118
+ if ((a. scalar_type () == ScalarType::Int) && (optimized)) {
123
119
const int * const inp1_data = a.const_data_ptr <int >();
124
120
const int * const inp2_data = b.const_data_ptr <int >();
125
121
int * const out_data = out.mutable_data_ptr <int >();
@@ -161,7 +157,7 @@ Tensor& sub_out(
161
157
alpha_val,
162
158
out.numel ());
163
159
}
164
- } else if ((compute_type == ScalarType::Float) && (optimized)) {
160
+ } else if ((a. scalar_type () == ScalarType::Float) && (optimized)) {
165
161
const float * const inp1_data = a.const_data_ptr <float >();
166
162
const float * const inp2_data = b.const_data_ptr <float >();
167
163
float * const out_data = out.mutable_data_ptr <float >();
@@ -204,6 +200,13 @@ Tensor& sub_out(
204
200
out.numel ());
205
201
}
206
202
} else {
203
+ // Common Dtype
204
+ ScalarType common_type =
205
+ executorch::runtime::promoteTypes (a.scalar_type (), b.scalar_type ());
206
+ // Compute Dtype
207
+ ScalarType compute_type =
208
+ torch::executor::native::utils::get_compute_type (common_type);
209
+
207
210
ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
208
211
const CTYPE_COMPUTE val_alpha =
209
212
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
@@ -232,14 +235,9 @@ Tensor& sub_scalar_out(
232
235
const Scalar& b,
233
236
const Scalar& alpha,
234
237
Tensor& out) {
235
- // Common Dtype
236
- ScalarType common_type =
237
- torch::executor::native::utils::promote_type_with_scalar (
238
- a.scalar_type (), b);
239
238
#ifdef OP_ARG_CHECK
240
239
ScalarType alpha_type =
241
240
torch::executor::native::utils::get_scalar_dtype (alpha);
242
-
243
241
// Check alpha type
244
242
ET_KERNEL_CHECK (ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
245
243
@@ -265,14 +263,20 @@ Tensor& sub_scalar_out(
265
263
out);
266
264
#endif
267
265
268
- // Compute Dtype
269
- ScalarType compute_type =
270
- torch::executor::native::utils::get_compute_type (common_type);
271
-
272
266
// @lint-ignore CLANGTIDY facebook-hte-CArray
273
267
static constexpr const char op_name[] = " sub.Scalar_out" ;
274
268
275
- if (compute_type == ScalarType::Int) {
269
+ bool optimized = 1 ;
270
+ ScalarType b_type = torch::executor::native::utils::get_scalar_dtype (b);
271
+
272
+ if (!(((a.scalar_type () == ScalarType::Int) ||
273
+ (a.scalar_type () == ScalarType::Float)) &&
274
+ (a.scalar_type () == b_type) &&
275
+ (a.scalar_type () == out.scalar_type ()))) {
276
+ optimized = 0 ;
277
+ }
278
+
279
+ if ((a.scalar_type () == ScalarType::Int) && (optimized)) {
276
280
const int * const inp1_data = a.const_data_ptr <int >();
277
281
int inp2_val;
278
282
torch::executor::native::utils::extract_scalar (b, &inp2_val);
@@ -291,7 +295,7 @@ Tensor& sub_scalar_out(
291
295
inp2_val,
292
296
alpha_val,
293
297
out.numel ());
294
- } else if (compute_type == ScalarType::Float) {
298
+ } else if ((a. scalar_type () == ScalarType::Float) && (optimized) ) {
295
299
const float * const inp1_data = a.const_data_ptr <float >();
296
300
float inp2_val;
297
301
torch::executor::native::utils::extract_scalar (b, &inp2_val);
@@ -311,6 +315,13 @@ Tensor& sub_scalar_out(
311
315
alpha_val,
312
316
out.numel ());
313
317
} else {
318
+ // Common Dtype
319
+ ScalarType common_type =
320
+ torch::executor::native::utils::promote_type_with_scalar (
321
+ a.scalar_type (), b);
322
+ // Compute Dtype
323
+ ScalarType compute_type =
324
+ torch::executor::native::utils::get_compute_type (common_type);
314
325
ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
315
326
const CTYPE_COMPUTE val_b =
316
327
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b);
0 commit comments