Skip to content

Commit 98a58e0

Browse files
benkli01facebook-github-bot
authored andcommitted
Add generic annotator for data layout ops (#5814)
Summary: Data layout ops like unsqueeze are not annotated by the quantizer per default which leads to issues down the line. Therefore we add a generic annotator to explicitly annotate those ops. Pull Request resolved: #5814 Reviewed By: mergennachin Differential Revision: D63812934 Pulled By: digantdesai fbshipit-source-id: 72e85a66a92c5d655168b7575035edc2b7d66255
1 parent 00d804c commit 98a58e0

File tree

4 files changed

+167
-0
lines changed

4 files changed

+167
-0
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ class ArmQuantizer(Quantizer):
271271
"mm",
272272
"cat",
273273
"one_to_one",
274+
"generic",
274275
]
275276

276277
def __init__(self) -> None:

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def decorator(annotator: AnnotatorType):
5353
add_annotator,
5454
cat_annotator,
5555
conv_annotator,
56+
generic_annotator,
5657
linear_annotator,
5758
max_pool2d_annotator,
5859
mm_annotator,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import Callable, List, Optional
9+
10+
import torch
11+
import torch.fx
12+
from executorch.backends.arm.quantizer import arm_quantizer_utils
13+
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
14+
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
15+
from torch.ao.quantization.quantizer import SharedQuantizationSpec
16+
from torch.ao.quantization.quantizer.utils import (
17+
_annotate_input_qspec_map,
18+
_annotate_output_qspec,
19+
)
20+
from torch.fx import Node
21+
22+
23+
_SUPPORTED_OPS = [
24+
# DATA LAYOUT OPS
25+
torch.ops.aten.squeeze.default,
26+
torch.ops.aten.squeeze_copy.default,
27+
torch.ops.aten.unsqueeze.default,
28+
torch.ops.aten.unsqueeze_copy.default,
29+
torch.ops.aten.reshape.default,
30+
# Disabling these as there seems to be an issue with support for complex
31+
# datatypes in torch:
32+
# torch.ops.aten.view_as_complex.default,
33+
# torch.ops.aten.view_as_complex_copy.default,
34+
# torch.ops.aten.view_as_real.default,
35+
# torch.ops.aten.view_as_real_copy.default,
36+
torch.ops.aten.view_copy.default,
37+
torch.ops.aten.slice.Tensor,
38+
torch.ops.aten.slice_copy.Tensor,
39+
# 'concat' should be handled separately as it has a sequence of inputs and
40+
# makes the implementation unnecessary complicated.
41+
# torch.ops.aten.concat.default,
42+
torch.ops.aten.transpose.Dimname,
43+
torch.ops.aten.transpose.int,
44+
torch.ops.aten.transpose_copy.int,
45+
torch.ops.aten.tile.default,
46+
torch.ops.aten.flip.default,
47+
]
48+
49+
50+
@register_annotator("generic")
51+
def _annotate_generic(
52+
gm: torch.fx.GraphModule,
53+
quantization_config: QuantizationConfig,
54+
filter_fn: Optional[Callable[[Node], bool]] = None,
55+
) -> Optional[List[List[Node]]]:
56+
"""Propagate qspecs to generic ops like unsqueeze, reshape etc."""
57+
annotated_partitions = []
58+
59+
for node in gm.graph.nodes:
60+
if node.op != "call_function" or node.target not in _SUPPORTED_OPS:
61+
continue
62+
if filter_fn and not filter_fn(node):
63+
continue
64+
if arm_quantizer_utils.is_annotated(node):
65+
continue
66+
67+
input_node = node.args[0]
68+
69+
# Using a non-shared quantization spec here as a SharedQuantizationSpec
70+
# can lead to a recursion.
71+
_annotate_input_qspec_map(
72+
node, input_node, quantization_config.get_input_act_qspec()
73+
)
74+
_annotate_output_qspec(node, SharedQuantizationSpec((input_node, node)))
75+
76+
arm_quantizer_utils.mark_nodes_as_annotated([node])
77+
annotated_partitions.append([node])
78+
79+
return annotated_partitions
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import itertools
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.quantizer.arm_quantizer_utils import is_annotated
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
14+
15+
16+
class SingleOpModel(torch.nn.Module):
17+
def __init__(self, op, example_input, **op_kwargs) -> None:
18+
super().__init__()
19+
self.op = op
20+
self._example_input = example_input
21+
self.op_kwargs = op_kwargs
22+
23+
def forward(self, x):
24+
return self.op(x, **self.op_kwargs)
25+
26+
def example_inputs(self):
27+
return self._example_input
28+
29+
30+
class TestGenericAnnotator(unittest.TestCase):
31+
def check_annotation(self, model):
32+
tester = ArmTester(
33+
model, model.example_inputs(), common.get_tosa_compile_spec()
34+
)
35+
quant_model = tester.quantize().get_artifact()
36+
partitions = get_source_partitions(quant_model.graph, [model.op])
37+
partitions = list(itertools.chain.from_iterable(partitions.values()))
38+
39+
assert len(partitions) == 1
40+
partition = partitions[0]
41+
assert all(is_annotated(node) for node in partition.nodes)
42+
43+
def test_squeeze(self):
44+
self.check_annotation(SingleOpModel(torch.squeeze, (torch.rand(8, 8, 1),)))
45+
self.check_annotation(SingleOpModel(torch.squeeze_copy, (torch.rand(8, 8, 1),)))
46+
47+
def test_unsqueeze(self):
48+
self.check_annotation(
49+
SingleOpModel(torch.unsqueeze, (torch.rand(8, 8),), dim=0)
50+
)
51+
self.check_annotation(
52+
SingleOpModel(torch.unsqueeze_copy, (torch.rand(8, 8),), dim=0)
53+
)
54+
55+
def test_reshape(self):
56+
self.check_annotation(
57+
SingleOpModel(torch.reshape, (torch.randn(8, 8),), shape=(64,)),
58+
)
59+
60+
def test_view(self):
61+
self.check_annotation(
62+
SingleOpModel(torch.view_copy, (torch.randn(4, 4),), size=(2, 8)),
63+
)
64+
65+
def test_slice(self):
66+
self.check_annotation(
67+
SingleOpModel(torch.slice_copy, (torch.randn(3, 4),)),
68+
)
69+
70+
def test_transpose(self):
71+
self.check_annotation(
72+
SingleOpModel(torch.transpose, (torch.randn(2, 3),), dim0=0, dim1=1),
73+
)
74+
self.check_annotation(
75+
SingleOpModel(torch.transpose_copy, (torch.randn(2, 3),), dim0=0, dim1=1),
76+
)
77+
78+
def test_tile(self):
79+
self.check_annotation(
80+
SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),
81+
)
82+
83+
def test_flip(self):
84+
self.check_annotation(
85+
SingleOpModel(torch.flip, (torch.randn(2, 4),), dims=(0, 1)),
86+
)

0 commit comments

Comments
 (0)