40
40
from torch .ao .quantization .quantizer .composable_quantizer import ComposableQuantizer
41
41
42
42
43
- act_qspec = QuantizationSpec (
44
- dtype = torch .uint8 ,
45
- quant_min = 0 ,
46
- quant_max = 255 ,
43
+ act_qspec_asym8u = QuantizationSpec (
44
+ dtype = torch .int8 ,
45
+ quant_min = - 128 ,
46
+ quant_max = 127 ,
47
47
qscheme = torch .per_tensor_affine ,
48
48
is_dynamic = False ,
49
49
observer_or_fake_quant_ctr = HistogramObserver .with_args (eps = 2 ** - 12 ),
50
50
)
51
51
52
- wgt_qspec = QuantizationSpec (
53
- dtype = torch .uint8 ,
54
- quant_min = 0 ,
55
- quant_max = 255 ,
52
+ wgt_qspec_asym8u = QuantizationSpec (
53
+ dtype = torch .int8 ,
54
+ quant_min = - 128 ,
55
+ quant_max = 127 ,
56
56
qscheme = torch .per_tensor_affine ,
57
57
is_dynamic = False ,
58
58
observer_or_fake_quant_ctr = MinMaxObserver ,
59
59
)
60
60
61
+ wgt_qspec_asym8s = QuantizationSpec (
62
+ dtype = torch .int8 ,
63
+ quant_min = - 128 ,
64
+ quant_max = 127 ,
65
+ qscheme = torch .per_tensor_symmetric ,
66
+ is_dynamic = False ,
67
+ observer_or_fake_quant_ctr = MinMaxObserver ,
68
+ )
69
+
61
70
bias_qspec : Optional [QuantizationSpec ] = None
62
71
63
- _default_qconfig = QuantizationConfig (
64
- act_qspec ,
65
- act_qspec ,
66
- wgt_qspec ,
72
+ qconfig_A8uW8u = QuantizationConfig (
73
+ act_qspec_asym8u ,
74
+ act_qspec_asym8u ,
75
+ wgt_qspec_asym8u ,
67
76
None ,
68
77
)
69
78
79
+ qconfig_A8uW8s = QuantizationConfig (
80
+ act_qspec_asym8u ,
81
+ act_qspec_asym8u ,
82
+ wgt_qspec_asym8s ,
83
+ None ,
84
+ )
70
85
71
86
class CadenceAtenQuantizer (Quantizer ):
72
87
def __init__ (
@@ -147,19 +162,17 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
147
162
return []
148
163
149
164
150
- def get_cadence_default_quantizer_list_with_config (
151
- quantization_config : QuantizationConfig ,
152
- ) -> List [Quantizer ]:
165
+ def get_cadence_default_quantizers () -> List [Quantizer ]:
153
166
return [
154
- CadenceAtenQuantizer (AddmmPattern (), quantization_config ),
155
- CadenceAtenQuantizer (BmmPattern (), quantization_config ),
156
- CadenceAtenQuantizer (Conv1dPattern (), quantization_config ),
157
- CadenceAtenQuantizer (Conv2dPattern (), quantization_config ),
158
- CadenceAtenQuantizer (LayerNormPattern (), quantization_config ),
159
- CadenceAtenQuantizer (LinearPattern (), quantization_config ),
160
- CadenceAtenQuantizer (MatmulPattern (), quantization_config ),
161
- CadenceAtenQuantizer (ReluPattern0 (), quantization_config ),
162
- CadenceAtenQuantizer (ReluPattern1 (), quantization_config ),
167
+ CadenceAtenQuantizer (AddmmPattern (), qconfig_A8uW8u ),
168
+ CadenceAtenQuantizer (BmmPattern (), qconfig_A8uW8u ),
169
+ CadenceAtenQuantizer (Conv1dPattern (), qconfig_A8uW8s ),
170
+ CadenceAtenQuantizer (Conv2dPattern (), qconfig_A8uW8s ),
171
+ CadenceAtenQuantizer (LayerNormPattern (), qconfig_A8uW8u ),
172
+ CadenceAtenQuantizer (LinearPattern (), qconfig_A8uW8u ),
173
+ CadenceAtenQuantizer (MatmulPattern (), qconfig_A8uW8u ),
174
+ CadenceAtenQuantizer (ReluPattern0 (), qconfig_A8uW8u ),
175
+ CadenceAtenQuantizer (ReluPattern1 (), qconfig_A8uW8u ),
163
176
]
164
177
165
178
@@ -178,10 +191,9 @@ class CadenceDefaultQuantizer(CadenceQuantizer):
178
191
Default quantizer for Cadence backend.
179
192
"""
180
193
181
- def __init__ (self , qconfig : Optional [QuantizationConfig ] = None ) -> None :
182
- if qconfig is None :
183
- qconfig = _default_qconfig
184
- quantizers = get_cadence_default_quantizer_list_with_config (qconfig )
194
+ def __init__ (self , quantizers : Optional [list [Quantizer ]] = None ) -> None :
195
+ if quantizers is None :
196
+ quantizers = get_cadence_default_quantizers ()
185
197
super ().__init__ (quantizers )
186
198
187
199
0 commit comments