7
7
import itertools
8
8
import operator
9
9
from dataclasses import dataclass
10
- from typing import Callable , Dict , List , Optional , Sequence , Tuple , Type , Union
10
+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Type , Union
11
11
12
12
import torch
13
13
import torch .nn .functional as F
@@ -123,12 +123,14 @@ def _derive_bias_qparams_fn(
123
123
def get_default_qnn_ptq_config (
124
124
enable_per_channel_conv_quant = False ,
125
125
) -> Tuple [QuantizationConfig , QnnQuantizerConfig ]:
126
+ extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
127
+
126
128
act_quantization_spec = QuantizationSpec (
127
129
dtype = torch .uint8 ,
128
130
quant_min = 0 ,
129
131
quant_max = 255 ,
130
132
qscheme = torch .per_tensor_affine ,
131
- observer_or_fake_quant_ctr = HistogramObserver .with_args (),
133
+ observer_or_fake_quant_ctr = HistogramObserver .with_args (** extra_args ),
132
134
)
133
135
134
136
weight_quantization_spec = QuantizationSpec (
@@ -137,15 +139,15 @@ def get_default_qnn_ptq_config(
137
139
quant_max = 127 ,
138
140
qscheme = torch .per_tensor_symmetric ,
139
141
ch_axis = 0 ,
140
- observer_or_fake_quant_ctr = MinMaxObserver .with_args (),
142
+ observer_or_fake_quant_ctr = MinMaxObserver .with_args (** extra_args ),
141
143
)
142
144
143
145
bias_quantization_spec = QuantizationSpec (
144
146
dtype = torch .int32 ,
145
147
quant_min = torch .iinfo (torch .int32 ).min ,
146
148
quant_max = torch .iinfo (torch .int32 ).max ,
147
149
qscheme = torch .per_tensor_symmetric ,
148
- observer_or_fake_quant_ctr = MinMaxObserver .with_args (),
150
+ observer_or_fake_quant_ctr = MinMaxObserver .with_args (** extra_args ),
149
151
)
150
152
151
153
quantization_config = QuantizationConfig (
@@ -163,12 +165,14 @@ def get_default_qnn_ptq_config(
163
165
164
166
165
167
def get_ptq_per_channel_weight_config () -> QuantizationConfig :
168
+ extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
169
+
166
170
act_quantization_spec = QuantizationSpec (
167
171
dtype = torch .uint8 ,
168
172
quant_min = 0 ,
169
173
quant_max = 255 ,
170
174
qscheme = torch .per_tensor_affine ,
171
- observer_or_fake_quant_ctr = HistogramObserver .with_args (),
175
+ observer_or_fake_quant_ctr = HistogramObserver .with_args (** extra_args ),
172
176
)
173
177
174
178
weight_quantization_spec = QuantizationSpec (
@@ -177,7 +181,7 @@ def get_ptq_per_channel_weight_config() -> QuantizationConfig:
177
181
quant_max = 127 ,
178
182
qscheme = torch .per_channel_symmetric ,
179
183
ch_axis = 0 ,
180
- observer_or_fake_quant_ctr = PerChannelMinMaxObserver .with_args (),
184
+ observer_or_fake_quant_ctr = PerChannelMinMaxObserver .with_args (** extra_args ),
181
185
)
182
186
183
187
bias_quantization_spec = _derived_bias_quant_spec
0 commit comments