Skip to content

Generalize all annotators into a generic parameterizable annotator #7298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion backends/arm/quantizer/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@ python_library(
srcs = ["arm_quantizer.py"],
deps = [
":arm_quantizer_utils",
":quantization_annotator",
"//caffe2:torch",
"//executorch/backends/arm/quantizer/quantization_annotation:quantization_annotation",
"//executorch/exir:lib",
],
)

python_library(
name = "quantization_annotator",
srcs = ["quantization_annotator.py"],
deps = [
":arm_quantizer_utils",
":quantization_config",
"//caffe2:torch",
],
)

python_library(
name = "quantization_config",
srcs = ["quantization_config.py"],
Expand Down
123 changes: 10 additions & 113 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,16 @@

from __future__ import annotations

import copy
import functools
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn.functional as F
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager

from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.arm_quantizer_utils import (
mark_nodes_as_annotated,
propagate_annotation,
)
from executorch.backends.arm.quantizer.quantization_annotation import (
OP_TO_ANNOTATOR,
OperatorConfig,
OperatorPatternType,
)
from executorch.backends.arm.quantizer.arm_quantizer_utils import mark_node_as_annotated
from executorch.backends.arm.quantizer.quantization_annotator import annotate_graph

from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.fake_quantize import (
FakeQuantize,
Expand Down Expand Up @@ -58,44 +50,6 @@
]


def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
supported_operators: Dict[str, List[OperatorPatternType]] = {
# Both conv and linear should be able to handle relu + hardtanh fusion since
# those are clamp ops
"conv2d": [
[torch.nn.Conv2d, torch.nn.ReLU],
[torch.nn.Conv2d, F.relu],
[F.conv2d, torch.nn.ReLU],
[F.conv2d, F.relu],
],
"linear": [[torch.nn.Linear], [F.linear]],
"add": [[torch.add]],
"max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
"adaptive_avg_pool2d": [
[torch.nn.AdaptiveAvgPool2d],
[F.adaptive_avg_pool2d],
],
"mul": [[torch.mul]],
"sub": [[torch.sub]],
"min_max": [[torch.min], [torch.max]],
}
return copy.deepcopy(supported_operators)


def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [
get_symmetric_quantization_config(),
get_symmetric_quantization_config(is_per_channel=True),
]:
ops = _supported_symmetric_quantized_operators()
for pattern_list in ops.values():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
)
return copy.deepcopy(supported_config_and_operators)


@functools.lru_cache
def get_symmetric_quantization_config(
is_per_channel: bool = False,
Expand Down Expand Up @@ -180,10 +134,6 @@ def get_symmetric_quantization_config(
return quantization_config


def _get_supported_config_and_operators() -> List[OperatorConfig]:
return _get_supported_symmetric_config_and_operators()


NodeFilterType = Callable[[Node], bool]
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
a Node and returns whether the node should be annotated or not.
Expand Down Expand Up @@ -255,26 +205,6 @@ def not_module_type_or_name_filter(n: Node) -> bool:


class ArmQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()

# A list of supported static quantization annotators, in order of application.
# For example, fusions come before singular ops.
# The name must match the name used when registering the annotator.
STATIC_ANNOTATION_ORDER = [
"linear",
"conv",
"adaptive_avg_pool2d",
"max_pool2d",
"add",
"sub",
"mul",
"min_max",
"mm",
"one_to_one",
"generic",
"upsample_nearest2d",
]

def __init__(self) -> None:
super().__init__()
self.global_config: Optional[QuantizationConfig] = None
Expand Down Expand Up @@ -331,7 +261,6 @@ def annotate(self, model: GraphModule) -> GraphModule:
The annotated model.
"""
model = self._annotate_for_static_quantization_config(model)
propagate_annotation(model)
return model

def _annotate_all_static_patterns(
Expand All @@ -353,8 +282,7 @@ def _annotate_all_static_patterns(
if quantization_config is None:
return model

for op in self.STATIC_ANNOTATION_ORDER:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
annotate_graph(model, quantization_config, filter_fn)
return model

def _annotate_for_static_quantization_config(
Expand All @@ -363,6 +291,9 @@ def _annotate_for_static_quantization_config(
"""Matches the correct QuantizationConfig with the correct module using a filter
when running _annotate_all_static_patterns.
"""
if self.io_config:
self._annotate_io(model, self.io_config)

module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_static_patterns(
Expand All @@ -381,9 +312,6 @@ def _annotate_for_static_quantization_config(
_get_not_module_type_or_name_filter(tp_list, module_name_list),
)

if self.io_config:
self._annotate_io(model, self.io_config)

return model

def _annotate_io(
Expand All @@ -399,44 +327,13 @@ def _annotate_io(
node,
quantization_config.get_output_act_qspec(),
)
mark_nodes_as_annotated([node])
mark_node_as_annotated(node)
if node.op == "output":
parent = node.all_input_nodes[0]
_annotate_input_qspec_map(
node, parent, quantization_config.get_input_act_qspec()
)
mark_nodes_as_annotated([node])
mark_node_as_annotated(node)

def validate(self, model: GraphModule) -> None:
pass

@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return cls.supported_config_and_operators

@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.supported_config_and_operators:
op_configs.add(spec)
return list(op_configs)

@classmethod
def get_supported_operator_for_quantization_config(
cls, quantization_config: Optional[QuantizationConfig]
) -> List[OperatorPatternType]:
if quantization_config is None:
all_ops = []
for _, ops in cls.supported_config_and_operators:
all_ops.extend(ops)
return all_ops

for config, ops in cls.supported_config_and_operators:
# note: this assumes each entry in cls.supported_spec_and_operators
# corresponds to one spec, e.g. we don't have
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
# where the first and second entry have the same spec but did not
# merge the op list
if config == quantization_config:
return ops
return []
Loading