Skip to content

Commit 08770b7

Browse files
authored
Generalize all annotators into a generic parameterizable annotator (#7298)
Change-Id: I69ecdbef9d7b83a87655e97758215303374b5f04
1 parent 9a23cff commit 08770b7

18 files changed

+339
-1098
lines changed

backends/arm/quantizer/TARGETS

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,22 @@ python_library(
55
srcs = ["arm_quantizer.py"],
66
deps = [
77
":arm_quantizer_utils",
8+
":quantization_annotator",
89
"//caffe2:torch",
9-
"//executorch/backends/arm/quantizer/quantization_annotation:quantization_annotation",
1010
"//executorch/exir:lib",
1111
],
1212
)
1313

14+
python_library(
15+
name = "quantization_annotator",
16+
srcs = ["quantization_annotator.py"],
17+
deps = [
18+
":arm_quantizer_utils",
19+
":quantization_config",
20+
"//caffe2:torch",
21+
],
22+
)
23+
1424
python_library(
1525
name = "quantization_config",
1626
srcs = ["quantization_config.py"],

backends/arm/quantizer/arm_quantizer.py

Lines changed: 10 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,16 @@
1313

1414
from __future__ import annotations
1515

16-
import copy
1716
import functools
18-
from typing import Any, Callable, Dict, List, Optional, Set
17+
from typing import Any, Callable, Dict, List, Optional
1918

2019
import torch
21-
import torch.nn.functional as F
2220
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
2321

2422
from executorch.backends.arm.quantizer import arm_quantizer_utils
25-
from executorch.backends.arm.quantizer.arm_quantizer_utils import (
26-
mark_nodes_as_annotated,
27-
propagate_annotation,
28-
)
29-
from executorch.backends.arm.quantizer.quantization_annotation import (
30-
OP_TO_ANNOTATOR,
31-
OperatorConfig,
32-
OperatorPatternType,
33-
)
23+
from executorch.backends.arm.quantizer.arm_quantizer_utils import mark_node_as_annotated
24+
from executorch.backends.arm.quantizer.quantization_annotator import annotate_graph
25+
3426
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
3527
from torch.ao.quantization.fake_quantize import (
3628
FakeQuantize,
@@ -58,44 +50,6 @@
5850
]
5951

6052

61-
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
62-
supported_operators: Dict[str, List[OperatorPatternType]] = {
63-
# Both conv and linear should be able to handle relu + hardtanh fusion since
64-
# those are clamp ops
65-
"conv2d": [
66-
[torch.nn.Conv2d, torch.nn.ReLU],
67-
[torch.nn.Conv2d, F.relu],
68-
[F.conv2d, torch.nn.ReLU],
69-
[F.conv2d, F.relu],
70-
],
71-
"linear": [[torch.nn.Linear], [F.linear]],
72-
"add": [[torch.add]],
73-
"max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
74-
"adaptive_avg_pool2d": [
75-
[torch.nn.AdaptiveAvgPool2d],
76-
[F.adaptive_avg_pool2d],
77-
],
78-
"mul": [[torch.mul]],
79-
"sub": [[torch.sub]],
80-
"min_max": [[torch.min], [torch.max]],
81-
}
82-
return copy.deepcopy(supported_operators)
83-
84-
85-
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
86-
supported_config_and_operators: List[OperatorConfig] = []
87-
for quantization_config in [
88-
get_symmetric_quantization_config(),
89-
get_symmetric_quantization_config(is_per_channel=True),
90-
]:
91-
ops = _supported_symmetric_quantized_operators()
92-
for pattern_list in ops.values():
93-
supported_config_and_operators.append(
94-
OperatorConfig(quantization_config, pattern_list)
95-
)
96-
return copy.deepcopy(supported_config_and_operators)
97-
98-
9953
@functools.lru_cache
10054
def get_symmetric_quantization_config(
10155
is_per_channel: bool = False,
@@ -180,10 +134,6 @@ def get_symmetric_quantization_config(
180134
return quantization_config
181135

182136

183-
def _get_supported_config_and_operators() -> List[OperatorConfig]:
184-
return _get_supported_symmetric_config_and_operators()
185-
186-
187137
NodeFilterType = Callable[[Node], bool]
188138
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
189139
a Node and returns whether the node should be annotated or not.
@@ -255,26 +205,6 @@ def not_module_type_or_name_filter(n: Node) -> bool:
255205

256206

257207
class ArmQuantizer(Quantizer):
258-
supported_config_and_operators = _get_supported_config_and_operators()
259-
260-
# A list of supported static quantization annotators, in order of application.
261-
# For example, fusions come before singular ops.
262-
# The name must match the name used when registering the annotator.
263-
STATIC_ANNOTATION_ORDER = [
264-
"linear",
265-
"conv",
266-
"adaptive_avg_pool2d",
267-
"max_pool2d",
268-
"add",
269-
"sub",
270-
"mul",
271-
"min_max",
272-
"mm",
273-
"one_to_one",
274-
"generic",
275-
"upsample_nearest2d",
276-
]
277-
278208
def __init__(self) -> None:
279209
super().__init__()
280210
self.global_config: Optional[QuantizationConfig] = None
@@ -331,7 +261,6 @@ def annotate(self, model: GraphModule) -> GraphModule:
331261
The annotated model.
332262
"""
333263
model = self._annotate_for_static_quantization_config(model)
334-
propagate_annotation(model)
335264
return model
336265

337266
def _annotate_all_static_patterns(
@@ -353,8 +282,7 @@ def _annotate_all_static_patterns(
353282
if quantization_config is None:
354283
return model
355284

356-
for op in self.STATIC_ANNOTATION_ORDER:
357-
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
285+
annotate_graph(model, quantization_config, filter_fn)
358286
return model
359287

360288
def _annotate_for_static_quantization_config(
@@ -363,6 +291,9 @@ def _annotate_for_static_quantization_config(
363291
"""Matches the correct QuantizationConfig with the correct module using a filter
364292
when running _annotate_all_static_patterns.
365293
"""
294+
if self.io_config:
295+
self._annotate_io(model, self.io_config)
296+
366297
module_name_list = list(self.module_name_config.keys())
367298
for module_name, config in self.module_name_config.items():
368299
self._annotate_all_static_patterns(
@@ -381,9 +312,6 @@ def _annotate_for_static_quantization_config(
381312
_get_not_module_type_or_name_filter(tp_list, module_name_list),
382313
)
383314

384-
if self.io_config:
385-
self._annotate_io(model, self.io_config)
386-
387315
return model
388316

389317
def _annotate_io(
@@ -399,44 +327,13 @@ def _annotate_io(
399327
node,
400328
quantization_config.get_output_act_qspec(),
401329
)
402-
mark_nodes_as_annotated([node])
330+
mark_node_as_annotated(node)
403331
if node.op == "output":
404332
parent = node.all_input_nodes[0]
405333
_annotate_input_qspec_map(
406334
node, parent, quantization_config.get_input_act_qspec()
407335
)
408-
mark_nodes_as_annotated([node])
336+
mark_node_as_annotated(node)
409337

410338
def validate(self, model: GraphModule) -> None:
411339
pass
412-
413-
@classmethod
414-
def get_supported_operators(cls) -> List[OperatorConfig]:
415-
return cls.supported_config_and_operators
416-
417-
@classmethod
418-
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
419-
op_configs: Set[QuantizationConfig] = set({})
420-
for spec, _ in cls.supported_config_and_operators:
421-
op_configs.add(spec)
422-
return list(op_configs)
423-
424-
@classmethod
425-
def get_supported_operator_for_quantization_config(
426-
cls, quantization_config: Optional[QuantizationConfig]
427-
) -> List[OperatorPatternType]:
428-
if quantization_config is None:
429-
all_ops = []
430-
for _, ops in cls.supported_config_and_operators:
431-
all_ops.extend(ops)
432-
return all_ops
433-
434-
for config, ops in cls.supported_config_and_operators:
435-
# note: this assumes each entry in cls.supported_spec_and_operators
436-
# corresponds to one spec, e.g. we don't have
437-
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
438-
# where the first and second entry have the same spec but did not
439-
# merge the op list
440-
if config == quantization_config:
441-
return ops
442-
return []

0 commit comments

Comments
 (0)