1
1
# mypy: allow-untyped-defs
2
2
import itertools
3
- import typing
4
- from dataclasses import dataclass
5
- from typing import Callable , NamedTuple , Optional
3
+ from typing import Callable , Optional
6
4
7
5
import torch
8
6
import torch .nn .functional as F
9
7
from executorch .backends .xnnpack .utils .utils import is_depthwise_conv
10
8
from torch ._subclasses import FakeTensor
11
- from torch .ao .quantization .fx .utils import get_new_attr_name_with_prefix
12
- from torch .ao .quantization .pt2e .export_utils import _WrapperModule
13
- from torch .ao .quantization .pt2e .utils import (
14
- _get_aten_graph_module_for_pattern ,
15
- _is_conv_node ,
16
- _is_conv_transpose_node ,
9
+ from torch .fx import Node
10
+ from torch .fx .passes .utils .matcher_with_name_node_map_utils import (
11
+ SubgraphMatcherWithNameNodeMap ,
17
12
)
18
- from torch .ao .quantization .quantizer import (
13
+ from torchao .quantization .pt2e import WrapperModule
14
+ from torchao .quantization .pt2e .graph_utils import get_source_partitions
15
+ from torchao .quantization .pt2e .quantizer import (
16
+ annotate_input_qspec_map ,
17
+ annotate_output_qspec ,
18
+ get_bias_qspec ,
19
+ get_input_act_qspec ,
20
+ get_output_act_qspec ,
21
+ get_weight_qspec ,
22
+ OperatorConfig ,
23
+ OperatorPatternType ,
19
24
QuantizationAnnotation ,
25
+ QuantizationConfig ,
20
26
QuantizationSpec ,
21
27
SharedQuantizationSpec ,
22
28
)
23
- from torch .ao .quantization .quantizer .utils import (
24
- _annotate_input_qspec_map ,
25
- _annotate_output_qspec ,
26
- )
27
- from torch .fx import Node
28
- from torch .fx .passes .utils .matcher_with_name_node_map_utils import (
29
- SubgraphMatcherWithNameNodeMap ,
29
+ from torchao .quantization .pt2e .utils import (
30
+ _get_aten_graph_module_for_pattern ,
31
+ _is_conv_node ,
32
+ _is_conv_transpose_node ,
33
+ get_new_attr_name_with_prefix ,
30
34
)
31
- from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
32
35
33
36
__all__ = [
34
37
"OperatorConfig" ,
35
38
"OperatorPatternType" ,
36
39
"QuantizationConfig" ,
40
+ "QuantizationSpec" ,
37
41
"get_input_act_qspec" ,
38
42
"get_output_act_qspec" ,
39
43
"get_weight_qspec" ,
43
47
]
44
48
45
49
46
- # In the absence of better name, just winging it with QuantizationConfig
47
- @dataclass (eq = True , frozen = True )
48
- class QuantizationConfig :
49
- input_activation : Optional [QuantizationSpec ]
50
- output_activation : Optional [QuantizationSpec ]
51
- weight : Optional [QuantizationSpec ]
52
- bias : Optional [QuantizationSpec ]
53
- # TODO: remove, since we can use observer_or_fake_quant_ctr to express this
54
- is_qat : bool = False
55
-
56
-
57
- # Use Annotated because list[Callable].__module__ is read-only.
58
- OperatorPatternType = typing .Annotated [list [Callable ], None ]
59
- OperatorPatternType .__module__ = (
60
- "executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils"
61
- )
62
-
63
50
AnnotatorType = Callable [
64
51
[
65
52
torch .fx .GraphModule ,
@@ -78,19 +65,6 @@ def decorator(annotator: AnnotatorType) -> None:
78
65
return decorator
79
66
80
67
81
- class OperatorConfig (NamedTuple ):
82
- # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
83
- # Basically we are mapping a quantization config to some list of patterns.
84
- # a pattern is defined as a list of nn module, function or builtin function names
85
- # e.g. [nn.Conv2d, torch.relu, torch.add]
86
- # We have not resolved whether fusion can be considered internal details of the
87
- # quantizer hence it does not need communication to user.
88
- # Note this pattern is not really informative since it does not really
89
- # tell us the graph structure resulting from the list of ops.
90
- config : QuantizationConfig
91
- operators : list [OperatorPatternType ]
92
-
93
-
94
68
def is_relu_node (node : Node ) -> bool :
95
69
"""
96
70
Check if a given node is a relu node
@@ -124,63 +98,6 @@ def _mark_nodes_as_annotated(nodes: list[Node]):
124
98
node .meta ["quantization_annotation" ]._annotated = True
125
99
126
100
127
- def get_input_act_qspec (quantization_config : Optional [QuantizationConfig ]):
128
- if quantization_config is None :
129
- return None
130
- if quantization_config .input_activation is None :
131
- return None
132
- quantization_spec : QuantizationSpec = quantization_config .input_activation
133
- assert quantization_spec .qscheme in [
134
- torch .per_tensor_affine ,
135
- torch .per_tensor_symmetric ,
136
- ]
137
- return quantization_spec
138
-
139
-
140
- def get_output_act_qspec (quantization_config : Optional [QuantizationConfig ]):
141
- if quantization_config is None :
142
- return None
143
- if quantization_config .output_activation is None :
144
- return None
145
- quantization_spec : QuantizationSpec = quantization_config .output_activation
146
- assert quantization_spec .qscheme in [
147
- torch .per_tensor_affine ,
148
- torch .per_tensor_symmetric ,
149
- ]
150
- return quantization_spec
151
-
152
-
153
- def get_weight_qspec (quantization_config : Optional [QuantizationConfig ]):
154
- if quantization_config is None :
155
- return None
156
- assert quantization_config is not None
157
- if quantization_config .weight is None :
158
- return None
159
- quantization_spec : QuantizationSpec = quantization_config .weight
160
- if quantization_spec .qscheme not in [
161
- torch .per_tensor_symmetric ,
162
- torch .per_channel_symmetric ,
163
- None ,
164
- ]:
165
- raise ValueError (
166
- f"Unsupported quantization_spec { quantization_spec } for weight"
167
- )
168
- return quantization_spec
169
-
170
-
171
- def get_bias_qspec (quantization_config : Optional [QuantizationConfig ]):
172
- if quantization_config is None :
173
- return None
174
- assert quantization_config is not None
175
- if quantization_config .bias is None :
176
- return None
177
- quantization_spec : QuantizationSpec = quantization_config .bias
178
- assert (
179
- quantization_spec .dtype == torch .float
180
- ), "Only float dtype for bias is supported for bias right now"
181
- return quantization_spec
182
-
183
-
184
101
@register_annotator ("linear" )
185
102
def _annotate_linear (
186
103
gm : torch .fx .GraphModule ,
@@ -204,25 +121,25 @@ def _annotate_linear(
204
121
bias_node = node .args [2 ]
205
122
206
123
if _is_annotated ([node ]) is False : # type: ignore[list-item]
207
- _annotate_input_qspec_map (
124
+ annotate_input_qspec_map (
208
125
node ,
209
126
act_node ,
210
127
input_act_qspec ,
211
128
)
212
- _annotate_input_qspec_map (
129
+ annotate_input_qspec_map (
213
130
node ,
214
131
weight_node ,
215
132
weight_qspec ,
216
133
)
217
134
nodes_to_mark_annotated = [node , weight_node ]
218
135
if bias_node :
219
- _annotate_input_qspec_map (
136
+ annotate_input_qspec_map (
220
137
node ,
221
138
bias_node ,
222
139
bias_qspec ,
223
140
)
224
141
nodes_to_mark_annotated .append (bias_node )
225
- _annotate_output_qspec (node , output_act_qspec )
142
+ annotate_output_qspec (node , output_act_qspec )
226
143
_mark_nodes_as_annotated (nodes_to_mark_annotated )
227
144
annotated_partitions .append (nodes_to_mark_annotated )
228
145
@@ -572,7 +489,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
572
489
"output" : output ,
573
490
}
574
491
575
- return _WrapperModule (_conv_bn )
492
+ return WrapperModule (_conv_bn )
576
493
577
494
# Needed for matching, otherwise the matches gets filtered out due to unused
578
495
# nodes returned by batch norm
0 commit comments