Skip to content

Commit b234dad

Browse files
committed
Generalize all annotators into a generic parameterizable annotator
Change-Id: I69ecdbef9d7b83a87655e97758215303374b5f04
1 parent 579d958 commit b234dad

17 files changed

+337
-1049
lines changed

backends/arm/quantizer/TARGETS

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,22 @@ python_library(
55
srcs = ["arm_quantizer.py"],
66
deps = [
77
":arm_quantizer_utils",
8+
":quantization_annotator",
89
"//caffe2:torch",
9-
"//executorch/backends/arm/quantizer/quantization_annotation:quantization_annotation",
1010
"//executorch/exir:lib",
1111
],
1212
)
1313

14+
python_library(
15+
name = "quantization_annotator",
16+
srcs = ["quantization_annotator.py"],
17+
deps = [
18+
":arm_quantizer_utils",
19+
":quantization_config",
20+
"//caffe2:torch",
21+
],
22+
)
23+
1424
python_library(
1525
name = "quantization_config",
1626
srcs = ["quantization_config.py"],

backends/arm/quantizer/arm_quantizer.py

Lines changed: 10 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,16 @@
1313

1414
from __future__ import annotations
1515

16-
import copy
1716
import functools
18-
from typing import Any, Callable, Dict, List, Optional, Set
17+
from typing import Any, Callable, Dict, List, Optional
1918

2019
import torch
21-
import torch.nn.functional as F
2220
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
2321

2422
from executorch.backends.arm.quantizer import arm_quantizer_utils
25-
from executorch.backends.arm.quantizer.arm_quantizer_utils import (
26-
mark_nodes_as_annotated,
27-
propagate_annotation,
28-
)
29-
from executorch.backends.arm.quantizer.quantization_annotation import (
30-
OP_TO_ANNOTATOR,
31-
OperatorConfig,
32-
OperatorPatternType,
33-
)
23+
from executorch.backends.arm.quantizer.arm_quantizer_utils import mark_node_as_annotated
24+
from executorch.backends.arm.quantizer.quantization_annotator import annotate_graph
25+
3426
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
3527
from torch.ao.quantization.fake_quantize import (
3628
FakeQuantize,
@@ -58,43 +50,6 @@
5850
]
5951

6052

61-
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
62-
supported_operators: Dict[str, List[OperatorPatternType]] = {
63-
# Both conv and linear should be able to handle relu + hardtanh fusion since
64-
# those are clamp ops
65-
"conv2d": [
66-
[torch.nn.Conv2d, torch.nn.ReLU],
67-
[torch.nn.Conv2d, F.relu],
68-
[F.conv2d, torch.nn.ReLU],
69-
[F.conv2d, F.relu],
70-
],
71-
"linear": [[torch.nn.Linear], [F.linear]],
72-
"add": [[torch.add]],
73-
"max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
74-
"adaptive_avg_pool2d": [
75-
[torch.nn.AdaptiveAvgPool2d],
76-
[F.adaptive_avg_pool2d],
77-
],
78-
"mul": [[torch.mul]],
79-
"sub": [[torch.sub]],
80-
}
81-
return copy.deepcopy(supported_operators)
82-
83-
84-
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
85-
supported_config_and_operators: List[OperatorConfig] = []
86-
for quantization_config in [
87-
get_symmetric_quantization_config(),
88-
get_symmetric_quantization_config(is_per_channel=True),
89-
]:
90-
ops = _supported_symmetric_quantized_operators()
91-
for pattern_list in ops.values():
92-
supported_config_and_operators.append(
93-
OperatorConfig(quantization_config, pattern_list)
94-
)
95-
return copy.deepcopy(supported_config_and_operators)
96-
97-
9853
@functools.lru_cache
9954
def get_symmetric_quantization_config(
10055
is_per_channel: bool = False,
@@ -179,10 +134,6 @@ def get_symmetric_quantization_config(
179134
return quantization_config
180135

181136

182-
def _get_supported_config_and_operators() -> List[OperatorConfig]:
183-
return _get_supported_symmetric_config_and_operators()
184-
185-
186137
NodeFilterType = Callable[[Node], bool]
187138
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
188139
a Node and returns whether the node should be annotated or not.
@@ -254,25 +205,6 @@ def not_module_type_or_name_filter(n: Node) -> bool:
254205

255206

256207
class ArmQuantizer(Quantizer):
257-
supported_config_and_operators = _get_supported_config_and_operators()
258-
259-
# A list of supported static quantization annotators, in order of application.
260-
# For example, fusions come before singular ops.
261-
# The name must match the name used when registering the annotator.
262-
STATIC_ANNOTATION_ORDER = [
263-
"linear",
264-
"conv",
265-
"adaptive_avg_pool2d",
266-
"max_pool2d",
267-
"add",
268-
"sub",
269-
"mul",
270-
"mm",
271-
"one_to_one",
272-
"generic",
273-
"upsample_nearest2d",
274-
]
275-
276208
def __init__(self) -> None:
277209
super().__init__()
278210
self.global_config: Optional[QuantizationConfig] = None
@@ -329,7 +261,6 @@ def annotate(self, model: GraphModule) -> GraphModule:
329261
The annotated model.
330262
"""
331263
model = self._annotate_for_static_quantization_config(model)
332-
propagate_annotation(model)
333264
return model
334265

335266
def _annotate_all_static_patterns(
@@ -351,8 +282,7 @@ def _annotate_all_static_patterns(
351282
if quantization_config is None:
352283
return model
353284

354-
for op in self.STATIC_ANNOTATION_ORDER:
355-
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
285+
annotate_graph(model, quantization_config, filter_fn)
356286
return model
357287

358288
def _annotate_for_static_quantization_config(
@@ -361,6 +291,9 @@ def _annotate_for_static_quantization_config(
361291
"""Matches the correct QuantizationConfig with the correct module using a filter
362292
when running _annotate_all_static_patterns.
363293
"""
294+
if self.io_config:
295+
self._annotate_io(model, self.io_config)
296+
364297
module_name_list = list(self.module_name_config.keys())
365298
for module_name, config in self.module_name_config.items():
366299
self._annotate_all_static_patterns(
@@ -379,9 +312,6 @@ def _annotate_for_static_quantization_config(
379312
_get_not_module_type_or_name_filter(tp_list, module_name_list),
380313
)
381314

382-
if self.io_config:
383-
self._annotate_io(model, self.io_config)
384-
385315
return model
386316

387317
def _annotate_io(
@@ -397,44 +327,13 @@ def _annotate_io(
397327
node,
398328
quantization_config.get_output_act_qspec(),
399329
)
400-
mark_nodes_as_annotated([node])
330+
mark_node_as_annotated(node)
401331
if node.op == "output":
402332
parent = node.all_input_nodes[0]
403333
_annotate_input_qspec_map(
404334
node, parent, quantization_config.get_input_act_qspec()
405335
)
406-
mark_nodes_as_annotated([node])
336+
mark_node_as_annotated(node)
407337

408338
def validate(self, model: GraphModule) -> None:
409339
pass
410-
411-
@classmethod
412-
def get_supported_operators(cls) -> List[OperatorConfig]:
413-
return cls.supported_config_and_operators
414-
415-
@classmethod
416-
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
417-
op_configs: Set[QuantizationConfig] = set({})
418-
for spec, _ in cls.supported_config_and_operators:
419-
op_configs.add(spec)
420-
return list(op_configs)
421-
422-
@classmethod
423-
def get_supported_operator_for_quantization_config(
424-
cls, quantization_config: Optional[QuantizationConfig]
425-
) -> List[OperatorPatternType]:
426-
if quantization_config is None:
427-
all_ops = []
428-
for _, ops in cls.supported_config_and_operators:
429-
all_ops.extend(ops)
430-
return all_ops
431-
432-
for config, ops in cls.supported_config_and_operators:
433-
# note: this assumes each entry in cls.supported_spec_and_operators
434-
# corresponds to one spec, e.g. we don't have
435-
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
436-
# where the first and second entry have the same spec but did not
437-
# merge the op list
438-
if config == quantization_config:
439-
return ops
440-
return []

0 commit comments

Comments
 (0)