@@ -92,7 +92,8 @@ void check_dequantize_per_tensor_args(
92
92
} // namespace
93
93
94
94
/* Local function which calls the kernels based on the input datatype */
95
- Tensor & dequantize_impl (KernelRuntimeContext& ctx,
95
+ Tensor& dequantize_impl (
96
+ KernelRuntimeContext& ctx,
96
97
Tensor& out,
97
98
const Tensor& input,
98
99
float * scale_data,
@@ -132,82 +133,82 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
132
133
if (is_asym_dequant) {
133
134
if (input.scalar_type () == ScalarType::Byte) {
134
135
const uint8_t * input_data = input.const_data_ptr <uint8_t >();
135
- XT_KERNEL_CHECK (
136
+ XT_KERNEL_CHECK (
136
137
ctx,
137
138
out,
138
- xa_nn_elm_dequantize_asym8u_f32,
139
- out_data,
140
- input_data,
141
- inp_shape,
142
- input.dim (),
143
- axis,
144
- zero_point_data,
145
- scale_data);
139
+ xa_nn_elm_dequantize_asym8u_f32,
140
+ out_data,
141
+ input_data,
142
+ inp_shape,
143
+ input.dim (),
144
+ axis,
145
+ zero_point_data,
146
+ scale_data);
146
147
} else if (input.scalar_type () == ScalarType::Char) {
147
148
const int8_t * input_data = input.const_data_ptr <int8_t >();
148
- XT_KERNEL_CHECK (
149
+ XT_KERNEL_CHECK (
149
150
ctx,
150
151
out,
151
- xa_nn_elm_dequantize_asym8_f32,
152
- out_data,
153
- input_data,
154
- inp_shape,
155
- input.dim (),
156
- axis,
157
- zero_point_data,
158
- scale_data);
152
+ xa_nn_elm_dequantize_asym8_f32,
153
+ out_data,
154
+ input_data,
155
+ inp_shape,
156
+ input.dim (),
157
+ axis,
158
+ zero_point_data,
159
+ scale_data);
159
160
} else if (input.scalar_type () == (ScalarType)Ushort) {
160
161
const uint16_t * input_data = input.const_data_ptr <uint16_t >();
161
- XT_KERNEL_CHECK (
162
+ XT_KERNEL_CHECK (
162
163
ctx,
163
164
out,
164
- xa_nn_elm_dequantize_asym16u_f32,
165
- out_data,
166
- input_data,
167
- inp_shape,
168
- input.dim (),
169
- axis,
170
- zero_point_data,
171
- scale_data);
165
+ xa_nn_elm_dequantize_asym16u_f32,
166
+ out_data,
167
+ input_data,
168
+ inp_shape,
169
+ input.dim (),
170
+ axis,
171
+ zero_point_data,
172
+ scale_data);
172
173
} else if (input.scalar_type () == ScalarType::Short) {
173
174
const int16_t * input_data = input.const_data_ptr <int16_t >();
174
- XT_KERNEL_CHECK (
175
+ XT_KERNEL_CHECK (
175
176
ctx,
176
177
out,
177
- xa_nn_elm_dequantize_asym16_f32,
178
- out_data,
179
- input_data,
180
- inp_shape,
181
- input.dim (),
182
- axis,
183
- zero_point_data,
184
- scale_data);
178
+ xa_nn_elm_dequantize_asym16_f32,
179
+ out_data,
180
+ input_data,
181
+ inp_shape,
182
+ input.dim (),
183
+ axis,
184
+ zero_point_data,
185
+ scale_data);
185
186
} else if (input.scalar_type () == (ScalarType)Bits4u) {
186
187
const uint8_t * input_data = input.const_data_ptr <uint8_t >();
187
- XT_KERNEL_CHECK (
188
+ XT_KERNEL_CHECK (
188
189
ctx,
189
190
out,
190
- xa_nn_elm_dequantize_asym4u_f32,
191
- out_data,
192
- input_data,
193
- inp_shape,
194
- input.dim (),
195
- axis,
196
- zero_point_data,
197
- scale_data);
191
+ xa_nn_elm_dequantize_asym4u_f32,
192
+ out_data,
193
+ input_data,
194
+ inp_shape,
195
+ input.dim (),
196
+ axis,
197
+ zero_point_data,
198
+ scale_data);
198
199
} else if (input.scalar_type () == (ScalarType)Bits4) {
199
200
const int8_t * input_data = input.const_data_ptr <int8_t >();
200
- XT_KERNEL_CHECK (
201
+ XT_KERNEL_CHECK (
201
202
ctx,
202
203
out,
203
- xa_nn_elm_dequantize_asym4_f32,
204
- out_data,
205
- input_data,
206
- inp_shape,
207
- input.dim (),
208
- axis,
209
- zero_point_data,
210
- scale_data);
204
+ xa_nn_elm_dequantize_asym4_f32,
205
+ out_data,
206
+ input_data,
207
+ inp_shape,
208
+ input.dim (),
209
+ axis,
210
+ zero_point_data,
211
+ scale_data);
211
212
} else {
212
213
if (axis == NULL ) {
213
214
// calculate the dequantized output, cast scale to float to match fbgemm
@@ -343,10 +344,10 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
343
344
} else {
344
345
if (input.scalar_type () == ScalarType::Byte) {
345
346
const uint8_t * input_data = input.const_data_ptr <uint8_t >();
346
- XT_KERNEL_CHECK (
347
+ XT_KERNEL_CHECK (
347
348
ctx,
348
349
out,
349
- xa_nn_elm_dequantize_sym8u_f32,
350
+ xa_nn_elm_dequantize_sym8u_f32,
350
351
out_data,
351
352
input_data,
352
353
inp_shape,
@@ -358,19 +359,19 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
358
359
XT_KERNEL_CHECK (
359
360
ctx,
360
361
out,
361
- xa_nn_elm_dequantize_sym8_f32,
362
- out_data,
363
- input_data,
364
- inp_shape,
365
- input.dim (),
366
- axis,
367
- scale_data);
362
+ xa_nn_elm_dequantize_sym8_f32,
363
+ out_data,
364
+ input_data,
365
+ inp_shape,
366
+ input.dim (),
367
+ axis,
368
+ scale_data);
368
369
} else if (input.scalar_type () == (ScalarType)Ushort) {
369
370
const uint16_t * input_data = input.const_data_ptr <uint16_t >();
370
- XT_KERNEL_CHECK (
371
+ XT_KERNEL_CHECK (
371
372
ctx,
372
373
out,
373
- xa_nn_elm_dequantize_sym16u_f32,
374
+ xa_nn_elm_dequantize_sym16u_f32,
374
375
out_data,
375
376
input_data,
376
377
inp_shape,
@@ -379,10 +380,10 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
379
380
scale_data);
380
381
} else if (input.scalar_type () == ScalarType::Short) {
381
382
const int16_t * input_data = input.const_data_ptr <int16_t >();
382
- XT_KERNEL_CHECK (
383
+ XT_KERNEL_CHECK (
383
384
ctx,
384
385
out,
385
- xa_nn_elm_dequantize_sym16_f32,
386
+ xa_nn_elm_dequantize_sym16_f32,
386
387
out_data,
387
388
input_data,
388
389
inp_shape,
@@ -391,10 +392,10 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
391
392
scale_data);
392
393
} else if (input.scalar_type () == (ScalarType)Bits4u) {
393
394
const uint8_t * input_data = input.const_data_ptr <uint8_t >();
394
- XT_KERNEL_CHECK (
395
+ XT_KERNEL_CHECK (
395
396
ctx,
396
397
out,
397
- xa_nn_elm_dequantize_sym4u_f32,
398
+ xa_nn_elm_dequantize_sym4u_f32,
398
399
out_data,
399
400
input_data,
400
401
inp_shape,
@@ -403,10 +404,10 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
403
404
scale_data);
404
405
} else if (input.scalar_type () == (ScalarType)Bits4) {
405
406
const int8_t * input_data = input.const_data_ptr <int8_t >();
406
- XT_KERNEL_CHECK (
407
+ XT_KERNEL_CHECK (
407
408
ctx,
408
409
out,
409
- xa_nn_elm_dequantize_sym4_f32,
410
+ xa_nn_elm_dequantize_sym4_f32,
410
411
out_data,
411
412
input_data,
412
413
inp_shape,
@@ -558,7 +559,8 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
558
559
* https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
559
560
* info.
560
561
*/
561
- Tensor& dequantize_per_tensor_out (KernelRuntimeContext& context,
562
+ Tensor& dequantize_per_tensor_out (
563
+ KernelRuntimeContext& context,
562
564
const Tensor& input,
563
565
double scale,
564
566
int64_t zero_point,
@@ -572,20 +574,22 @@ Tensor& dequantize_per_tensor_out(KernelRuntimeContext& context,
572
574
ET_CHECK_MSG (
573
575
err == torch::executor::Error::Ok,
574
576
" Failed to resize out Tensor in dequantize_per_tensor_out" );
575
-
577
+
576
578
check_dequantize_per_tensor_args (
577
579
input, quant_min, quant_max, dtype, out_dtype, out);
578
580
#endif
579
581
580
582
float scale_data = (float )scale;
581
583
int zero_point_data = (int )zero_point;
582
584
583
- dequantize_impl (context, out, input, &scale_data, &zero_point_data, NULL , out_dtype);
585
+ dequantize_impl (
586
+ context, out, input, &scale_data, &zero_point_data, NULL , out_dtype);
584
587
585
588
return out;
586
589
}
587
590
588
- Tensor& dequantize_per_tensor_tensor_args_out (KernelRuntimeContext& context,
591
+ Tensor& dequantize_per_tensor_tensor_args_out (
592
+ KernelRuntimeContext& context,
589
593
const Tensor& input,
590
594
const Tensor& scale,
591
595
const Tensor& zero_point,
@@ -613,7 +617,8 @@ Tensor& dequantize_per_tensor_tensor_args_out(KernelRuntimeContext& context,
613
617
ssize_t (zero_point.numel ()));
614
618
#endif
615
619
616
- dequantize_per_tensor_out (context,
620
+ dequantize_per_tensor_out (
621
+ context,
617
622
input,
618
623
scale.const_data_ptr <double >()[0 ],
619
624
zero_point.const_data_ptr <int64_t >()[0 ],
@@ -626,7 +631,8 @@ Tensor& dequantize_per_tensor_tensor_args_out(KernelRuntimeContext& context,
626
631
return out;
627
632
}
628
633
629
- Tensor& dequantize_per_channel_out (KernelRuntimeContext& context,
634
+ Tensor& dequantize_per_channel_out (
635
+ KernelRuntimeContext& context,
630
636
const Tensor& input,
631
637
const Tensor& scale,
632
638
const exec_aten::optional<Tensor>& opt_zero_points,
@@ -636,14 +642,13 @@ Tensor& dequantize_per_channel_out(KernelRuntimeContext& context,
636
642
ScalarType dtype,
637
643
exec_aten::optional<ScalarType> out_dtype,
638
644
Tensor& out) {
639
-
640
645
if (axis < 0 ) {
641
646
axis += executorch::runtime::nonzero_dim (input);
642
647
}
643
- /* if the arguments are passed properly to the operator disable the Macro - "OP_ARG_CHECK"
644
- * if not the case, enable the Macro - "OP_ARG_CHECK", to have the checks only in
645
- * operator level(As there are no checks in kernel).
646
- */
648
+ /* if the arguments are passed properly to the operator disable the Macro -
649
+ * "OP_ARG_CHECK" if not the case, enable the Macro - "OP_ARG_CHECK", to have
650
+ * the checks only in operator level(As there are no checks in kernel).
651
+ */
647
652
#ifdef OP_ARG_CHECK
648
653
torch::executor::Error err = resize_tensor (out, input.sizes ());
649
654
@@ -705,12 +710,14 @@ Tensor& dequantize_per_channel_out(KernelRuntimeContext& context,
705
710
for (int i = 0 ; i < scale.numel (); i++) {
706
711
scale_data[i] = (float )scale_dt[i];
707
712
}
708
- dequantize_impl (context, out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
713
+ dequantize_impl (
714
+ context, out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
709
715
710
716
return out;
711
717
}
712
718
713
- Tensor& dequantize_per_token_out (KernelRuntimeContext& context,
719
+ Tensor& dequantize_per_token_out (
720
+ KernelRuntimeContext& context,
714
721
const Tensor& input,
715
722
const Tensor& scale,
716
723
const Tensor& zero_points,
@@ -757,7 +764,8 @@ Tensor& dequantize_per_token_out(KernelRuntimeContext& context,
757
764
" Failed to resize out Tensor in dequantize_per_channel_out" );
758
765
#endif
759
766
760
- return dequantize_per_channel_out (context,
767
+ return dequantize_per_channel_out (
768
+ context,
761
769
reshaped_input,
762
770
scale,
763
771
zero_points,
0 commit comments