Skip to content

Commit 2a42737

Browse files
andrewor14facebook-github-bot
authored andcommitted
Add choose_qparams_per_token_asymmetric for llama on XNNPACK
Summary: XNNPACK uses asymmetric activation quantizations, but the existing `choose_qparams_per_token` assumed symmetric quantization (zero point is always 0). This caused significant numerical discrepancies between eager and lowered models. This commit adds a new asymmetric version of `choose_qparams_per_token` for this purpose. Reviewed By: digantdesai Differential Revision: D54323650 fbshipit-source-id: afd1e8f8b582bc8c07d4b03752ab71caa30c2bb0
1 parent 3ff0f77 commit 2a42737

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

examples/models/llama2/quantize.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,74 @@ def choose_qparams_per_token_meta(
189189
)
190190

191191

192+
# TODO: move this to https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py
193+
quantized_decomposed_lib.define(
194+
"choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
195+
)
196+
197+
198+
@impl(
199+
quantized_decomposed_lib,
200+
"choose_qparams_per_token_asymmetric",
201+
"CompositeExplicitAutograd",
202+
)
203+
def choose_qparams_per_token_asymmetric(
204+
input: torch.Tensor,
205+
dtype: torch.dtype,
206+
) -> Tuple[torch.Tensor, torch.Tensor]:
207+
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
208+
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
209+
every N elements with the same quantization parameter. The dimension for scales/zero_points
210+
will be (M1 * M2 ... * Mn)
211+
212+
Args:
213+
input (torch.Tensor): original float32/float16 Tensor
214+
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
215+
216+
Returns:
217+
scales and zero_points, both float32 Tensors
218+
"""
219+
# Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
220+
qmin, qmax = -128, 127
221+
min_val, max_val = torch.aminmax(input, dim=-1, keepdim=True)
222+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
223+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
224+
eps = torch.finfo(torch.float32).eps # use xnnpack eps?
225+
226+
# scale
227+
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
228+
scale = scale.clamp(min=eps)
229+
230+
# zero point
231+
descaled_min = min_val_neg / scale
232+
descaled_max = max_val_pos / scale
233+
zero_point_from_min_error = qmin + descaled_min
234+
zero_point_from_max_error = qmax + descaled_max
235+
zero_point = torch.where(
236+
zero_point_from_min_error + zero_point_from_max_error > 0,
237+
qmin - descaled_min,
238+
qmax - descaled_max,
239+
)
240+
zero_point = torch.clamp(zero_point, qmin, qmax).round()
241+
242+
return scale.to(torch.float32), zero_point.to(torch.float32)
243+
244+
245+
@impl(
246+
quantized_decomposed_lib,
247+
"choose_qparams_per_token_asymmetric",
248+
"Meta",
249+
)
250+
def choose_qparams_per_token_asymmetric_meta(
251+
input: torch.Tensor,
252+
dtype: torch.dtype,
253+
) -> Tuple[torch.Tensor, torch.Tensor]:
254+
size = (1, input.size(-1))
255+
return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
256+
size, dtype=torch.int64, device=input.device
257+
)
258+
259+
192260
def _per_token_quant_qparam_dim_check(input, scales, zero_points):
193261
num_tokens = math.prod(list(input.size())[:-1])
194262
assert (

0 commit comments

Comments
 (0)