Skip to content

Commit 7445cba

Browse files
author
Hao-Wei Hsu
committed
Add observer eps back for better accuracy
1 parent bfff94a commit 7445cba

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

backends/qualcomm/qnn_quantizer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import itertools
88
import operator
99
from dataclasses import dataclass
10-
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
10+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
1111

1212
import torch
1313
import torch.nn.functional as F
@@ -123,12 +123,14 @@ def _derive_bias_qparams_fn(
123123
def get_default_qnn_ptq_config(
124124
enable_per_channel_conv_quant=False,
125125
) -> Tuple[QuantizationConfig, QnnQuantizerConfig]:
126+
extra_args: Dict[str, Any] = {"eps": 2**-12}
127+
126128
act_quantization_spec = QuantizationSpec(
127129
dtype=torch.uint8,
128130
quant_min=0,
129131
quant_max=255,
130132
qscheme=torch.per_tensor_affine,
131-
observer_or_fake_quant_ctr=HistogramObserver.with_args(),
133+
observer_or_fake_quant_ctr=HistogramObserver.with_args(**extra_args),
132134
)
133135

134136
weight_quantization_spec = QuantizationSpec(
@@ -137,15 +139,15 @@ def get_default_qnn_ptq_config(
137139
quant_max=127,
138140
qscheme=torch.per_tensor_symmetric,
139141
ch_axis=0,
140-
observer_or_fake_quant_ctr=MinMaxObserver.with_args(),
142+
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
141143
)
142144

143145
bias_quantization_spec = QuantizationSpec(
144146
dtype=torch.int32,
145147
quant_min=torch.iinfo(torch.int32).min,
146148
quant_max=torch.iinfo(torch.int32).max,
147149
qscheme=torch.per_tensor_symmetric,
148-
observer_or_fake_quant_ctr=MinMaxObserver.with_args(),
150+
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
149151
)
150152

151153
quantization_config = QuantizationConfig(
@@ -163,12 +165,14 @@ def get_default_qnn_ptq_config(
163165

164166

165167
def get_ptq_per_channel_weight_config() -> QuantizationConfig:
168+
extra_args: Dict[str, Any] = {"eps": 2**-12}
169+
166170
act_quantization_spec = QuantizationSpec(
167171
dtype=torch.uint8,
168172
quant_min=0,
169173
quant_max=255,
170174
qscheme=torch.per_tensor_affine,
171-
observer_or_fake_quant_ctr=HistogramObserver.with_args(),
175+
observer_or_fake_quant_ctr=HistogramObserver.with_args(**extra_args),
172176
)
173177

174178
weight_quantization_spec = QuantizationSpec(
@@ -177,7 +181,7 @@ def get_ptq_per_channel_weight_config() -> QuantizationConfig:
177181
quant_max=127,
178182
qscheme=torch.per_channel_symmetric,
179183
ch_axis=0,
180-
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(),
184+
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
181185
)
182186

183187
bias_quantization_spec = _derived_bias_quant_spec

0 commit comments

Comments
 (0)