Skip to content

Add generic annotator for data layout ops #5814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ class ArmQuantizer(Quantizer):
"mm",
"cat",
"one_to_one",
"generic",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: too vague? May it's just me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is a bit vague (generic even 😅), but naming things is hard...

]

def __init__(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def decorator(annotator: AnnotatorType):
add_annotator,
cat_annotator,
conv_annotator,
generic_annotator,
linear_annotator,
max_pool2d_annotator,
mm_annotator,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Callable, List, Optional

import torch
import torch.fx
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import SharedQuantizationSpec
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import Node


_SUPPORTED_OPS = [
# DATA LAYOUT OPS
torch.ops.aten.squeeze.default,
torch.ops.aten.squeeze_copy.default,
torch.ops.aten.unsqueeze.default,
torch.ops.aten.unsqueeze_copy.default,
torch.ops.aten.reshape.default,
# Disabling these as there seems to be an issue with support for complex
# datatypes in torch:
# torch.ops.aten.view_as_complex.default,
# torch.ops.aten.view_as_complex_copy.default,
# torch.ops.aten.view_as_real.default,
# torch.ops.aten.view_as_real_copy.default,
torch.ops.aten.view_copy.default,
torch.ops.aten.slice.Tensor,
torch.ops.aten.slice_copy.Tensor,
# 'concat' should be handled separately as it has a sequence of inputs and
# makes the implementation unnecessary complicated.
# torch.ops.aten.concat.default,
torch.ops.aten.transpose.Dimname,
torch.ops.aten.transpose.int,
torch.ops.aten.transpose_copy.int,
torch.ops.aten.tile.default,
torch.ops.aten.flip.default,
]


@register_annotator("generic")
def _annotate_generic(
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
"""Propagate qspecs to generic ops like unsqueeze, reshape etc."""
annotated_partitions = []

for node in gm.graph.nodes:
if node.op != "call_function" or node.target not in _SUPPORTED_OPS:
continue
if filter_fn and not filter_fn(node):
continue
if arm_quantizer_utils.is_annotated(node):
continue

input_node = node.args[0]

# Using a non-shared quantization spec here as a SharedQuantizationSpec
# can lead to a recursion.
_annotate_input_qspec_map(
node, input_node, quantization_config.get_input_act_qspec()
)
_annotate_output_qspec(node, SharedQuantizationSpec((input_node, node)))

arm_quantizer_utils.mark_nodes_as_annotated([node])
annotated_partitions.append([node])

return annotated_partitions
86 changes: 86 additions & 0 deletions backends/arm/test/quantizer/test_generic_annotater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import unittest

import torch
from executorch.backends.arm.quantizer.arm_quantizer_utils import is_annotated
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class SingleOpModel(torch.nn.Module):
def __init__(self, op, example_input, **op_kwargs) -> None:
super().__init__()
self.op = op
self._example_input = example_input
self.op_kwargs = op_kwargs

def forward(self, x):
return self.op(x, **self.op_kwargs)

def example_inputs(self):
return self._example_input


class TestGenericAnnotator(unittest.TestCase):
def check_annotation(self, model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

tester = ArmTester(
model, model.example_inputs(), common.get_tosa_compile_spec()
)
quant_model = tester.quantize().get_artifact()
partitions = get_source_partitions(quant_model.graph, [model.op])
partitions = list(itertools.chain.from_iterable(partitions.values()))

assert len(partitions) == 1
partition = partitions[0]
assert all(is_annotated(node) for node in partition.nodes)

def test_squeeze(self):
self.check_annotation(SingleOpModel(torch.squeeze, (torch.rand(8, 8, 1),)))
self.check_annotation(SingleOpModel(torch.squeeze_copy, (torch.rand(8, 8, 1),)))

def test_unsqueeze(self):
self.check_annotation(
SingleOpModel(torch.unsqueeze, (torch.rand(8, 8),), dim=0)
)
self.check_annotation(
SingleOpModel(torch.unsqueeze_copy, (torch.rand(8, 8),), dim=0)
)

def test_reshape(self):
self.check_annotation(
SingleOpModel(torch.reshape, (torch.randn(8, 8),), shape=(64,)),
)

def test_view(self):
self.check_annotation(
SingleOpModel(torch.view_copy, (torch.randn(4, 4),), size=(2, 8)),
)

def test_slice(self):
self.check_annotation(
SingleOpModel(torch.slice_copy, (torch.randn(3, 4),)),
)

def test_transpose(self):
self.check_annotation(
SingleOpModel(torch.transpose, (torch.randn(2, 3),), dim0=0, dim1=1),
)
self.check_annotation(
SingleOpModel(torch.transpose_copy, (torch.randn(2, 3),), dim0=0, dim1=1),
)

def test_tile(self):
self.check_annotation(
SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),
)

def test_flip(self):
self.check_annotation(
SingleOpModel(torch.flip, (torch.randn(2, 4),), dims=(0, 1)),
)
Loading