Skip to content

Commit 3ff0f77

Browse files
andrewor14facebook-github-bot
authored andcommitted
Fix llama quantize_per_token numerics
Summary: The existing implementation can produce quantized values outside the quant range, since we add the zero points after clamping. This was not a problem for symmetric quantization since zero points are 0 there, but causes dqlinear numerics to diverge significantly from the lowered implementation for asymmetric quantization. Reviewed By: digantdesai Differential Revision: D54320424 fbshipit-source-id: e8d9136354b0dac1993ef7825fc331f68d0d4c05
1 parent 9283e50 commit 3ff0f77

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

examples/models/llama2/quantize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,9 @@ def quantize_per_token(
234234
"""
235235
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
236236
_per_token_quant_qparam_dim_check(input, scales, zero_points)
237-
input = torch.round(input / scales).clamp(quant_min, quant_max).to(dtype)
238-
input = input + zero_points
237+
input = (
238+
torch.round(input / scales + zero_points).clamp(quant_min, quant_max).to(dtype)
239+
)
239240
return input
240241

241242

0 commit comments

Comments
 (0)