Skip to content

Commit 265cdb7

Browse files
committed
Add generic annotator for data layout ops
Data layout ops like unsqueeze are not annotated by the quantizer per default which leads to issues down the line. Therefore we add a generic annotator to explicitly annotate those ops. Signed-off-by: Benjamin Klimczak <[email protected]> Change-Id: Id3919abcb3df0b81159f3cccaab9785f8706b9cd
1 parent 393553c commit 265cdb7

File tree

4 files changed

+167
-0
lines changed

4 files changed

+167
-0
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ class ArmQuantizer(Quantizer):
271271
"mm",
272272
"cat",
273273
"one_to_one",
274+
"generic",
274275
]
275276

276277
def __init__(self) -> None:

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def decorator(annotator: AnnotatorType):
5353
add_annotator,
5454
cat_annotator,
5555
conv_annotator,
56+
generic_annotator,
5657
linear_annotator,
5758
max_pool2d_annotator,
5859
mm_annotator,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import Callable, List, Optional
9+
10+
import torch
11+
import torch.fx
12+
from executorch.backends.arm.quantizer import arm_quantizer_utils
13+
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
14+
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
15+
from torch.ao.quantization.quantizer import SharedQuantizationSpec
16+
from torch.ao.quantization.quantizer.utils import (
17+
_annotate_input_qspec_map,
18+
_annotate_output_qspec,
19+
)
20+
from torch.fx import Node
21+
22+
23+
_SUPPORTED_OPS = [
24+
# DATA LAYOUT OPS
25+
torch.ops.aten.squeeze.default,
26+
torch.ops.aten.squeeze_copy.default,
27+
torch.ops.aten.unsqueeze.default,
28+
torch.ops.aten.unsqueeze_copy.default,
29+
torch.ops.aten.reshape.default,
30+
# Disabling these as there seems to be an issue with support for complex
31+
# datatypes in torch:
32+
# torch.ops.aten.view_as_complex.default,
33+
# torch.ops.aten.view_as_complex_copy.default,
34+
# torch.ops.aten.view_as_real.default,
35+
# torch.ops.aten.view_as_real_copy.default,
36+
torch.ops.aten.view_copy.default,
37+
torch.ops.aten.slice.Tensor,
38+
torch.ops.aten.slice_copy.Tensor,
39+
# 'concat' should be handled separately as it has a sequence of inputs and
40+
# makes the implementation unnecessary complicated.
41+
# torch.ops.aten.concat.default,
42+
torch.ops.aten.transpose.Dimname,
43+
torch.ops.aten.transpose.int,
44+
torch.ops.aten.transpose_copy.int,
45+
torch.ops.aten.tile.default,
46+
torch.ops.aten.flip.default,
47+
]
48+
49+
50+
@register_annotator("generic")
51+
def _annotate_generic(
52+
gm: torch.fx.GraphModule,
53+
quantization_config: QuantizationConfig,
54+
filter_fn: Optional[Callable[[Node], bool]] = None,
55+
) -> Optional[List[List[Node]]]:
56+
"""Propagate qspecs to generic ops like unsqueeze, reshape etc."""
57+
annotated_partitions = []
58+
59+
for node in gm.graph.nodes:
60+
if node.op != "call_function" or node.target not in _SUPPORTED_OPS:
61+
continue
62+
if filter_fn and not filter_fn(node):
63+
continue
64+
if arm_quantizer_utils.is_annotated(node):
65+
continue
66+
67+
input_node = node.args[0]
68+
69+
# Using a non-shared quantization spec here as a SharedQuantizationSpec
70+
# can lead to a recursion.
71+
_annotate_input_qspec_map(
72+
node, input_node, quantization_config.get_input_act_qspec()
73+
)
74+
_annotate_output_qspec(node, SharedQuantizationSpec((input_node, node)))
75+
76+
arm_quantizer_utils.mark_nodes_as_annotated([node])
77+
annotated_partitions.append([node])
78+
79+
return annotated_partitions
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import itertools
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.quantizer.arm_quantizer_utils import is_annotated
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
14+
15+
16+
class SingleOpModel(torch.nn.Module):
17+
def __init__(self, op, example_input, **op_kwargs) -> None:
18+
super().__init__()
19+
self.op = op
20+
self._example_input = example_input
21+
self.op_kwargs = op_kwargs
22+
23+
def forward(self, x):
24+
return self.op(x, **self.op_kwargs)
25+
26+
def example_inputs(self):
27+
return self._example_input
28+
29+
30+
class TestGenericAnnotator(unittest.TestCase):
31+
def check_annotation(self, model):
32+
tester = ArmTester(
33+
model, model.example_inputs(), common.get_tosa_compile_spec()
34+
)
35+
quant_model = tester.quantize().get_artifact()
36+
partitions = get_source_partitions(quant_model.graph, [model.op])
37+
partitions = list(itertools.chain.from_iterable(partitions.values()))
38+
39+
assert len(partitions) == 1
40+
partition = partitions[0]
41+
assert all(is_annotated(node) for node in partition.nodes)
42+
43+
def test_squeeze(self):
44+
self.check_annotation(SingleOpModel(torch.squeeze, (torch.rand(8, 8, 1),)))
45+
self.check_annotation(SingleOpModel(torch.squeeze_copy, (torch.rand(8, 8, 1),)))
46+
47+
def test_unsqueeze(self):
48+
self.check_annotation(
49+
SingleOpModel(torch.unsqueeze, (torch.rand(8, 8),), dim=0)
50+
)
51+
self.check_annotation(
52+
SingleOpModel(torch.unsqueeze_copy, (torch.rand(8, 8),), dim=0)
53+
)
54+
55+
def test_reshape(self):
56+
self.check_annotation(
57+
SingleOpModel(torch.reshape, (torch.randn(8, 8),), shape=(64,)),
58+
)
59+
60+
def test_view(self):
61+
self.check_annotation(
62+
SingleOpModel(torch.view_copy, (torch.randn(4, 4),), size=(2, 8)),
63+
)
64+
65+
def test_slice(self):
66+
self.check_annotation(
67+
SingleOpModel(torch.slice_copy, (torch.randn(3, 4),)),
68+
)
69+
70+
def test_transpose(self):
71+
self.check_annotation(
72+
SingleOpModel(torch.transpose, (torch.randn(2, 3),), dim0=0, dim1=1),
73+
)
74+
self.check_annotation(
75+
SingleOpModel(torch.transpose_copy, (torch.randn(2, 3),), dim0=0, dim1=1),
76+
)
77+
78+
def test_tile(self):
79+
self.check_annotation(
80+
SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),
81+
)
82+
83+
def test_flip(self):
84+
self.check_annotation(
85+
SingleOpModel(torch.flip, (torch.randn(2, 4),), dims=(0, 1)),
86+
)

0 commit comments

Comments
 (0)