Skip to content

Commit d2d44e1

Browse files
author
Joey Tsai
committed
[Qualcomm AI Engine Direct - Quantizer refine for qat]
- Reorginize qualcomm/quantizer - Split quantizer/utils.py to -- qconfig -- annotators -- observers directory - Change coresponding callees - Rename get_default_Nbit_qnn_ptq_config to get_NaNw_qnn_ptq_config - Add 16a4w conv test* (It is not compared with original model)
1 parent ecdc007 commit d2d44e1

File tree

12 files changed

+786
-601
lines changed

12 files changed

+786
-601
lines changed

backends/qualcomm/quantizer/utils.py renamed to backends/qualcomm/quantizer/annotators.py

Lines changed: 32 additions & 431 deletions
Large diffs are not rendered by default.

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,12 @@
66
from typing import Sequence
77

88
import torch
9+
from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
910
from executorch.backends.qualcomm.quantizer.quantizer import (
1011
get_16a8w_qnn_ptq_config,
11-
get_default_8bit_qnn_ptq_config,
12+
get_8a8w_qnn_ptq_config,
1213
QuantizationConfig,
1314
)
14-
from executorch.backends.qualcomm.quantizer.utils import (
15-
get_ptq_per_channel_quant_config,
16-
QUANT_ANNOTATION_KEY,
17-
)
1815
from executorch.exir.dialects._ops import ops as exir_ops
1916
from torch.ao.quantization.quantizer import (
2017
QuantizationAnnotation,
@@ -113,7 +110,7 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig):
113110
# Annotate 16a8w for matmul op to get better performance
114111
quantization_config_16a8w = get_16a8w_qnn_ptq_config()
115112
# Annotate 8a8w for second input of matmul until past_kv_cache
116-
quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True)
113+
quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True)
117114
for node in gm.graph.nodes:
118115
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
119116
if "nn_module_stack" in node.meta:
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import torch
2+
from torch.ao.quantization.observer import UniformQuantizationObserverBase
3+
4+
5+
class ParamObserver(UniformQuantizationObserverBase):
6+
def __init__(
7+
self,
8+
ch_axis=0,
9+
use_mse=True,
10+
steps=100,
11+
dtype=torch.int8,
12+
qscheme=torch.per_channel_symmetric,
13+
reduce_range=False,
14+
quant_min=None,
15+
quant_max=None,
16+
factory_kwargs=None,
17+
eps=torch.finfo(torch.float32).eps, # noqa: B008
18+
is_dynamic=False,
19+
**kwargs,
20+
) -> None:
21+
super().__init__(
22+
dtype=dtype,
23+
qscheme=qscheme,
24+
reduce_range=reduce_range,
25+
quant_min=quant_min,
26+
quant_max=quant_max,
27+
factory_kwargs=factory_kwargs,
28+
eps=eps,
29+
is_dynamic=is_dynamic,
30+
**kwargs,
31+
)
32+
33+
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
34+
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
35+
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
36+
self.ch_axis = ch_axis
37+
self.use_mse = use_mse
38+
self.steps = steps
39+
self.calibrated = False
40+
41+
def to_ch_axis(self, x):
42+
axis_order = list(range(len(x.size())))
43+
axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis
44+
return torch.flatten(x.permute(axis_order), start_dim=1)
45+
46+
def mse(self, pred, expect):
47+
loss = (pred - expect).abs().pow(2)
48+
return self.to_ch_axis(loss).mean(1)
49+
50+
def cosine(self, pred, expect):
51+
target = torch.ones(pred.shape[self.ch_axis])
52+
pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1)
53+
expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1)
54+
return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target)
55+
56+
def loss_fn(self, x, new_min, new_max):
57+
scale, offset = self._calculate_qparams(new_min, new_max)
58+
x_q = torch.fake_quantize_per_channel_affine(
59+
x,
60+
scale.data,
61+
offset.data.int(),
62+
self.ch_axis,
63+
self.quant_min,
64+
self.quant_max,
65+
)
66+
return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x)
67+
68+
def line_search(self, x):
69+
x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1)
70+
x_range = torch.max(x_min.abs(), x_max)
71+
optimal_loss = torch.zeros_like(x_min) + 1e9
72+
73+
# check which clip range could produce smallest loss
74+
for i in range(1, self.steps + 1):
75+
thres = x_range / self.steps * i
76+
current_loss = self.loss_fn(x, -thres, thres)
77+
x_min = torch.where(current_loss < optimal_loss, -thres, x_min)
78+
x_max = torch.where(current_loss < optimal_loss, thres, x_max)
79+
optimal_loss = torch.min(current_loss, optimal_loss)
80+
81+
return x_min, x_max
82+
83+
def forward(self, x_orig):
84+
# since params are static, one calibration is enough
85+
if not self.calibrated:
86+
x = x_orig.detach().to(self.min_val.dtype)
87+
self.min_val, self.max_val = self.line_search(x)
88+
self.calibrated = True
89+
90+
# return fake-quant result for saturating outliers
91+
scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
92+
return torch.fake_quantize_per_channel_affine(
93+
x_orig,
94+
scale.data,
95+
zero_point.data.int(),
96+
self.ch_axis,
97+
self.quant_min,
98+
self.quant_max,
99+
)
100+
101+
@torch.jit.export
102+
def calculate_qparams(self):
103+
return self._calculate_qparams(self.min_val, self.max_val)

0 commit comments

Comments
 (0)