Skip to content

Commit e95f171

Browse files
authored
Qualcomm AI Engine Direct - Quantizer refine for qat
Differential Revision: D65738212 Pull Request resolved: #6747
1 parent 19268de commit e95f171

File tree

12 files changed

+797
-602
lines changed

12 files changed

+797
-602
lines changed

backends/qualcomm/quantizer/utils.py renamed to backends/qualcomm/quantizer/annotators.py

Lines changed: 32 additions & 431 deletions
Large diffs are not rendered by default.

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,12 @@
66
from typing import Sequence
77

88
import torch
9+
from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
910
from executorch.backends.qualcomm.quantizer.quantizer import (
1011
get_16a8w_qnn_ptq_config,
11-
get_default_8bit_qnn_ptq_config,
12-
QuantizationConfig,
13-
)
14-
from executorch.backends.qualcomm.quantizer.utils import (
12+
get_8a8w_qnn_ptq_config,
1513
get_ptq_per_channel_quant_config,
16-
QUANT_ANNOTATION_KEY,
14+
QuantizationConfig,
1715
)
1816
from executorch.exir.dialects._ops import ops as exir_ops
1917
from torch.ao.quantization.quantizer import (
@@ -113,7 +111,7 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig):
113111
# Annotate 16a8w for matmul op to get better performance
114112
quantization_config_16a8w = get_16a8w_qnn_ptq_config()
115113
# Annotate 8a8w for second input of matmul until past_kv_cache
116-
quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True)
114+
quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True)
117115
for node in gm.graph.nodes:
118116
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
119117
if "nn_module_stack" in node.meta:
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import torch
2+
from torch.ao.quantization.observer import UniformQuantizationObserverBase
3+
4+
5+
# TODO move to torch/ao/quantization/observer.py.
6+
class PerChannelParamObserver(UniformQuantizationObserverBase):
7+
"""
8+
Minimize quantization loss caused by outlier via linear search. More details can be found at https://arxiv.org/pdf/2209.13325
9+
"""
10+
11+
def __init__(
12+
self,
13+
ch_axis=0,
14+
use_mse=True,
15+
steps=100,
16+
dtype=torch.int8,
17+
qscheme=torch.per_channel_symmetric,
18+
reduce_range=False,
19+
quant_min=None,
20+
quant_max=None,
21+
factory_kwargs=None,
22+
eps=torch.finfo(torch.float32).eps, # noqa: B008
23+
is_dynamic=False,
24+
**kwargs,
25+
) -> None:
26+
super().__init__(
27+
dtype=dtype,
28+
qscheme=qscheme,
29+
reduce_range=reduce_range,
30+
quant_min=quant_min,
31+
quant_max=quant_max,
32+
factory_kwargs=factory_kwargs,
33+
eps=eps,
34+
is_dynamic=is_dynamic,
35+
**kwargs,
36+
)
37+
38+
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
39+
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
40+
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
41+
self.ch_axis = ch_axis
42+
self.use_mse = use_mse
43+
self.steps = steps
44+
self.calibrated = False
45+
46+
def to_ch_axis(self, x):
47+
axis_order = list(range(len(x.size())))
48+
axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis
49+
return torch.flatten(x.permute(axis_order), start_dim=1)
50+
51+
def mse(self, pred, expect):
52+
loss = (pred - expect).abs().pow(2)
53+
return self.to_ch_axis(loss).mean(1)
54+
55+
def cosine(self, pred, expect):
56+
target = torch.ones(pred.shape[self.ch_axis])
57+
pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1)
58+
expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1)
59+
return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target)
60+
61+
def loss_fn(self, x, new_min, new_max):
62+
scale, offset = self._calculate_qparams(new_min, new_max)
63+
x_q = torch.fake_quantize_per_channel_affine(
64+
x,
65+
scale.data,
66+
offset.data.int(),
67+
self.ch_axis,
68+
self.quant_min,
69+
self.quant_max,
70+
)
71+
return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x)
72+
73+
def line_search(self, x):
74+
x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1)
75+
x_range = torch.max(x_min.abs(), x_max)
76+
optimal_loss = torch.zeros_like(x_min) + 1e9
77+
78+
# check which clip range could produce smallest loss
79+
for i in range(1, self.steps + 1):
80+
thres = x_range / self.steps * i
81+
current_loss = self.loss_fn(x, -thres, thres)
82+
x_min = torch.where(current_loss < optimal_loss, -thres, x_min)
83+
x_max = torch.where(current_loss < optimal_loss, thres, x_max)
84+
optimal_loss = torch.min(current_loss, optimal_loss)
85+
86+
return x_min, x_max
87+
88+
def forward(self, x_orig):
89+
# since params are static, one calibration is enough
90+
if not self.calibrated:
91+
x = x_orig.detach().to(self.min_val.dtype)
92+
self.min_val, self.max_val = self.line_search(x)
93+
self.calibrated = True
94+
95+
# return fake-quant result for saturating outliers
96+
scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
97+
return torch.fake_quantize_per_channel_affine(
98+
x_orig,
99+
scale.data,
100+
zero_point.data.int(),
101+
self.ch_axis,
102+
self.quant_min,
103+
self.quant_max,
104+
)
105+
106+
@torch.jit.export
107+
def calculate_qparams(self):
108+
return self._calculate_qparams(self.min_val, self.max_val)

0 commit comments

Comments
 (0)