Skip to content

Commit 86ba7e7

Browse files
authored
Migrate pt2e arm (#11053)
Migrate arm backend to use pt2e from torchao
1 parent a4985a8 commit 86ba7e7

File tree

9 files changed

+36
-31
lines changed

9 files changed

+36
-31
lines changed

.lintrunner.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,6 @@ exclude_patterns = [
386386
"third-party/**",
387387
# TODO: remove exceptions as we migrate
388388
# backends
389-
"backends/arm/quantizer/**",
390-
"backends/arm/test/ops/**",
391389
"backends/vulkan/quantizer/**",
392390
"backends/vulkan/test/**",
393391
"backends/qualcomm/quantizer/**",

backends/arm/quantizer/TARGETS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ python_library(
66
srcs = ["quantization_config.py"],
77
deps = [
88
"//caffe2:torch",
9+
"//pytorch/ao:torchao",
910
],
1011
)
1112

@@ -18,6 +19,7 @@ python_library(
1819
":quantization_annotator",
1920
"//caffe2:torch",
2021
"//executorch/exir:lib",
22+
"//pytorch/ao:torchao",
2123
],
2224
)
2325

@@ -28,6 +30,7 @@ python_library(
2830
":arm_quantizer_utils",
2931
":quantization_config",
3032
"//caffe2:torch",
33+
"//pytorch/ao:torchao",
3134
],
3235
)
3336

@@ -36,6 +39,7 @@ python_library(
3639
srcs = ["arm_quantizer_utils.py"],
3740
deps = [
3841
":quantization_config",
42+
"//pytorch/ao:torchao",
3943
],
4044
)
4145

backends/arm/quantizer/arm_quantizer.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,26 @@
3030
is_vgf,
3131
) # usort: skip
3232
from executorch.exir.backend.compile_spec_schema import CompileSpec
33-
from torch.ao.quantization.fake_quantize import (
33+
34+
from torch.fx import GraphModule, Node
35+
from torchao.quantization.pt2e import (
3436
FakeQuantize,
3537
FusedMovingAvgObsFakeQuantize,
36-
)
37-
from torch.ao.quantization.observer import (
3838
HistogramObserver,
3939
MinMaxObserver,
4040
MovingAverageMinMaxObserver,
4141
MovingAveragePerChannelMinMaxObserver,
42+
ObserverOrFakeQuantizeConstructor,
4243
PerChannelMinMaxObserver,
4344
PlaceholderObserver,
4445
)
45-
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
46-
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
47-
from torch.ao.quantization.quantizer.utils import (
48-
_annotate_input_qspec_map,
49-
_annotate_output_qspec,
46+
47+
from torchao.quantization.pt2e.quantizer import (
48+
annotate_input_qspec_map,
49+
annotate_output_qspec,
50+
QuantizationSpec,
51+
Quantizer,
5052
)
51-
from torch.fx import GraphModule, Node
5253

5354
__all__ = [
5455
"TOSAQuantizer",
@@ -97,7 +98,7 @@ def get_symmetric_quantization_config(
9798
weight_qscheme = (
9899
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
99100
)
100-
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
101+
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
101102
MinMaxObserver
102103
)
103104
if is_qat:
@@ -337,14 +338,14 @@ def _annotate_io(
337338
if is_annotated(node):
338339
continue
339340
if node.op == "placeholder" and len(node.users) > 0:
340-
_annotate_output_qspec(
341+
annotate_output_qspec(
341342
node,
342343
quantization_config.get_output_act_qspec(),
343344
)
344345
mark_node_as_annotated(node)
345346
if node.op == "output":
346347
parent = node.all_input_nodes[0]
347-
_annotate_input_qspec_map(
348+
annotate_input_qspec_map(
348349
node, parent, quantization_config.get_input_act_qspec()
349350
)
350351
mark_node_as_annotated(node)

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
import torch
1717
from torch._subclasses import FakeTensor
18-
19-
from torch.ao.quantization.quantizer import QuantizationAnnotation
2018
from torch.fx import GraphModule, Node
2119

20+
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
21+
2222

2323
def is_annotated(node: Node) -> bool:
2424
"""Given a node return whether the node is annotated."""

backends/arm/quantizer/quantization_annotator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
import torch.nn.functional as F
1414
from executorch.backends.arm.quantizer import QuantizationConfig
1515
from executorch.backends.arm.tosa_utils import get_node_debug_info
16-
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
17-
from torch.ao.quantization.quantizer.utils import (
18-
_annotate_input_qspec_map,
19-
_annotate_output_qspec,
20-
)
16+
2117
from torch.fx import Node
18+
from torchao.quantization.pt2e.quantizer import (
19+
annotate_input_qspec_map,
20+
annotate_output_qspec,
21+
QuantizationSpecBase,
22+
SharedQuantizationSpec,
23+
)
2224

2325
from .arm_quantizer_utils import (
2426
is_annotated,
@@ -119,7 +121,7 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
119121
strict=True,
120122
):
121123
assert isinstance(n_arg, Node)
122-
_annotate_input_qspec_map(node, n_arg, qspec)
124+
annotate_input_qspec_map(node, n_arg, qspec)
123125
if quant_property.mark_annotated:
124126
mark_node_as_annotated(n_arg) # type: ignore[attr-defined]
125127

@@ -130,7 +132,7 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
130132
assert not quant_property.optional
131133
assert quant_property.index == 0, "Only one output annotation supported currently"
132134

133-
_annotate_output_qspec(node, quant_property.qspec)
135+
annotate_output_qspec(node, quant_property.qspec)
134136

135137

136138
def _match_pattern(

backends/arm/quantizer/quantization_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from dataclasses import dataclass
1010

1111
import torch
12-
from torch.ao.quantization import ObserverOrFakeQuantize
12+
from torchao.quantization.pt2e import ObserverOrFakeQuantize
1313

14-
from torch.ao.quantization.quantizer import (
14+
from torchao.quantization.pt2e.quantizer import (
1515
DerivedQuantizationSpec,
1616
FixedQParamsQuantizationSpec,
1717
QuantizationSpec,

backends/arm/test/ops/test_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
)
2020
from executorch.backends.arm.tosa_specification import TosaSpecification
2121
from executorch.backends.xnnpack.test.tester import Quantize
22-
from torch.ao.quantization.observer import HistogramObserver
23-
from torch.ao.quantization.quantizer import QuantizationSpec
22+
from torchao.quantization.pt2e import HistogramObserver
23+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
2424

2525
aten_op = "torch.ops.aten.add.Tensor"
2626
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"

backends/arm/test/ops/test_sigmoid_16bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
)
1919
from executorch.backends.arm.tosa_specification import TosaSpecification
2020
from executorch.backends.xnnpack.test.tester import Quantize
21-
from torch.ao.quantization.observer import HistogramObserver
22-
from torch.ao.quantization.quantizer import QuantizationSpec
21+
from torchao.quantization.pt2e import HistogramObserver
22+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
2323

2424

2525
def _get_16_bit_quant_config():

backends/arm/test/ops/test_sigmoid_32bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
)
1515
from executorch.backends.arm.tosa_specification import TosaSpecification
1616
from executorch.backends.xnnpack.test.tester import Quantize
17-
from torch.ao.quantization.observer import HistogramObserver
18-
from torch.ao.quantization.quantizer import QuantizationSpec
17+
from torchao.quantization.pt2e import HistogramObserver
18+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
1919

2020

2121
def _get_16_bit_quant_config():

0 commit comments

Comments
 (0)