1
+ from typing import Optional , Tuple
2
+
1
3
import torch
2
- from torch .library import Library , impl
3
- from torch .ao .quantization .utils import determine_qparams , validate_qmin_qmax
4
- from typing import Tuple
5
4
from torch ._refs import _unsqueeze_multiple
6
-
5
+ from torch .ao .quantization .utils import determine_qparams , validate_qmin_qmax
6
+ from torch .library import impl , Library
7
7
8
8
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
9
9
# name is not too long
13
13
torch .uint8 : (0 , 255 ),
14
14
torch .int8 : (- 128 , 127 ),
15
15
torch .int16 : (- (2 ** 15 ), 2 ** 15 - 1 ),
16
- torch .int32 : (- (2 ** 31 ), 2 ** 31 - 1 )
16
+ torch .int32 : (- (2 ** 31 ), 2 ** 31 - 1 ),
17
17
}
18
18
19
19
# Helper to check the passed in quant min and max are valid for the dtype
@@ -60,13 +60,26 @@ def quantize_per_tensor(
60
60
"""
61
61
if input .dtype == torch .bfloat16 :
62
62
input = input .to (torch .float32 )
63
-
64
63
assert input .dtype == torch .float32 , f"Expecting input to have dtype torch.float32, but got dtype: { input .dtype } "
65
64
_quant_min_max_bounds_check (quant_min , quant_max , dtype )
66
65
67
66
inv_scale = 1.0 / scale
68
67
return torch .clamp (torch .round (input * inv_scale ) + zero_point , quant_min , quant_max ).to (dtype )
69
68
69
+ @impl (quantized_decomposed_lib , "quantize_per_tensor" , "Meta" )
70
+ def quantize_per_tensor_meta (
71
+ input : torch .Tensor ,
72
+ scale : float ,
73
+ zero_point : int ,
74
+ quant_min : int ,
75
+ quant_max : int ,
76
+ dtype : torch .dtype
77
+ ) -> torch .Tensor :
78
+ if input .dtype == torch .bfloat16 :
79
+ input = input .to (torch .float32 )
80
+ assert input .dtype == torch .float32 , f"Expecting input to have dtype torch.float32, but got dtype: { input .dtype } "
81
+ return torch .empty_like (input , dtype = dtype )
82
+
70
83
quantized_decomposed_lib .define (
71
84
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
72
85
"int quant_min, int quant_max, ScalarType dtype) -> Tensor" )
@@ -90,7 +103,14 @@ def quantize_per_tensor_tensor(
90
103
return quantize_per_tensor (input , scale .item (), zero_point .item (), quant_min , quant_max , dtype )
91
104
92
105
@impl (quantized_decomposed_lib , "quantize_per_tensor.tensor" , "Meta" )
93
- def quantize_per_tensor_tensor_meta (input , scale , zero_point , quant_min , quant_max , dtype ):
106
+ def quantize_per_tensor_tensor_meta (
107
+ input : torch .Tensor ,
108
+ scale : torch .Tensor ,
109
+ zero_point : torch .Tensor ,
110
+ quant_min : int ,
111
+ quant_max : int ,
112
+ dtype : torch .dtype
113
+ ) -> torch .Tensor :
94
114
if input .dtype == torch .bfloat16 :
95
115
input = input .to (torch .float32 )
96
116
assert zero_point .numel () == 1 , f"Expecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
@@ -122,7 +142,14 @@ def quantize_per_tensor_tensor2(
122
142
return quantize_per_tensor (input , scale .item (), zero_point .item (), quant_min .item (), quant_max .item (), dtype )
123
143
124
144
@impl (quantized_decomposed_lib , "quantize_per_tensor.tensor2" , "Meta" )
125
- def quantize_per_tensor_tensor2_meta (input , scale , zero_point , quant_min , quant_max , dtype ):
145
+ def quantize_per_tensor_tensor2_meta (
146
+ input : torch .Tensor ,
147
+ scale : torch .Tensor ,
148
+ zero_point : torch .Tensor ,
149
+ quant_min : torch .Tensor ,
150
+ quant_max : torch .Tensor ,
151
+ dtype : torch .dtype
152
+ ) -> torch .Tensor :
126
153
return quantize_per_tensor_tensor_meta (input , scale , zero_point , quant_min , quant_max , dtype )
127
154
128
155
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
@@ -131,7 +158,7 @@ def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_
131
158
# We will revisit this later if we found there are no use cases for it
132
159
quantized_decomposed_lib .define (
133
160
"dequantize_per_tensor(Tensor input, float scale, int zero_point, "
134
- "int quant_min, int quant_max, ScalarType dtype) -> Tensor" )
161
+ "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None ) -> Tensor" )
135
162
136
163
@impl (quantized_decomposed_lib , "dequantize_per_tensor" , "CompositeExplicitAutograd" )
137
164
def dequantize_per_tensor (
@@ -140,7 +167,9 @@ def dequantize_per_tensor(
140
167
zero_point : int ,
141
168
quant_min : int ,
142
169
quant_max : int ,
143
- dtype : torch .dtype
170
+ dtype : torch .dtype ,
171
+ * ,
172
+ out_dtype : Optional [torch .dtype ] = None
144
173
) -> torch .Tensor :
145
174
""" Affine dequantization for the Tensor using the same quantization parameters to map
146
175
from quantized values to floating point values
@@ -163,22 +192,40 @@ def dequantize_per_tensor(
163
192
dtype (torch.dtype): dtype for input Tensor (not used in computation,
164
193
reserved for pattern matching)
165
194
195
+ out_dtype (torch.dtype?): optional dtype for output Tensor
196
+
166
197
Returns:
167
198
dequantized float32 Tensor
168
199
"""
169
200
assert input .dtype == dtype , f"Expecting input to have dtype: { dtype } , but got { input .dtype } "
201
+ if out_dtype is None :
202
+ out_dtype = torch .float32
170
203
if dtype in _DTYPE_TO_QVALUE_BOUNDS :
171
204
# TODO: investigate why
172
205
# (input - zero_point).to(torch.float32) * scale
173
206
# failed the test
174
- return (input .to (torch . float32 ) - zero_point ) * scale
207
+ return (input .to (out_dtype ) - zero_point ) * scale
175
208
else :
176
209
raise ValueError (f"Unsupported dtype in dequantize_per_tensor: { dtype } " )
177
210
211
+ @impl (quantized_decomposed_lib , "dequantize_per_tensor" , "Meta" )
212
+ def dequantize_per_tensor_meta (
213
+ input : torch .Tensor ,
214
+ scale : torch .Tensor ,
215
+ zero_pointe : torch .Tensor ,
216
+ quant_min : int ,
217
+ quant_max : int ,
218
+ dtype : torch .dtype ,
219
+ * ,
220
+ out_dtype : Optional [torch .dtype ] = None
221
+ ) -> torch .Tensor :
222
+ if out_dtype is None :
223
+ out_dtype = torch .float32
224
+ return torch .empty_like (input , dtype = out_dtype )
178
225
179
226
quantized_decomposed_lib .define (
180
227
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
181
- "int quant_min, int quant_max, ScalarType dtype) -> Tensor" )
228
+ "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None ) -> Tensor" )
182
229
183
230
@impl (quantized_decomposed_lib , "dequantize_per_tensor.tensor" , "CompositeExplicitAutograd" )
184
231
def dequantize_per_tensor_tensor (
@@ -187,7 +234,9 @@ def dequantize_per_tensor_tensor(
187
234
zero_point : torch .Tensor ,
188
235
quant_min : int ,
189
236
quant_max : int ,
190
- dtype : torch .dtype
237
+ dtype : torch .dtype ,
238
+ * ,
239
+ out_dtype : Optional [torch .dtype ] = None
191
240
) -> torch .Tensor :
192
241
""" Affine dequantization for the Tensor using the same quantization parameters to map
193
242
from quantized values to floating point values
@@ -196,22 +245,33 @@ def dequantize_per_tensor_tensor(
196
245
"""
197
246
assert zero_point .numel () == 1 , f"Expecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
198
247
assert scale .numel () == 1 , f"Expecting scale tensor to be one element, but received : { scale .numel ()} "
199
- return dequantize_per_tensor (input , scale .item (), zero_point .item (), quant_min , quant_max , dtype )
248
+ return dequantize_per_tensor (input , scale .item (), zero_point .item (), quant_min , quant_max , dtype , out_dtype = out_dtype )
200
249
201
250
@impl (quantized_decomposed_lib , "dequantize_per_tensor.tensor" , "Meta" )
202
- def dequantize_per_tensor_tensor_meta (input , scale , zero_point , quant_min , quant_max , dtype ):
251
+ def dequantize_per_tensor_tensor_meta (
252
+ input : torch .Tensor ,
253
+ scale : torch .Tensor ,
254
+ zero_point : torch .Tensor ,
255
+ quant_min : int ,
256
+ quant_max : int ,
257
+ dtype : torch .dtype ,
258
+ * ,
259
+ out_dtype : Optional [torch .dtype ] = None
260
+ ) -> torch .Tensor :
261
+ if out_dtype is None :
262
+ out_dtype = torch .float32
203
263
assert zero_point .numel () == 1 , f"Expecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
204
264
assert scale .numel () == 1 , f"Expecting scale tensor to be one element, but received : { scale .numel ()} "
205
265
assert input .dtype == dtype , f"Expecting input to have dtype: { dtype } "
206
266
if dtype in _DTYPE_TO_QVALUE_BOUNDS :
207
- return torch .empty_like (input , dtype = torch . float32 )
267
+ return torch .empty_like (input , dtype = out_dtype )
208
268
else :
209
269
raise ValueError (f"Unsupported dtype in dequantize_per_tensor: { dtype } " )
210
270
211
271
# TODO: remove other variants and keep this one
212
272
quantized_decomposed_lib .define (
213
273
"dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
214
- "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor" )
274
+ "Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None ) -> Tensor" )
215
275
216
276
@impl (quantized_decomposed_lib , "dequantize_per_tensor.tensor2" , "CompositeExplicitAutograd" )
217
277
def dequantize_per_tensor_tensor2 (
@@ -220,7 +280,9 @@ def dequantize_per_tensor_tensor2(
220
280
zero_point : torch .Tensor ,
221
281
quant_min : torch .Tensor ,
222
282
quant_max : torch .Tensor ,
223
- dtype : torch .dtype
283
+ dtype : torch .dtype ,
284
+ * ,
285
+ out_dtype : Optional [torch .dtype ] = None
224
286
) -> torch .Tensor :
225
287
""" Affine dequantization for the Tensor using the same quantization parameters to map
226
288
from quantized values to floating point values
@@ -229,11 +291,21 @@ def dequantize_per_tensor_tensor2(
229
291
"""
230
292
assert zero_point .numel () == 1 , f"Expecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
231
293
assert scale .numel () == 1 , f"Expecting scale tensor to be one element, but received : { scale .numel ()} "
232
- return dequantize_per_tensor (input , scale .item (), zero_point .item (), quant_min .item (), quant_max .item (), dtype )
294
+ return dequantize_per_tensor (
295
+ input , scale .item (), zero_point .item (), quant_min .item (), quant_max .item (), dtype , out_dtype = out_dtype )
233
296
234
297
@impl (quantized_decomposed_lib , "dequantize_per_tensor.tensor2" , "Meta" )
235
- def dequantize_per_tensor_tensor2_meta (input , scale , zero_point , quant_min , quant_max , dtype ):
236
- return dequantize_per_tensor_tensor_meta (input , scale , zero_point , quant_min , quant_max , dtype )
298
+ def dequantize_per_tensor_tensor2_meta (
299
+ input ,
300
+ scale ,
301
+ zero_point ,
302
+ quant_min ,
303
+ quant_max ,
304
+ dtype ,
305
+ * ,
306
+ out_dtype : Optional [torch .dtype ] = None
307
+ ) -> torch .Tensor :
308
+ return dequantize_per_tensor_tensor_meta (input , scale , zero_point , quant_min , quant_max , dtype , out_dtype = out_dtype )
237
309
238
310
quantized_decomposed_lib .define (
239
311
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
@@ -415,7 +487,7 @@ def quantize_per_channel_meta(
415
487
# We will revisit this later if we found there are no use cases for it
416
488
quantized_decomposed_lib .define (
417
489
"dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
418
- "int quant_min, int quant_max, ScalarType dtype) -> Tensor" )
490
+ "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None ) -> Tensor" )
419
491
420
492
@impl (quantized_decomposed_lib , "dequantize_per_channel" , "CompositeExplicitAutograd" )
421
493
def dequantize_per_channel (
@@ -425,7 +497,9 @@ def dequantize_per_channel(
425
497
axis : int ,
426
498
quant_min : int ,
427
499
quant_max : int ,
428
- dtype : torch .dtype
500
+ dtype : torch .dtype ,
501
+ * ,
502
+ out_dtype : Optional [torch .dtype ] = None
429
503
) -> torch .Tensor :
430
504
""" Affine per channel dequantization for the Tensor using the same quantization
431
505
parameters for each channel/axis to map from quantized values to floating point values
@@ -450,20 +524,24 @@ def dequantize_per_channel(
450
524
dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
451
525
reserved for pattern matching)
452
526
527
+ out_dtype (torch.dtype?): optional dtype for output Tensor
528
+
453
529
Returns:
454
530
dequantized float32 Tensor
455
531
"""
456
532
assert input .dtype == dtype , f"Expecting input to have dtype { dtype } , but got dtype: { input .dtype } "
533
+ if out_dtype is None :
534
+ out_dtype = torch .float32
457
535
assert axis < input .dim (), f"Expecting axis to be < { input .dim ()} "
458
536
_quant_min_max_bounds_check (quant_min , quant_max , dtype )
459
537
input , permute_axis_list = _permute_to_axis_zero (input , axis )
460
- res = torch .zeros_like (input , dtype = torch . float32 )
538
+ res = torch .zeros_like (input , dtype = out_dtype )
461
539
462
540
for i in range (input .size (0 )):
463
541
# TODO: investigate why
464
- # (input[i] - zero_points[i]).to(torch.float32 ) * scales[i]
542
+ # (input[i] - zero_points[i]).to(out_dtype ) * scales[i]
465
543
# failed the test
466
- res [i ] = (input [i ].to (torch . float32 ) - zero_points [i ]) * scales [i ]
544
+ res [i ] = (input [i ].to (out_dtype ) - zero_points [i ]) * scales [i ]
467
545
468
546
out = res .permute (tuple (permute_axis_list ))
469
547
return out
@@ -476,12 +554,16 @@ def dequantize_per_channel_meta(
476
554
axis : int ,
477
555
quant_min : int ,
478
556
quant_max : int ,
479
- dtype : torch .dtype
557
+ dtype : torch .dtype ,
558
+ * ,
559
+ out_dtype : Optional [torch .dtype ] = None
480
560
) -> torch .Tensor :
481
561
assert input .dtype == dtype , f"Expecting input to have dtype { dtype } , but got dtype: { input .dtype } "
562
+ if out_dtype is None :
563
+ out_dtype = torch .float32
482
564
assert axis < input .dim (), f"Expecting axis to be < { input .dim ()} "
483
565
_quant_min_max_bounds_check (quant_min , quant_max , dtype )
484
- return torch .empty_like (input , dtype = torch . float32 )
566
+ return torch .empty_like (input , dtype = out_dtype )
485
567
486
568
quantized_decomposed_lib .define (
487
569
"fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
0 commit comments