Skip to content

Commit 068f43c

Browse files
chunit-quicJoey Tsai
andauthored
Qualcomm AI Engine Direct - Quantizer refine for qat (#6513)
* [Qualcomm AI Engine Direct - Quantizer refine for qat] - Reorginize qualcomm/quantizer - Split quantizer/utils.py to -- qconfig -- annotators -- observers directory - Change coresponding callees - Rename get_default_Nbit_qnn_ptq_config to get_NaNw_qnn_ptq_config - Add 16a4w conv test* (It is not compared with original model) * Fix baed on comments - Move and rename param_observer.py to per_channel_param_observer.py - Add todo to merge qconfig * Add a comment - Add todo for per_channel_param_observer.py * [Fix lint] --------- Co-authored-by: Joey Tsai <[email protected]>
1 parent f7e26d7 commit 068f43c

File tree

13 files changed

+790
-584
lines changed

13 files changed

+790
-584
lines changed

backends/qualcomm/quantizer/utils.py renamed to backends/qualcomm/quantizer/annotators.py

Lines changed: 32 additions & 415 deletions
Large diffs are not rendered by default.

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from typing import Sequence
77

88
import torch
9+
from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
910
from executorch.backends.qualcomm.quantizer.quantizer import (
1011
get_16a8w_qnn_ptq_config,
11-
get_default_8bit_qnn_ptq_config,
12+
get_8a8w_qnn_ptq_config,
1213
QuantizationConfig,
1314
)
14-
from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from torch.ao.quantization.quantizer import (
1717
QuantizationAnnotation,
@@ -110,7 +110,7 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig):
110110
# Annotate 16a8w for matmul op to get better performance
111111
quantization_config_16a8w = get_16a8w_qnn_ptq_config()
112112
# Annotate 8a8w for second input of matmul until past_kv_cache
113-
quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True)
113+
quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True)
114114
for node in gm.graph.nodes:
115115
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
116116
if "nn_module_stack" in node.meta:
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import torch
2+
from torch.ao.quantization.observer import UniformQuantizationObserverBase
3+
4+
5+
# TODO move to torch/ao/quantization/observer.py.
6+
class PerChannelParamObserver(UniformQuantizationObserverBase):
7+
def __init__(
8+
self,
9+
ch_axis=0,
10+
use_mse=True,
11+
steps=100,
12+
dtype=torch.int8,
13+
qscheme=torch.per_channel_symmetric,
14+
reduce_range=False,
15+
quant_min=None,
16+
quant_max=None,
17+
factory_kwargs=None,
18+
eps=torch.finfo(torch.float32).eps, # noqa: B008
19+
is_dynamic=False,
20+
**kwargs,
21+
) -> None:
22+
super().__init__(
23+
dtype=dtype,
24+
qscheme=qscheme,
25+
reduce_range=reduce_range,
26+
quant_min=quant_min,
27+
quant_max=quant_max,
28+
factory_kwargs=factory_kwargs,
29+
eps=eps,
30+
is_dynamic=is_dynamic,
31+
**kwargs,
32+
)
33+
34+
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
35+
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
36+
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
37+
self.ch_axis = ch_axis
38+
self.use_mse = use_mse
39+
self.steps = steps
40+
self.calibrated = False
41+
42+
def to_ch_axis(self, x):
43+
axis_order = list(range(len(x.size())))
44+
axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis
45+
return torch.flatten(x.permute(axis_order), start_dim=1)
46+
47+
def mse(self, pred, expect):
48+
loss = (pred - expect).abs().pow(2)
49+
return self.to_ch_axis(loss).mean(1)
50+
51+
def cosine(self, pred, expect):
52+
target = torch.ones(pred.shape[self.ch_axis])
53+
pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1)
54+
expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1)
55+
return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target)
56+
57+
def loss_fn(self, x, new_min, new_max):
58+
scale, offset = self._calculate_qparams(new_min, new_max)
59+
x_q = torch.fake_quantize_per_channel_affine(
60+
x,
61+
scale.data,
62+
offset.data.int(),
63+
self.ch_axis,
64+
self.quant_min,
65+
self.quant_max,
66+
)
67+
return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x)
68+
69+
def line_search(self, x):
70+
x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1)
71+
x_range = torch.max(x_min.abs(), x_max)
72+
optimal_loss = torch.zeros_like(x_min) + 1e9
73+
74+
# check which clip range could produce smallest loss
75+
for i in range(1, self.steps + 1):
76+
thres = x_range / self.steps * i
77+
current_loss = self.loss_fn(x, -thres, thres)
78+
x_min = torch.where(current_loss < optimal_loss, -thres, x_min)
79+
x_max = torch.where(current_loss < optimal_loss, thres, x_max)
80+
optimal_loss = torch.min(current_loss, optimal_loss)
81+
82+
return x_min, x_max
83+
84+
def forward(self, x_orig):
85+
# since params are static, one calibration is enough
86+
if not self.calibrated:
87+
x = x_orig.detach().to(self.min_val.dtype)
88+
self.min_val, self.max_val = self.line_search(x)
89+
self.calibrated = True
90+
91+
# return fake-quant result for saturating outliers
92+
scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
93+
return torch.fake_quantize_per_channel_affine(
94+
x_orig,
95+
scale.data,
96+
zero_point.data.int(),
97+
self.ch_axis,
98+
self.quant_min,
99+
self.quant_max,
100+
)
101+
102+
@torch.jit.export
103+
def calculate_qparams(self):
104+
return self._calculate_qparams(self.min_val, self.max_val)

0 commit comments

Comments
 (0)