Skip to content

Commit e170312

Browse files
kausikmaitifacebook-github-bot
authored andcommitted
Modify signature of dequantize ops for decomposed quantized Tensor (pytorch#121450)
Summary: X-link: pytorch/executorch#2308 Note: The initial purpose of this PR is to draw suggestion and feedback regarding better alternative, if any. At present, dequantize op for decomposed quantized Tensor representation e.g. dequantize_per_tensor() assumes the output dtype as torch.float and hence, it does not have the output dtype in its operator argument list. However, this op signature becomes unusable when the assumption breaks. Because, in case the output dtype is different from torch.float, there is no way to specify the same during dequantization. This change is aimed at generalizing the signature of dequantize op like dequantize_per_tensor() for wider use-cases where the output dtype can be different from torch.float and needs to passed during dequantization. The proposal is to use an additional argument named 'output_dtype' to solve the problem. However, we would also like to have suggestion and feedback regarding any better alternative that can be used instead. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen leslie-fang-intel Reviewed By: digantdesai Differential Revision: D53590486 Pulled By: manuelcandales
1 parent 5b5d423 commit e170312

File tree

1 file changed

+110
-28
lines changed

1 file changed

+110
-28
lines changed

torch/ao/quantization/fx/_decomposed.py

Lines changed: 110 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
from typing import Optional, Tuple
2+
13
import torch
2-
from torch.library import Library, impl
3-
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
4-
from typing import Tuple
54
from torch._refs import _unsqueeze_multiple
6-
5+
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
6+
from torch.library import impl, Library
77

88
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
99
# name is not too long
@@ -13,7 +13,7 @@
1313
torch.uint8: (0, 255),
1414
torch.int8: (-128, 127),
1515
torch.int16: (-(2**15), 2**15 - 1),
16-
torch.int32: (-(2**31), 2**31 - 1)
16+
torch.int32: (-(2**31), 2**31 - 1),
1717
}
1818

1919
# Helper to check the passed in quant min and max are valid for the dtype
@@ -60,13 +60,26 @@ def quantize_per_tensor(
6060
"""
6161
if input.dtype == torch.bfloat16:
6262
input = input.to(torch.float32)
63-
6463
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
6564
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
6665

6766
inv_scale = 1.0 / scale
6867
return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
6968

69+
@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta")
70+
def quantize_per_tensor_meta(
71+
input: torch.Tensor,
72+
scale: float,
73+
zero_point: int,
74+
quant_min: int,
75+
quant_max: int,
76+
dtype: torch.dtype
77+
) -> torch.Tensor:
78+
if input.dtype == torch.bfloat16:
79+
input = input.to(torch.float32)
80+
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
81+
return torch.empty_like(input, dtype=dtype)
82+
7083
quantized_decomposed_lib.define(
7184
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
7285
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
@@ -90,7 +103,14 @@ def quantize_per_tensor_tensor(
90103
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
91104

92105
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
93-
def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
106+
def quantize_per_tensor_tensor_meta(
107+
input: torch.Tensor,
108+
scale: torch.Tensor,
109+
zero_point: torch.Tensor,
110+
quant_min: int,
111+
quant_max: int,
112+
dtype: torch.dtype
113+
) -> torch.Tensor:
94114
if input.dtype == torch.bfloat16:
95115
input = input.to(torch.float32)
96116
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
@@ -122,7 +142,14 @@ def quantize_per_tensor_tensor2(
122142
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype)
123143

124144
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta")
125-
def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype):
145+
def quantize_per_tensor_tensor2_meta(
146+
input: torch.Tensor,
147+
scale: torch.Tensor,
148+
zero_point: torch.Tensor,
149+
quant_min: torch.Tensor,
150+
quant_max: torch.Tensor,
151+
dtype: torch.dtype
152+
) -> torch.Tensor:
126153
return quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype)
127154

128155
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
@@ -131,7 +158,7 @@ def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_
131158
# We will revisit this later if we found there are no use cases for it
132159
quantized_decomposed_lib.define(
133160
"dequantize_per_tensor(Tensor input, float scale, int zero_point, "
134-
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
161+
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
135162

136163
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
137164
def dequantize_per_tensor(
@@ -140,7 +167,9 @@ def dequantize_per_tensor(
140167
zero_point: int,
141168
quant_min: int,
142169
quant_max: int,
143-
dtype: torch.dtype
170+
dtype: torch.dtype,
171+
*,
172+
out_dtype: Optional[torch.dtype] = None
144173
) -> torch.Tensor:
145174
""" Affine dequantization for the Tensor using the same quantization parameters to map
146175
from quantized values to floating point values
@@ -163,22 +192,40 @@ def dequantize_per_tensor(
163192
dtype (torch.dtype): dtype for input Tensor (not used in computation,
164193
reserved for pattern matching)
165194
195+
out_dtype (torch.dtype?): optional dtype for output Tensor
196+
166197
Returns:
167198
dequantized float32 Tensor
168199
"""
169200
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
201+
if out_dtype is None:
202+
out_dtype = torch.float32
170203
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
171204
# TODO: investigate why
172205
# (input - zero_point).to(torch.float32) * scale
173206
# failed the test
174-
return (input.to(torch.float32) - zero_point) * scale
207+
return (input.to(out_dtype) - zero_point) * scale
175208
else:
176209
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
177210

211+
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta")
212+
def dequantize_per_tensor_meta(
213+
input: torch.Tensor,
214+
scale: torch.Tensor,
215+
zero_pointe: torch.Tensor,
216+
quant_min: int,
217+
quant_max: int,
218+
dtype: torch.dtype,
219+
*,
220+
out_dtype: Optional[torch.dtype] = None
221+
) -> torch.Tensor:
222+
if out_dtype is None:
223+
out_dtype = torch.float32
224+
return torch.empty_like(input, dtype=out_dtype)
178225

179226
quantized_decomposed_lib.define(
180227
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
181-
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
228+
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
182229

183230
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd")
184231
def dequantize_per_tensor_tensor(
@@ -187,7 +234,9 @@ def dequantize_per_tensor_tensor(
187234
zero_point: torch.Tensor,
188235
quant_min: int,
189236
quant_max: int,
190-
dtype: torch.dtype
237+
dtype: torch.dtype,
238+
*,
239+
out_dtype: Optional[torch.dtype] = None
191240
) -> torch.Tensor:
192241
""" Affine dequantization for the Tensor using the same quantization parameters to map
193242
from quantized values to floating point values
@@ -196,22 +245,33 @@ def dequantize_per_tensor_tensor(
196245
"""
197246
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
198247
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
199-
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
248+
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype, out_dtype=out_dtype)
200249

201250
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
202-
def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
251+
def dequantize_per_tensor_tensor_meta(
252+
input: torch.Tensor,
253+
scale: torch.Tensor,
254+
zero_point: torch.Tensor,
255+
quant_min: int,
256+
quant_max: int,
257+
dtype: torch.dtype,
258+
*,
259+
out_dtype: Optional[torch.dtype] = None
260+
) -> torch.Tensor:
261+
if out_dtype is None:
262+
out_dtype = torch.float32
203263
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
204264
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
205265
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
206266
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
207-
return torch.empty_like(input, dtype=torch.float32)
267+
return torch.empty_like(input, dtype=out_dtype)
208268
else:
209269
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
210270

211271
# TODO: remove other variants and keep this one
212272
quantized_decomposed_lib.define(
213273
"dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
214-
"Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor")
274+
"Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
215275

216276
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "CompositeExplicitAutograd")
217277
def dequantize_per_tensor_tensor2(
@@ -220,7 +280,9 @@ def dequantize_per_tensor_tensor2(
220280
zero_point: torch.Tensor,
221281
quant_min: torch.Tensor,
222282
quant_max: torch.Tensor,
223-
dtype: torch.dtype
283+
dtype: torch.dtype,
284+
*,
285+
out_dtype: Optional[torch.dtype] = None
224286
) -> torch.Tensor:
225287
""" Affine dequantization for the Tensor using the same quantization parameters to map
226288
from quantized values to floating point values
@@ -229,11 +291,21 @@ def dequantize_per_tensor_tensor2(
229291
"""
230292
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
231293
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
232-
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype)
294+
return dequantize_per_tensor(
295+
input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype, out_dtype=out_dtype)
233296

234297
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta")
235-
def dequantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype):
236-
return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype)
298+
def dequantize_per_tensor_tensor2_meta(
299+
input,
300+
scale,
301+
zero_point,
302+
quant_min,
303+
quant_max,
304+
dtype,
305+
*,
306+
out_dtype: Optional[torch.dtype] = None
307+
) -> torch.Tensor:
308+
return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype)
237309

238310
quantized_decomposed_lib.define(
239311
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
@@ -415,7 +487,7 @@ def quantize_per_channel_meta(
415487
# We will revisit this later if we found there are no use cases for it
416488
quantized_decomposed_lib.define(
417489
"dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
418-
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
490+
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
419491

420492
@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
421493
def dequantize_per_channel(
@@ -425,7 +497,9 @@ def dequantize_per_channel(
425497
axis: int,
426498
quant_min: int,
427499
quant_max: int,
428-
dtype: torch.dtype
500+
dtype: torch.dtype,
501+
*,
502+
out_dtype: Optional[torch.dtype] = None
429503
) -> torch.Tensor:
430504
""" Affine per channel dequantization for the Tensor using the same quantization
431505
parameters for each channel/axis to map from quantized values to floating point values
@@ -450,20 +524,24 @@ def dequantize_per_channel(
450524
dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
451525
reserved for pattern matching)
452526
527+
out_dtype (torch.dtype?): optional dtype for output Tensor
528+
453529
Returns:
454530
dequantized float32 Tensor
455531
"""
456532
assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
533+
if out_dtype is None:
534+
out_dtype = torch.float32
457535
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
458536
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
459537
input, permute_axis_list = _permute_to_axis_zero(input, axis)
460-
res = torch.zeros_like(input, dtype=torch.float32)
538+
res = torch.zeros_like(input, dtype=out_dtype)
461539

462540
for i in range(input.size(0)):
463541
# TODO: investigate why
464-
# (input[i] - zero_points[i]).to(torch.float32) * scales[i]
542+
# (input[i] - zero_points[i]).to(out_dtype) * scales[i]
465543
# failed the test
466-
res[i] = (input[i].to(torch.float32) - zero_points[i]) * scales[i]
544+
res[i] = (input[i].to(out_dtype) - zero_points[i]) * scales[i]
467545

468546
out = res.permute(tuple(permute_axis_list))
469547
return out
@@ -476,12 +554,16 @@ def dequantize_per_channel_meta(
476554
axis: int,
477555
quant_min: int,
478556
quant_max: int,
479-
dtype: torch.dtype
557+
dtype: torch.dtype,
558+
*,
559+
out_dtype: Optional[torch.dtype] = None
480560
) -> torch.Tensor:
481561
assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
562+
if out_dtype is None:
563+
out_dtype = torch.float32
482564
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
483565
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
484-
return torch.empty_like(input, dtype=torch.float32)
566+
return torch.empty_like(input, dtype=out_dtype)
485567

486568
quantized_decomposed_lib.define(
487569
"fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "

0 commit comments

Comments
 (0)