Skip to content

Commit 3152d7f

Browse files
Arm backend: Add and use ArmQuantizer (#2561)
Summary: - Add and use ArmQuantizer in Arm backend. Pull Request resolved: #2561 Reviewed By: mergennachin Differential Revision: D55200846 Pulled By: digantdesai fbshipit-source-id: 4085a14e498311ec4ce245bcb062ab79122144d5
1 parent d06ccd2 commit 3152d7f

File tree

6 files changed

+46
-37
lines changed

6 files changed

+46
-37
lines changed

backends/arm/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ kend Architecture](#arm-backend-architecture). For examples of use see `executor
2020
- `tosa_mapping.py` - utilities for mapping edge dialect to TOSA
2121
- `tosa_quant_utils.py` - utilities for mapping quantization information to TOSA encoding
2222

23+
Quantization:
24+
- `arm_quantizer.py` - Quantizer for Arm backend
25+
- `arm_quantizer_utils.py` - Utilities for quantization
26+
2327
Runtime:
2428
- `runtime/ArmBackendEthosU.cpp` - The Arm backend implementation of the ExecuTorch runtime backend (PyTorchBackendInterface) for Ethos-U
2529

backends/arm/arm_quantizer.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
23
# All rights reserved.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

8+
#
9+
# Quantizer for Arm backend
10+
#
11+
712
from __future__ import annotations
813

914
import copy
@@ -14,6 +19,15 @@
1419
import torch
1520
import torch._dynamo as torchdynamo
1621
import torch.nn.functional as F
22+
23+
from executorch.backends.arm.arm_quantizer_utils import (
24+
_convert_scalars_to_attrs,
25+
OP_TO_ANNOTATOR,
26+
OperatorConfig,
27+
OperatorPatternType,
28+
propagate_annotation,
29+
QuantizationConfig,
30+
)
1731
from torch.ao.quantization.fake_quantize import (
1832
FakeQuantize,
1933
FusedMovingAvgObsFakeQuantize,
@@ -31,20 +45,11 @@
3145

3246
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
3347

34-
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
35-
_convert_scalars_to_attrs,
36-
OP_TO_ANNOTATOR,
37-
OperatorConfig,
38-
OperatorPatternType,
39-
propagate_annotation,
40-
QuantizationConfig,
41-
)
42-
4348
from torch.fx import Node
4449

4550

4651
__all__ = [
47-
"XNNPACKQuantizer",
52+
"ArmQuantizer",
4853
"get_symmetric_quantization_config",
4954
]
5055

@@ -260,7 +265,7 @@ def not_module_type_or_name_filter(n: Node) -> bool:
260265
return not_module_type_or_name_filter
261266

262267

263-
class XNNPACKQuantizer(Quantizer):
268+
class ArmQuantizer(Quantizer):
264269
supported_config_and_operators = _get_supported_config_and_operators()
265270
STATIC_QAT_ONLY_OPS = [
266271
"conv_bn_relu",
@@ -325,15 +330,15 @@ def get_supported_operator_for_quantization_config(
325330
return ops
326331
return []
327332

328-
def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
333+
def set_global(self, quantization_config: QuantizationConfig) -> ArmQuantizer:
329334
self.global_config = quantization_config
330335
return self
331336

332337
def set_operator_type(
333338
self,
334339
operator_type: torch._ops.OpOverloadPacket,
335340
quantization_config: QuantizationConfig,
336-
) -> XNNPACKQuantizer:
341+
) -> ArmQuantizer:
337342
self.operator_type_config[operator_type] = quantization_config
338343
return self
339344

backends/arm/arm_quantizer_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
23
# All rights reserved.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

8+
#
9+
# Utility functions for ArmQuantizer
10+
#
11+
712
import itertools
813
import operator
914
from dataclasses import dataclass
@@ -62,9 +67,7 @@ class QuantizationConfig:
6267

6368

6469
OperatorPatternType = List[Callable]
65-
OperatorPatternType.__module__ = (
66-
"torch.ao.quantization.quantizer.xnnpack_quantizer_utils"
67-
)
70+
OperatorPatternType.__module__ = "executorch.backends.arm.arm_quantizer_utils"
6871

6972
AnnotatorType = Callable[
7073
[
@@ -604,7 +607,7 @@ def _annotate_max_pool2d(
604607
maxpool_node = n
605608
assert (
606609
maxpool_node is not None
607-
), "XNNPACKQuantizer only works with torch.ops.aten.max_pool2d.default, "
610+
), "ArmQuantizer only works with torch.ops.aten.max_pool2d.default, "
608611
"please make sure you are exporting the model correctly"
609612
if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item]
610613
continue

backends/arm/test/test_tosa.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,14 @@
2525
_check_ir_validity=False,
2626
)
2727

28-
from executorch.exir import EdgeCompileConfig
29-
from executorch.exir.program import to_edge
30-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
31-
3228
## For quantization
33-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
29+
from executorch.backends.arm.arm_quantizer import (
30+
ArmQuantizer,
3431
get_symmetric_quantization_config,
35-
XNNPACKQuantizer,
3632
)
33+
from executorch.exir import EdgeCompileConfig
34+
from executorch.exir.program import to_edge
35+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3736

3837

3938
class TestBasicNN(unittest.TestCase):
@@ -88,7 +87,7 @@ def prepare_model_and_ref(test_model, profile=TosaProfile.MI):
8887
model, copy.deepcopy(model.inputs[profile])
8988
)
9089
# Setup the quantizer
91-
quantizer = XNNPACKQuantizer()
90+
quantizer = ArmQuantizer()
9291
operator_config = get_symmetric_quantization_config(is_per_channel=False)
9392
quantizer.set_global(operator_config)
9493

backends/arm/test/tester/arm_tester.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
)
1414

1515
from executorch.backends.arm.arm_partitioner import ArmPartitioner
16+
from executorch.backends.arm.arm_quantizer import (
17+
ArmQuantizer,
18+
get_symmetric_quantization_config,
19+
)
1620

1721
from executorch.backends.arm.test.tosautil.tosa_test_utils import (
1822
QuantizationParams,
@@ -29,10 +33,6 @@
2933
)
3034

3135
from executorch.exir import EdgeCompileConfig
32-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
33-
get_symmetric_quantization_config,
34-
XNNPACKQuantizer,
35-
)
3636
from torch.export import ExportedProgram
3737

3838

@@ -115,9 +115,8 @@ def __init__(
115115

116116
def quantize(self, quantize_stage: Optional[Quantize] = None):
117117
if quantize_stage is None:
118-
# Using the XNNPACKQuantizer for now
119118
quantize_stage = Quantize(
120-
XNNPACKQuantizer(),
119+
ArmQuantizer(),
121120
get_symmetric_quantization_config(is_per_channel=False),
122121
)
123122
return super().quantize(quantize_stage)

examples/arm/aot_arm_compiler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,20 @@
2323
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
2424
logging.basicConfig(level=logging.WARNING, format=FORMAT)
2525

26-
# Quantize model if required using the standard export quantizaion flow.
27-
# For now we're using the xnnpack quantizer as this produces reasonable
28-
# output for our arithmetic behaviour.
29-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
30-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
26+
from executorch.backends.arm.arm_quantizer import (
27+
ArmQuantizer,
3128
get_symmetric_quantization_config,
32-
XNNPACKQuantizer,
3329
)
3430

31+
# Quantize model if required using the standard export quantizaion flow.
32+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
33+
3534

3635
def quantize(model, example_inputs):
3736
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
3837
logging.info("Quantizing Model...")
3938
logging.debug(f"Original model: {model}")
40-
quantizer = XNNPACKQuantizer()
39+
quantizer = ArmQuantizer()
4140
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
4241
operator_config = get_symmetric_quantization_config(is_per_channel=False)
4342
quantizer.set_global(operator_config)

0 commit comments

Comments
 (0)