Skip to content

Commit c361431

Browse files
ckmadhiracmadhira@cadence.comzonglinpeng
authored
modified all the operators to support generic implementation in specific cases (#7936)
* modified all the operators to support generic implementation when input data types are different or when input data type is different from output datatype * boolean variables are assiged with true or false instead of 1 or 0. In Scalar operation fixed an issue related to datatype of the second input --------- Co-authored-by: [email protected] <[email protected]> Co-authored-by: JP <[email protected]>
1 parent 3be1c5e commit c361431

File tree

13 files changed

+378
-282
lines changed

13 files changed

+378
-282
lines changed

backends/cadence/fusion_g3/operators/op_add.cpp

Lines changed: 60 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,7 @@ Tensor& add_out(
3535
const Tensor& b,
3636
const Scalar& alpha,
3737
Tensor& out) {
38-
// Common Dtype
39-
ScalarType common_type =
40-
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
41-
4238
#ifdef OP_ARG_CHECK
43-
// Check Common Dtype
44-
ET_KERNEL_CHECK(
45-
ctx,
46-
(canCast(common_type, out.scalar_type()) &&
47-
torch::executor::check_alpha_type(
48-
torch::executor::native::utils::get_scalar_dtype(alpha),
49-
common_type)),
50-
InvalidArgument,
51-
out);
52-
5339
// Check Dim Order
5440
ET_KERNEL_CHECK(
5541
ctx,
@@ -65,10 +51,6 @@ Tensor& add_out(
6551
out);
6652
#endif
6753

68-
// Compute Dtype
69-
ScalarType compute_type =
70-
torch::executor::native::utils::get_compute_type(common_type);
71-
7254
static constexpr const char op_name[] = "add.out";
7355

7456
int kTensorDimensionLimit = 5;
@@ -77,12 +59,12 @@ Tensor& add_out(
7759
int inp2_shape[kTensorDimensionLimit];
7860
int out_shape[kTensorDimensionLimit];
7961

80-
bool broadcast = 0;
62+
bool broadcast = false;
8163

8264
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
8365
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
8466

85-
bool optimized = 1;
67+
bool optimized = true;
8668

8769
/* Added change to work with input dimensions more than 5 */
8870
for (int i = 0; i < max_dim; i++) {
@@ -109,15 +91,19 @@ Tensor& add_out(
10991
for (int i = 0; i < out.dim(); i++) {
11092
if (((inp1_shape[i]) != (out_shape[i])) ||
11193
((inp2_shape[i]) != (out_shape[i]))) {
112-
broadcast = 1;
94+
broadcast = true;
11395
}
11496
}
11597

116-
if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
117-
optimized = 0;
98+
if (((broadcast) && (max_dim > kTensorDimensionLimit)) ||
99+
(!(((a.scalar_type() == ScalarType::Int) ||
100+
(a.scalar_type() == ScalarType::Float)) &&
101+
(a.scalar_type() == b.scalar_type()) &&
102+
(a.scalar_type() == out.scalar_type())))) {
103+
optimized = false;
118104
}
119105

120-
if ((compute_type == ScalarType::Int) && (optimized)) {
106+
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
121107
const int* const inp1_data = a.const_data_ptr<int>();
122108
const int* const inp2_data = b.const_data_ptr<int>();
123109
int* const out_data = out.mutable_data_ptr<int>();
@@ -169,7 +155,7 @@ Tensor& add_out(
169155
alpha_val,
170156
out.numel());
171157
}
172-
} else if ((compute_type == ScalarType::Float) && (optimized)) {
158+
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
173159
const float* const inp1_data = a.const_data_ptr<float>();
174160
const float* const inp2_data = b.const_data_ptr<float>();
175161
float* const out_data = out.mutable_data_ptr<float>();
@@ -222,6 +208,23 @@ Tensor& add_out(
222208
out.numel());
223209
}
224210
} else {
211+
// Common Dtype
212+
ScalarType common_type =
213+
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
214+
// Compute Dtype
215+
ScalarType compute_type =
216+
torch::executor::native::utils::get_compute_type(common_type);
217+
218+
// Check Common Dtype
219+
ET_KERNEL_CHECK(
220+
ctx,
221+
(canCast(common_type, out.scalar_type()) &&
222+
torch::executor::check_alpha_type(
223+
torch::executor::native::utils::get_scalar_dtype(alpha),
224+
common_type)),
225+
InvalidArgument,
226+
out);
227+
225228
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
226229
const CTYPE_COMPUTE val_alpha =
227230
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
@@ -249,22 +252,7 @@ Tensor& add_scalar_out(
249252
const Scalar& b,
250253
const Scalar& alpha,
251254
Tensor& out) {
252-
// Common Dtype
253-
ScalarType common_type =
254-
torch::executor::native::utils::promote_type_with_scalar(
255-
a.scalar_type(), b);
256-
257255
#ifdef OP_ARG_CHECK
258-
// Check Common Dtype
259-
ET_KERNEL_CHECK(
260-
ctx,
261-
(common_type == out.scalar_type() &&
262-
torch::executor::check_alpha_type(
263-
torch::executor::native::utils::get_scalar_dtype(alpha),
264-
common_type)),
265-
InvalidArgument,
266-
out);
267-
268256
// Check Dim Order
269257
ET_KERNEL_CHECK(
270258
ctx,
@@ -279,14 +267,23 @@ Tensor& add_scalar_out(
279267
InvalidArgument,
280268
out);
281269
#endif
282-
// Compute Dtype
283-
ScalarType compute_type =
284-
torch::executor::native::utils::get_compute_type(common_type);
285270

286271
// @lint-ignore CLANGTIDY facebook-hte-CArray
287272
static constexpr const char op_name[] = "add.Scalar_out";
288273

289-
if (compute_type == ScalarType::Int) {
274+
bool optimized = true;
275+
276+
if (!(((a.scalar_type() == ScalarType::Int) ||
277+
(a.scalar_type() == ScalarType::Float)) &&
278+
(a.scalar_type() == out.scalar_type()))) {
279+
optimized = false;
280+
}
281+
282+
if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) {
283+
optimized = false;
284+
}
285+
286+
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
290287
const int* const inp1_data = a.const_data_ptr<int>();
291288
int inp2_val;
292289
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -306,7 +303,7 @@ Tensor& add_scalar_out(
306303
alpha_val,
307304
out.numel());
308305

309-
} else if (compute_type == ScalarType::Float) {
306+
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
310307
const float* const inp1_data = a.const_data_ptr<float>();
311308
float inp2_val;
312309
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -327,6 +324,24 @@ Tensor& add_scalar_out(
327324
out.numel());
328325

329326
} else {
327+
// Common Dtype
328+
ScalarType common_type =
329+
torch::executor::native::utils::promote_type_with_scalar(
330+
a.scalar_type(), b);
331+
// Compute Dtype
332+
ScalarType compute_type =
333+
torch::executor::native::utils::get_compute_type(common_type);
334+
335+
// Check Common Dtype
336+
ET_KERNEL_CHECK(
337+
ctx,
338+
(common_type == out.scalar_type() &&
339+
torch::executor::check_alpha_type(
340+
torch::executor::native::utils::get_scalar_dtype(alpha),
341+
common_type)),
342+
InvalidArgument,
343+
out);
344+
330345
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
331346
torch::executor::native::utils::
332347
apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(

backends/cadence/fusion_g3/operators/op_cat.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,6 @@ Tensor& cat_out(
4646
int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit;
4747

4848
#ifdef OP_ARG_CHECK
49-
ET_KERNEL_CHECK(
50-
ctx,
51-
torch::executor::check_cat_args(tensors, dim, out),
52-
InvalidArgument,
53-
out);
5449

5550
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
5651
size_t expected_out_dim = 0;
@@ -106,7 +101,16 @@ Tensor& cat_out(
106101
out_shapes[i] = out_size[i];
107102
}
108103

109-
if ((out.scalar_type() == ScalarType::Int) ||
104+
bool optimized = true;
105+
106+
for (int i = 0; i < tensors.size(); i++) {
107+
if (out.scalar_type() != tensors[i].scalar_type()) {
108+
optimized = false;
109+
break;
110+
}
111+
}
112+
113+
if ((optimized) && (out.scalar_type() == ScalarType::Int) ||
110114
(out.scalar_type() == ScalarType::Short) ||
111115
(out.scalar_type() == ScalarType::Char) ||
112116
(out.scalar_type() == ScalarType::UInt32) ||
@@ -125,6 +129,12 @@ Tensor& cat_out(
125129
(int)dim,
126130
get_element_size(out.scalar_type()));
127131
} else {
132+
ET_KERNEL_CHECK(
133+
ctx,
134+
torch::executor::check_cat_args(tensors, dim, out),
135+
InvalidArgument,
136+
out);
137+
128138
const size_t outer = executorch::runtime::getLeadingDims(out, dim);
129139
const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim);
130140
const size_t ninputs = tensors.size();

backends/cadence/fusion_g3/operators/op_dequantize.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,22 @@ Tensor& dequantize_impl(
117117
}
118118
}
119119
} else {
120-
if (*zero_point_data != 0) // tesor
120+
if (*zero_point_data != 0) // tensor
121121
{
122122
is_asym_dequant |= 1;
123123
}
124124
}
125125
}
126126
float* out_data = out.mutable_data_ptr<float>();
127127

128+
bool optimized = true;
129+
130+
if (out.scalar_type() != ScalarType::Float) {
131+
optimized = false;
132+
}
133+
128134
if (is_asym_dequant) {
129-
if (input.scalar_type() == ScalarType::Byte) {
135+
if ((input.scalar_type() == ScalarType::Byte) && (optimized)) {
130136
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
131137
XT_KERNEL_CHECK(
132138
ctx,
@@ -139,7 +145,7 @@ Tensor& dequantize_impl(
139145
axis,
140146
zero_point_data,
141147
scale_data);
142-
} else if (input.scalar_type() == ScalarType::Char) {
148+
} else if ((input.scalar_type() == ScalarType::Char) && (optimized)) {
143149
const int8_t* input_data = input.const_data_ptr<int8_t>();
144150
XT_KERNEL_CHECK(
145151
ctx,
@@ -152,7 +158,7 @@ Tensor& dequantize_impl(
152158
axis,
153159
zero_point_data,
154160
scale_data);
155-
} else if (input.scalar_type() == ScalarType::UInt16) {
161+
} else if ((input.scalar_type() == ScalarType::UInt16) && (optimized)) {
156162
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
157163
XT_KERNEL_CHECK(
158164
ctx,
@@ -165,7 +171,7 @@ Tensor& dequantize_impl(
165171
axis,
166172
zero_point_data,
167173
scale_data);
168-
} else if (input.scalar_type() == ScalarType::Short) {
174+
} else if ((input.scalar_type() == ScalarType::Short) && (optimized)) {
169175
const int16_t* input_data = input.const_data_ptr<int16_t>();
170176
XT_KERNEL_CHECK(
171177
ctx,
@@ -178,7 +184,7 @@ Tensor& dequantize_impl(
178184
axis,
179185
zero_point_data,
180186
scale_data);
181-
} else if (input.scalar_type() == (ScalarType)Bits4u) {
187+
} else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
182188
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
183189
XT_KERNEL_CHECK(
184190
ctx,
@@ -191,7 +197,7 @@ Tensor& dequantize_impl(
191197
axis,
192198
zero_point_data,
193199
scale_data);
194-
} else if (input.scalar_type() == (ScalarType)Bits4) {
200+
} else if ((input.scalar_type() == (ScalarType)Bits4) && (optimized)) {
195201
const int8_t* input_data = input.const_data_ptr<int8_t>();
196202
XT_KERNEL_CHECK(
197203
ctx,
@@ -338,7 +344,7 @@ Tensor& dequantize_impl(
338344
}
339345
}
340346
} else {
341-
if (input.scalar_type() == ScalarType::Byte) {
347+
if ((input.scalar_type() == ScalarType::Byte) && (optimized)) {
342348
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
343349
XT_KERNEL_CHECK(
344350
ctx,
@@ -350,7 +356,7 @@ Tensor& dequantize_impl(
350356
input.dim(),
351357
axis,
352358
scale_data);
353-
} else if (input.scalar_type() == ScalarType::Char) {
359+
} else if ((input.scalar_type() == ScalarType::Char) && (optimized)) {
354360
const int8_t* input_data = input.const_data_ptr<int8_t>();
355361
XT_KERNEL_CHECK(
356362
ctx,
@@ -362,7 +368,7 @@ Tensor& dequantize_impl(
362368
input.dim(),
363369
axis,
364370
scale_data);
365-
} else if (input.scalar_type() == ScalarType::UInt16) {
371+
} else if ((input.scalar_type() == ScalarType::UInt16) && (optimized)) {
366372
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
367373
XT_KERNEL_CHECK(
368374
ctx,
@@ -374,7 +380,7 @@ Tensor& dequantize_impl(
374380
input.dim(),
375381
axis,
376382
scale_data);
377-
} else if (input.scalar_type() == ScalarType::Short) {
383+
} else if ((input.scalar_type() == ScalarType::Short) && (optimized)) {
378384
const int16_t* input_data = input.const_data_ptr<int16_t>();
379385
XT_KERNEL_CHECK(
380386
ctx,
@@ -386,7 +392,7 @@ Tensor& dequantize_impl(
386392
input.dim(),
387393
axis,
388394
scale_data);
389-
} else if (input.scalar_type() == (ScalarType)Bits4u) {
395+
} else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
390396
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
391397
XT_KERNEL_CHECK(
392398
ctx,
@@ -398,7 +404,7 @@ Tensor& dequantize_impl(
398404
input.dim(),
399405
axis,
400406
scale_data);
401-
} else if (input.scalar_type() == (ScalarType)Bits4) {
407+
} else if ((input.scalar_type() == (ScalarType)Bits4) && (optimized)) {
402408
const int8_t* input_data = input.const_data_ptr<int8_t>();
403409
XT_KERNEL_CHECK(
404410
ctx,

0 commit comments

Comments
 (0)