4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
from enum import IntEnum , unique
7
- from typing import Any , Callable , Dict , List , Optional , Sequence , Set , Tuple
7
+ from typing import Callable , Dict , Optional , Sequence , Set
8
8
9
9
import torch
10
10
from executorch .backends .qualcomm .passes .convert_hardsigmoid import ConvertHardsigmoid
16
16
from executorch .backends .qualcomm .passes .remove_clone import RemoveClone
17
17
from executorch .backends .qualcomm .passes .replace_inf_buffer import ReplaceInfBuffer
18
18
19
- from torch import Tensor
20
19
from torch ._ops import OpOverload
21
- from torch .ao .quantization .observer import (
22
- HistogramObserver ,
23
- MinMaxObserver ,
24
- MovingAverageMinMaxObserver ,
25
- PerChannelMinMaxObserver ,
20
+ from torch .ao .quantization .quantizer import Quantizer
21
+ from torch .fx import GraphModule
22
+
23
+ from .utils import (
24
+ get_16a4w_qnn_ptq_config ,
25
+ get_default_16bit_qnn_ptq_config ,
26
+ get_default_8bit_qnn_ptq_config ,
27
+ get_ptq_per_channel_weight_config ,
28
+ OP_ANNOTATOR ,
29
+ QuantizationConfig ,
26
30
)
27
- from torch .ao .quantization .quantizer import (
28
- DerivedQuantizationSpec ,
29
- QuantizationSpec ,
30
- Quantizer ,
31
- )
32
-
33
- from torch .fx import GraphModule , Node
34
-
35
- from .utils import OP_ANNOTATOR , QuantizationConfig
36
31
37
32
__all__ = [
38
33
"QnnQuantizer" ,
@@ -54,205 +49,6 @@ class QuantDtype(IntEnum):
54
49
use_8a8w = 2
55
50
56
51
57
- def _derived_bias_quant_spec (node : Node ) -> DerivedQuantizationSpec :
58
- def _derive_bias_qparams_fn (
59
- obs_or_fqs : List ,
60
- ) -> Tuple [Tensor , Tensor ]:
61
- assert (
62
- len (obs_or_fqs ) == 2
63
- ), f"Expecting two obs/fqs, one for activation and one for weight, got: { len (obs_or_fqs )} "
64
- act_obs_or_fq = obs_or_fqs [0 ]
65
- weight_obs_or_fq = obs_or_fqs [1 ]
66
- weight_scale , weight_zp = weight_obs_or_fq .calculate_qparams ()
67
- act_scale , act_zp = act_obs_or_fq .calculate_qparams ()
68
- (broadcast_act_scale , broadcast_weight_scale ) = torch .broadcast_tensors (
69
- act_scale , weight_scale
70
- )
71
- derived_scale = (broadcast_act_scale * broadcast_weight_scale ).to (torch .float32 )
72
- derived_zero = torch .zeros (derived_scale .size ()).to (torch .int32 )
73
- return (derived_scale , derived_zero )
74
-
75
- input_act = node .args [0 ]
76
- assert isinstance (input_act , Node )
77
- weight = node .args [1 ]
78
- assert isinstance (weight , Node )
79
-
80
- return DerivedQuantizationSpec (
81
- derived_from = [(input_act , node ), (weight , node )],
82
- derive_qparams_fn = _derive_bias_qparams_fn ,
83
- dtype = torch .int32 ,
84
- quant_min = torch .iinfo (torch .int32 ).min ,
85
- quant_max = torch .iinfo (torch .int32 ).max ,
86
- ch_axis = 0 ,
87
- qscheme = torch .per_channel_symmetric ,
88
- )
89
-
90
-
91
- def get_default_8bit_qnn_ptq_config () -> QuantizationConfig :
92
- extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
93
-
94
- act_quantization_spec = QuantizationSpec (
95
- dtype = torch .uint8 ,
96
- quant_min = 0 ,
97
- quant_max = torch .iinfo (torch .uint8 ).max ,
98
- qscheme = torch .per_tensor_affine ,
99
- observer_or_fake_quant_ctr = MovingAverageMinMaxObserver .with_args (** extra_args ),
100
- )
101
-
102
- weight_quantization_spec = QuantizationSpec (
103
- dtype = torch .int8 ,
104
- quant_min = torch .iinfo (torch .int8 ).min + 1 ,
105
- quant_max = torch .iinfo (torch .int8 ).max ,
106
- qscheme = torch .per_tensor_symmetric ,
107
- ch_axis = 0 ,
108
- observer_or_fake_quant_ctr = MinMaxObserver .with_args (** extra_args ),
109
- )
110
-
111
- bias_quantization_spec = QuantizationSpec (
112
- dtype = torch .int32 ,
113
- quant_min = torch .iinfo (torch .int32 ).min ,
114
- quant_max = torch .iinfo (torch .int32 ).max ,
115
- qscheme = torch .per_tensor_symmetric ,
116
- observer_or_fake_quant_ctr = MinMaxObserver .with_args (** extra_args ),
117
- )
118
-
119
- quantization_config = QuantizationConfig (
120
- input_activation = act_quantization_spec ,
121
- output_activation = act_quantization_spec ,
122
- weight = weight_quantization_spec ,
123
- bias = bias_quantization_spec ,
124
- )
125
-
126
- return quantization_config
127
-
128
-
129
- # 4 bits quantization only supports specific ops.
130
- def get_16a4w_qnn_ptq_config () -> QuantizationConfig :
131
- extra_args : Dict [str , Any ] = {"eps" : 2 ** - 20 }
132
- act_quantization_spec = QuantizationSpec (
133
- dtype = torch .int32 ,
134
- quant_min = torch .iinfo (torch .uint16 ).min ,
135
- quant_max = torch .iinfo (torch .uint16 ).max ,
136
- qscheme = torch .per_tensor_affine ,
137
- observer_or_fake_quant_ctr = MovingAverageMinMaxObserver .with_args (** extra_args ),
138
- )
139
-
140
- weight_quantization_spec = QuantizationSpec (
141
- dtype = torch .int8 ,
142
- quant_min = - 7 ,
143
- quant_max = 7 ,
144
- qscheme = torch .per_tensor_symmetric ,
145
- ch_axis = 0 ,
146
- observer_or_fake_quant_ctr = MinMaxObserver .with_args (** extra_args ),
147
- )
148
-
149
- bias_quantization_spec = QuantizationSpec (
150
- dtype = torch .int32 ,
151
- quant_min = torch .iinfo (torch .int32 ).min ,
152
- quant_max = torch .iinfo (torch .int32 ).max ,
153
- qscheme = torch .per_tensor_symmetric ,
154
- observer_or_fake_quant_ctr = MinMaxObserver .with_args (** extra_args ),
155
- )
156
-
157
- quantization_config = QuantizationConfig (
158
- input_activation = act_quantization_spec ,
159
- output_activation = act_quantization_spec ,
160
- weight = weight_quantization_spec ,
161
- bias = bias_quantization_spec ,
162
- )
163
-
164
- return quantization_config
165
-
166
-
167
- def get_default_16bit_qnn_ptq_config () -> QuantizationConfig :
168
- extra_args : Dict [str , Any ] = {"eps" : 2 ** - 20 }
169
- act_quantization_spec = QuantizationSpec (
170
- dtype = torch .int32 ,
171
- quant_min = torch .iinfo (torch .uint16 ).min ,
172
- quant_max = torch .iinfo (torch .uint16 ).max ,
173
- qscheme = torch .per_tensor_affine ,
174
- observer_or_fake_quant_ctr = MovingAverageMinMaxObserver .with_args (** extra_args ),
175
- )
176
-
177
- weight_quantization_spec = QuantizationSpec (
178
- dtype = torch .int16 ,
179
- quant_min = torch .iinfo (torch .int16 ).min + 1 ,
180
- quant_max = torch .iinfo (torch .int16 ).max ,
181
- qscheme = torch .per_tensor_symmetric ,
182
- ch_axis = 0 ,
183
- observer_or_fake_quant_ctr = MinMaxObserver .with_args (** extra_args ),
184
- )
185
-
186
- # torch does not support uint16 quantization, use int32 to bypass
187
- bias_quantization_spec = QuantizationSpec (
188
- dtype = torch .int32 ,
189
- quant_min = torch .iinfo (torch .int32 ).min ,
190
- quant_max = torch .iinfo (torch .int32 ).max ,
191
- qscheme = torch .per_tensor_symmetric ,
192
- observer_or_fake_quant_ctr = MinMaxObserver .with_args (** extra_args ),
193
- )
194
-
195
- quantization_config = QuantizationConfig (
196
- input_activation = act_quantization_spec ,
197
- output_activation = act_quantization_spec ,
198
- weight = weight_quantization_spec ,
199
- bias = bias_quantization_spec ,
200
- )
201
-
202
- return quantization_config
203
-
204
-
205
- def get_ptq_per_channel_weight_config (
206
- act_dtype = torch .uint8 , weight_dtype = torch .int8
207
- ) -> QuantizationConfig :
208
- extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
209
-
210
- supported_act_types = {
211
- torch .uint8 ,
212
- torch .uint16 ,
213
- torch .int8 ,
214
- torch .int16 ,
215
- }
216
- # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
217
- supported_weight_dtypes = {"int4" , torch .int8 , torch .int16 }
218
- assert (
219
- act_dtype in supported_act_types
220
- ), f"act_dtype, { act_dtype } is not one of supported types, { supported_act_types } "
221
-
222
- assert (
223
- weight_dtype in supported_weight_dtypes
224
- ), f"weight_dtype, { weight_dtype } is not one of supported types, { supported_weight_dtypes } "
225
-
226
- # torch do not support uint16 quantization, use int32 to bypass
227
- act_quantization_spec = QuantizationSpec (
228
- dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
229
- quant_min = torch .iinfo (act_dtype ).min ,
230
- quant_max = torch .iinfo (act_dtype ).max ,
231
- qscheme = torch .per_tensor_affine ,
232
- observer_or_fake_quant_ctr = HistogramObserver .with_args (** extra_args ),
233
- )
234
-
235
- weight_quantization_spec = QuantizationSpec (
236
- dtype = torch .int8 if weight_dtype == "int4" else weight_dtype ,
237
- quant_min = - 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).min + 1 ,
238
- quant_max = 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).max ,
239
- qscheme = torch .per_channel_symmetric ,
240
- ch_axis = 0 ,
241
- observer_or_fake_quant_ctr = PerChannelMinMaxObserver .with_args (** extra_args ),
242
- )
243
-
244
- bias_quantization_spec = _derived_bias_quant_spec
245
-
246
- quantization_config = QuantizationConfig (
247
- input_activation = act_quantization_spec ,
248
- output_activation = act_quantization_spec ,
249
- weight = weight_quantization_spec ,
250
- bias = bias_quantization_spec ,
251
- )
252
-
253
- return quantization_config
254
-
255
-
256
52
class QnnQuantizer (Quantizer ):
257
53
SUPPORTED_OPS : Set = set (OP_ANNOTATOR .keys ())
258
54
0 commit comments