Skip to content

Commit 57ef834

Browse files
authored
Generic impl different ip datatypes
Differential Revision: D68579949 Pull Request resolved: #7835
1 parent 6fa4ae5 commit 57ef834

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

backends/cadence/fusion_g3/operators/op_sub.cpp

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,9 @@ Tensor& sub_out(
3434
const Tensor& b,
3535
const Scalar& alpha,
3636
Tensor& out) {
37-
// Common Dtype
38-
ScalarType common_type =
39-
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
4037
#ifdef OP_ARG_CHECK
4138
ScalarType alpha_type =
4239
torch::executor::native::utils::get_scalar_dtype(alpha);
43-
4440
// Check alpha type
4541
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
4642

@@ -67,10 +63,6 @@ Tensor& sub_out(
6763
out);
6864
#endif
6965

70-
// Compute Dtype
71-
ScalarType compute_type =
72-
torch::executor::native::utils::get_compute_type(common_type);
73-
7466
// @lint-ignore CLANGTIDY facebook-hte-CArray
7567
static constexpr const char op_name[] = "sub.out";
7668

@@ -115,11 +107,15 @@ Tensor& sub_out(
115107
}
116108
}
117109

118-
if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
110+
if (((broadcast == 1) && (max_dim > kTensorDimensionLimit)) ||
111+
(!(((a.scalar_type() == ScalarType::Int) ||
112+
(a.scalar_type() == ScalarType::Float)) &&
113+
(a.scalar_type() == b.scalar_type()) &&
114+
(a.scalar_type() == out.scalar_type())))) {
119115
optimized = 0;
120116
}
121117

122-
if ((compute_type == ScalarType::Int) && (optimized)) {
118+
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
123119
const int* const inp1_data = a.const_data_ptr<int>();
124120
const int* const inp2_data = b.const_data_ptr<int>();
125121
int* const out_data = out.mutable_data_ptr<int>();
@@ -161,7 +157,7 @@ Tensor& sub_out(
161157
alpha_val,
162158
out.numel());
163159
}
164-
} else if ((compute_type == ScalarType::Float) && (optimized)) {
160+
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
165161
const float* const inp1_data = a.const_data_ptr<float>();
166162
const float* const inp2_data = b.const_data_ptr<float>();
167163
float* const out_data = out.mutable_data_ptr<float>();
@@ -204,6 +200,13 @@ Tensor& sub_out(
204200
out.numel());
205201
}
206202
} else {
203+
// Common Dtype
204+
ScalarType common_type =
205+
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
206+
// Compute Dtype
207+
ScalarType compute_type =
208+
torch::executor::native::utils::get_compute_type(common_type);
209+
207210
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
208211
const CTYPE_COMPUTE val_alpha =
209212
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
@@ -232,14 +235,9 @@ Tensor& sub_scalar_out(
232235
const Scalar& b,
233236
const Scalar& alpha,
234237
Tensor& out) {
235-
// Common Dtype
236-
ScalarType common_type =
237-
torch::executor::native::utils::promote_type_with_scalar(
238-
a.scalar_type(), b);
239238
#ifdef OP_ARG_CHECK
240239
ScalarType alpha_type =
241240
torch::executor::native::utils::get_scalar_dtype(alpha);
242-
243241
// Check alpha type
244242
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
245243

@@ -265,14 +263,20 @@ Tensor& sub_scalar_out(
265263
out);
266264
#endif
267265

268-
// Compute Dtype
269-
ScalarType compute_type =
270-
torch::executor::native::utils::get_compute_type(common_type);
271-
272266
// @lint-ignore CLANGTIDY facebook-hte-CArray
273267
static constexpr const char op_name[] = "sub.Scalar_out";
274268

275-
if (compute_type == ScalarType::Int) {
269+
bool optimized = 1;
270+
ScalarType b_type = torch::executor::native::utils::get_scalar_dtype(b);
271+
272+
if (!(((a.scalar_type() == ScalarType::Int) ||
273+
(a.scalar_type() == ScalarType::Float)) &&
274+
(a.scalar_type() == b_type) &&
275+
(a.scalar_type() == out.scalar_type()))) {
276+
optimized = 0;
277+
}
278+
279+
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
276280
const int* const inp1_data = a.const_data_ptr<int>();
277281
int inp2_val;
278282
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -291,7 +295,7 @@ Tensor& sub_scalar_out(
291295
inp2_val,
292296
alpha_val,
293297
out.numel());
294-
} else if (compute_type == ScalarType::Float) {
298+
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
295299
const float* const inp1_data = a.const_data_ptr<float>();
296300
float inp2_val;
297301
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -311,6 +315,13 @@ Tensor& sub_scalar_out(
311315
alpha_val,
312316
out.numel());
313317
} else {
318+
// Common Dtype
319+
ScalarType common_type =
320+
torch::executor::native::utils::promote_type_with_scalar(
321+
a.scalar_type(), b);
322+
// Compute Dtype
323+
ScalarType compute_type =
324+
torch::executor::native::utils::get_compute_type(common_type);
314325
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
315326
const CTYPE_COMPUTE val_b =
316327
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b);

0 commit comments

Comments
 (0)