Skip to content

Arm backend: Refactor Quantizer test to allow for TOSA 1.0 #10905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 79 additions & 71 deletions backends/arm/test/quantizer/test_generic_annotater.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import unittest

from typing import Tuple

import torch
from executorch.backends.arm.quantizer import is_annotated
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI

from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


input_t1 = Tuple[torch.Tensor] # Input x


class SingleOpModel(torch.nn.Module):
def __init__(self, op, example_input, **op_kwargs) -> None:
super().__init__()
Expand All @@ -27,69 +30,74 @@ def example_inputs(self):
return self._example_input


class TestGenericAnnotator(unittest.TestCase):
def check_annotation(self, model):
tester = ArmTester(
model,
model.example_inputs(),
common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
quant_model = tester.quantize().get_artifact()
partitions = get_source_partitions(quant_model.graph, [model.op])
partitions = list(itertools.chain.from_iterable(partitions.values()))

assert len(partitions) == 1
partition = partitions[0]
assert all(is_annotated(node) for node in partition.nodes)

def test_squeeze(self):
self.check_annotation(SingleOpModel(torch.squeeze, (torch.rand(8, 8, 1),)))
self.check_annotation(SingleOpModel(torch.squeeze_copy, (torch.rand(8, 8, 1),)))

def test_unsqueeze(self):
self.check_annotation(
SingleOpModel(torch.unsqueeze, (torch.rand(8, 8),), dim=0)
)
self.check_annotation(
SingleOpModel(torch.unsqueeze_copy, (torch.rand(8, 8),), dim=0)
)

def test_reshape(self):
self.check_annotation(
SingleOpModel(torch.reshape, (torch.randn(8, 8),), shape=(64,)),
)

def test_view(self):
self.check_annotation(
SingleOpModel(torch.view_copy, (torch.randn(4, 4),), size=(2, 8)),
)

def test_slice(self):
self.check_annotation(
SingleOpModel(torch.slice_copy, (torch.randn(3, 4),)),
)

def test_transpose(self):
self.check_annotation(
SingleOpModel(torch.transpose, (torch.randn(2, 3),), dim0=0, dim1=1),
)
self.check_annotation(
SingleOpModel(torch.transpose_copy, (torch.randn(2, 3),), dim0=0, dim1=1),
)

def test_tile(self):
self.check_annotation(
SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),
)

def test_flip(self):
self.check_annotation(
SingleOpModel(torch.flip, (torch.randn(2, 4),), dims=(0, 1)),
)

def test_concat(self):
self.check_annotation(
SingleOpModel(
torch.concatenate, ((torch.randn(2, 3), torch.randn(2, 3)),), dim=0
),
)
def check_annotation(model):
pipeline = TosaPipelineBI[input_t1](model, model.example_inputs(), [], [])
pipeline.pop_stage("check_count.exir")
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()

artifact = pipeline.tester.get_artifact("Quantize")

partitions = get_source_partitions(artifact.graph, [model.op])
partitions = list(itertools.chain.from_iterable(partitions.values()))

assert len(partitions) == 1
partition = partitions[0]
assert all(is_annotated(node) for node in partition.nodes)


def test_squeeze():
check_annotation(SingleOpModel(torch.squeeze, (torch.rand(8, 8, 1),)))
check_annotation(SingleOpModel(torch.squeeze_copy, (torch.rand(8, 8, 1),)))


def test_unsqueeze():
check_annotation(SingleOpModel(torch.unsqueeze, (torch.rand(8, 8),), dim=0))
check_annotation(SingleOpModel(torch.unsqueeze_copy, (torch.rand(8, 8),), dim=0))


def test_reshape():
check_annotation(
SingleOpModel(torch.reshape, (torch.randn(8, 8),), shape=(64,)),
)


def test_view():
check_annotation(
SingleOpModel(torch.view_copy, (torch.randn(4, 4),), size=(2, 8)),
)


def test_slice():
check_annotation(
SingleOpModel(torch.slice_copy, (torch.randn(3, 4),)),
)


def test_transpose():
check_annotation(
SingleOpModel(torch.transpose, (torch.randn(2, 3),), dim0=0, dim1=1),
)
check_annotation(
SingleOpModel(torch.transpose_copy, (torch.randn(2, 3),), dim0=0, dim1=1),
)


def test_tile():
check_annotation(
SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),
)


def test_flip():
check_annotation(
SingleOpModel(torch.flip, (torch.randn(2, 4),), dims=(0, 1)),
)


def test_concat():
check_annotation(
SingleOpModel(
torch.concatenate, ((torch.randn(2, 3), torch.randn(2, 3)),), dim=0
),
)
Loading