Skip to content

Arm backend: Add and use ArmQuantizer #2561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/arm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ kend Architecture](#arm-backend-architecture). For examples of use see `executor
- `tosa_mapping.py` - utilities for mapping edge dialect to TOSA
- `tosa_quant_utils.py` - utilities for mapping quantization information to TOSA encoding

Quantization:
- `arm_quantizer.py` - Quantizer for Arm backend
- `arm_quantizer_utils.py` - Utilities for quantization

Runtime:
- `runtime/ArmBackendEthosU.cpp` - The Arm backend implementation of the ExecuTorch runtime backend (PyTorchBackendInterface) for Ethos-U

Expand Down
31 changes: 18 additions & 13 deletions backends/arm/arm_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#
# Quantizer for Arm backend
#

from __future__ import annotations

import copy
Expand All @@ -14,6 +19,15 @@
import torch
import torch._dynamo as torchdynamo
import torch.nn.functional as F

from executorch.backends.arm.arm_quantizer_utils import (
_convert_scalars_to_attrs,
OP_TO_ANNOTATOR,
OperatorConfig,
OperatorPatternType,
propagate_annotation,
QuantizationConfig,
)
from torch.ao.quantization.fake_quantize import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
Expand All @@ -31,20 +45,11 @@

from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
_convert_scalars_to_attrs,
OP_TO_ANNOTATOR,
OperatorConfig,
OperatorPatternType,
propagate_annotation,
QuantizationConfig,
)

from torch.fx import Node


__all__ = [
"XNNPACKQuantizer",
"ArmQuantizer",
"get_symmetric_quantization_config",
]

Expand Down Expand Up @@ -260,7 +265,7 @@ def not_module_type_or_name_filter(n: Node) -> bool:
return not_module_type_or_name_filter


class XNNPACKQuantizer(Quantizer):
class ArmQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()
STATIC_QAT_ONLY_OPS = [
"conv_bn_relu",
Expand Down Expand Up @@ -325,15 +330,15 @@ def get_supported_operator_for_quantization_config(
return ops
return []

def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
def set_global(self, quantization_config: QuantizationConfig) -> ArmQuantizer:
self.global_config = quantization_config
return self

def set_operator_type(
self,
operator_type: torch._ops.OpOverloadPacket,
quantization_config: QuantizationConfig,
) -> XNNPACKQuantizer:
) -> ArmQuantizer:
self.operator_type_config[operator_type] = quantization_config
return self

Expand Down
11 changes: 7 additions & 4 deletions backends/arm/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#
# Utility functions for ArmQuantizer
#

import itertools
import operator
from dataclasses import dataclass
Expand Down Expand Up @@ -62,9 +67,7 @@ class QuantizationConfig:


OperatorPatternType = List[Callable]
OperatorPatternType.__module__ = (
"torch.ao.quantization.quantizer.xnnpack_quantizer_utils"
)
OperatorPatternType.__module__ = "executorch.backends.arm.arm_quantizer_utils"

AnnotatorType = Callable[
[
Expand Down Expand Up @@ -604,7 +607,7 @@ def _annotate_max_pool2d(
maxpool_node = n
assert (
maxpool_node is not None
), "XNNPACKQuantizer only works with torch.ops.aten.max_pool2d.default, "
), "ArmQuantizer only works with torch.ops.aten.max_pool2d.default, "
"please make sure you are exporting the model correctly"
if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item]
continue
Expand Down
13 changes: 6 additions & 7 deletions backends/arm/test/test_tosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@
_check_ir_validity=False,
)

from executorch.exir import EdgeCompileConfig
from executorch.exir.program import to_edge
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

## For quantization
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
from executorch.backends.arm.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from executorch.exir import EdgeCompileConfig
from executorch.exir.program import to_edge
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e


class TestBasicNN(unittest.TestCase):
Expand Down Expand Up @@ -88,7 +87,7 @@ def prepare_model_and_ref(test_model, profile=TosaProfile.MI):
model, copy.deepcopy(model.inputs[profile])
)
# Setup the quantizer
quantizer = XNNPACKQuantizer()
quantizer = ArmQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)

Expand Down
11 changes: 5 additions & 6 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
)

from executorch.backends.arm.arm_partitioner import ArmPartitioner
from executorch.backends.arm.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
)

from executorch.backends.arm.test.tosautil.tosa_test_utils import (
QuantizationParams,
Expand All @@ -29,10 +33,6 @@
)

from executorch.exir import EdgeCompileConfig
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import ExportedProgram


Expand Down Expand Up @@ -88,9 +88,8 @@ def __init__(

def quantize(self, quantize_stage: Optional[Quantize] = None):
if quantize_stage is None:
# Using the XNNPACKQuantizer for now
quantize_stage = Quantize(
XNNPACKQuantizer(),
ArmQuantizer(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

get_symmetric_quantization_config(is_per_channel=False),
)
return super().quantize(quantize_stage)
Expand Down
13 changes: 6 additions & 7 deletions examples/arm/aot_arm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,20 @@
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.WARNING, format=FORMAT)

# Quantize model if required using the standard export quantizaion flow.
# For now we're using the xnnpack quantizer as this produces reasonable
# output for our arithmetic behaviour.
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
from executorch.backends.arm.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
XNNPACKQuantizer,
)

# Quantize model if required using the standard export quantizaion flow.
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e


def quantize(model, example_inputs):
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
logging.info("Quantizing Model...")
logging.debug(f"Original model: {model}")
quantizer = XNNPACKQuantizer()
quantizer = ArmQuantizer()
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
Expand Down