Skip to content

Arm backend: Add missing ops to annotator #11517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,12 @@ def _match_pattern(
torch.ops.aten.squeeze_copy.dim,
torch.ops.aten.squeeze.dim,
torch.ops.aten.squeeze.dims,
torch.ops.aten.unbind.int,
torch.ops.aten.unsqueeze.default,
torch.ops.aten.unsqueeze_copy.default,
torch.ops.aten.reshape.default,
torch.ops.aten.repeat.default,
torch.ops.aten.repeat_interleave.self_int,
torch.ops.aten.expand_copy.default,
torch.ops.aten.expand.default,
# Disabling these as there seems to be an issue with support for complex
Expand Down Expand Up @@ -256,6 +258,7 @@ def _match_pattern(
torch.ops.aten.amin.default,
torch.ops.aten.clamp.default,
torch.ops.aten.clamp.Tensor,
torch.ops.aten.unflatten.int,
]

_one_to_one_shared_input_or_input_act_qspec = [
Expand All @@ -271,6 +274,7 @@ def _match_pattern(
torch.ops.aten.avg_pool2d.default,
torch.ops.aten.max_pool2d.default,
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.flatten.using_ints,
torch.ops.aten.dropout.default,
torch.ops.aten.dropout_.default,
Expand Down Expand Up @@ -539,6 +543,7 @@ def annotate_graph( # type: ignore[return]
if node.target in [
torch.ops.aten.full_like.default,
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.scalar_tensor.default,
]:
node.kwargs = {}
52 changes: 42 additions & 10 deletions backends/arm/runtime/EthosUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,24 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
event_tracer,
"+EthosUBackend::execute()handles.input.permute_CHW_to_HWC()");
// permuted byte copy CHW to HWC
int c, h, w;
if (tensor_in.dim() == 4) {
c = tensor_in.size(1);
h = tensor_in.size(2);
w = tensor_in.size(3);
} else if (tensor_in.dim() == 5) {
c = tensor_in.size(2);
h = tensor_in.size(3);
w = tensor_in.size(4);
} else {
ET_LOG(
Error,
"Unsupported input tensor dimension %d, expected 4 or 5",
tensor_in.dim());
return Error::InvalidProgram;
}
permute_CHW_to_HWC(
tensor_in.mutable_data_ptr<char>(),
scratch_addr,
tensor_in.size(1),
tensor_in.size(2),
tensor_in.size(3));
tensor_in.mutable_data_ptr<char>(), scratch_addr, c, h, w);
} else if (both_char or both_int or both_short) {
EXECUTORCH_PROF_SCOPE(
event_tracer, "+EthosUBackend::execute()handles.input.memcpy()");
Expand Down Expand Up @@ -364,12 +376,24 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
"+EthosUBackend::execute()handles.output.permute_HWC_to_CHW()");

char* output_address = (char*)output_addr;
int c, h, w;
if (tensor_out.dim() == 4) {
c = tensor_out.size(1);
h = tensor_out.size(2);
w = tensor_out.size(3);
} else if (tensor_out.dim() == 5) {
c = tensor_out.size(2);
h = tensor_out.size(3);
w = tensor_out.size(4);
} else {
ET_LOG(
Error,
"Unsupported output tensor dimension %d, expected 4 or 5",
tensor_out.dim());
return Error::InvalidProgram;
}
permute_HWC_to_CHW(
output_address,
tensor_out.mutable_data_ptr<char>(),
tensor_out.size(1),
tensor_out.size(2),
tensor_out.size(3));
output_address, tensor_out.mutable_data_ptr<char>(), c, h, w);
} else {
EXECUTORCH_PROF_SCOPE(
event_tracer, "+EthosUBackend::execute()handles.output.move()");
Expand Down Expand Up @@ -430,6 +454,14 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
if (permuted_shape) {
ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
}
} else if (tensor.dim() == 5) {
// Same as above, but for 5D tensors.
permuted_shape = tensor.size(0) == io->shape[0] &&
tensor.size(1) == io->shape[1] && tensor.size(2) == io->shape[4] &&
tensor.size(3) == io->shape[2] && tensor.size(4) == io->shape[3];
if (permuted_shape) {
ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
}
}
*is_permuted = permuted_shape;
return Error::Ok;
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/scripts/parse_test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"bitwise_right_shift.Tensor",
"bitwise_left_shift.Tensor",
"native_group_norm.default",
"unbind.int",
"unflatten.int",
"_native_batch_norm_legit_no_training.default",
"_native_batch_norm_legit.no_stats",
]
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/models/test_deit_tiny_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_deit_tiny_tosa_BI():
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
atol=2.5, # This needs to go down: MLETORCH-956
atol=1,
qtol=1,
)
pipeline.run()
6 changes: 0 additions & 6 deletions backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,4 @@ def test_llama_tosa_BI():
exir_op=[],
use_to_edge_transform_and_lower=True,
)
pipeline.change_args(
"run_method_and_compare_outputs",
atol=9.9,
rtol=1.5, # TODO: Tolerance needs to be updated after MLETORCH-907
inputs=llama_inputs,
)
pipeline.run()
87 changes: 55 additions & 32 deletions backends/arm/test/ops/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,68 +21,91 @@
)

input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x, Input y
aten_op = "torch.ops.aten.repeat.default"


"""Tests Tensor.repeat for different ranks and dimensions."""


class Repeat(torch.nn.Module):
# (input tensor, multiples)
test_parameters = {
"1_x_1": lambda: (torch.randn(3), (2,)),
"2_x_2": lambda: (torch.randn(3, 4), (2, 1)),
"4_x_4": lambda: (torch.randn(1, 1, 2, 2), (1, 2, 3, 4)),
"1_x_2": lambda: (torch.randn(3), (2, 2)),
"1_x_3": lambda: (torch.randn(3), (1, 2, 3)),
"2_x_3": lambda: (torch.randn((3, 3)), (2, 2, 2)),
"1_x_4": lambda: (torch.randn((3, 3, 3)), (2, 1, 2, 4)),
}

def forward(self, x: torch.Tensor, multiples: Sequence):
return x.repeat(multiples)


@common.parametrize("test_data", Repeat.test_parameters)
aten_op = "torch.ops.aten.repeat.default"

def __init__(self, multiples: Sequence[int]):
super().__init__()
self.multiples = multiples

def forward(self, x: torch.Tensor):
return x.repeat(self.multiples)


class RepeatInterleaveInt(torch.nn.Module):
aten_op = "torch.ops.aten.repeat_interleave.self_int"

def __init__(self, repeats: int, dim: int):
super().__init__()
self.repeats = repeats
self.dim = dim

def forward(self, x: torch.Tensor):
return x.repeat_interleave(self.repeats, self.dim)


test_data_suite = {
# test_name : lambda: (module, test_data)
"1_x_1": lambda: (Repeat((2,)), (torch.randn(3),)),
"2_x_2": lambda: (Repeat((2, 1)), (torch.randn(3, 4),)),
"4_x_4": lambda: (Repeat((1, 2, 3, 4)), (torch.randn(1, 1, 2, 2),)),
"1_x_2": lambda: (Repeat((2, 2)), (torch.randn(3),)),
"1_x_3": lambda: (Repeat((1, 2, 3)), (torch.randn(3),)),
"2_x_3": lambda: (Repeat((2, 2, 2)), (torch.randn((3, 3)),)),
"1_x_4": lambda: (Repeat((2, 1, 2, 4)), (torch.randn((3, 3, 3)),)),
"interleave_int_3_x_1": lambda: (RepeatInterleaveInt(3, 1), (torch.randn(3, 4),)),
}


@common.parametrize("test_data", test_data_suite)
def test_repeat_tosa_MI(test_data: Tuple):
module, test_data = test_data()
pipeline = TosaPipelineMI[input_t1](
Repeat(),
test_data(),
aten_op,
module,
test_data,
module.aten_op,
exir_op=[],
)
pipeline.run()


@common.parametrize("test_data", Repeat.test_parameters)
@common.parametrize("test_data", test_data_suite)
def test_repeat_tosa_BI(test_data: Tuple):
module, test_data = test_data()
pipeline = TosaPipelineBI[input_t1](
Repeat(),
test_data(),
aten_op,
module,
test_data,
module.aten_op,
exir_op=[],
)
pipeline.run()


@common.parametrize("test_data", Repeat.test_parameters)
@common.parametrize("test_data", test_data_suite)
def test_repeat_u55_BI(test_data: Tuple):
module, test_data = test_data()
pipeline = EthosU55PipelineBI[input_t1](
Repeat(),
test_data(),
aten_op,
module,
test_data,
module.aten_op,
exir_ops=[],
run_on_fvp=False,
)
pipeline.run()


@common.parametrize("test_data", Repeat.test_parameters)
@common.parametrize("test_data", test_data_suite)
def test_repeat_u85_BI(test_data: Tuple):
module, test_data = test_data()
pipeline = EthosU85PipelineBI[input_t1](
Repeat(),
test_data(),
aten_op,
module,
test_data,
module.aten_op,
exir_ops=[],
run_on_fvp=False,
)
Expand Down
55 changes: 55 additions & 0 deletions backends/arm/test/ops/test_unbind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
TosaPipelineBI,
TosaPipelineMI,
)

input_t = tuple[torch.Tensor]
test_data_t = tuple[int, torch.dtype]


class Unbind(torch.nn.Module):
aten_op: str = "torch.ops.aten.unbind.int"

def __init__(self, dim: int):
super().__init__()
self.dim = dim

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
return torch.unbind(x, self.dim)

test_data: dict[str, test_data_t] = {
"randn_4d": (lambda: (torch.randn(1, 5, 4, 3),), (2,)),
"randn_3d": (lambda: (torch.randn(5, 4, 3),), (0,)),
}


@common.parametrize("test_data", Unbind.test_data)
def test_unbind_int_tosa_MI(test_data: test_data_t):
input_data, init_data = test_data
pipeline = TosaPipelineMI[input_t](
Unbind(*init_data),
input_data(),
Unbind.aten_op,
)
pipeline.run()


@common.parametrize("test_data", Unbind.test_data)
def test_unbind_int_tosa_BI(test_data: test_data_t):
input_data, init_data = test_data
pipeline = TosaPipelineBI[input_t](
Unbind(*init_data),
input_data(),
Unbind.aten_op,
)
pipeline.run()
56 changes: 56 additions & 0 deletions backends/arm/test/ops/test_unflatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
TosaPipelineBI,
TosaPipelineMI,
)

input_t = tuple[torch.Tensor]
test_data_t = tuple[torch.nn.Module, input_t]


class Unflatten(torch.nn.Module):
aten_op: str = "torch.ops.aten.unflatten.int"

def __init__(self, dim: int, sizes: Tuple[int, ...]):
super().__init__()
self.dim = dim
self.sizes = sizes

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.unflatten(x, self.dim, self.sizes)

test_data: dict[str, test_data_t] = {
"randn_4d": (lambda: (Unflatten(1, (2, 2)), (torch.randn(3, 4, 5, 1),))),
"rand_3d": (lambda: (Unflatten(1, (-1, 2)), (torch.rand(3, 4, 4),))),
}


@common.parametrize("test_data", Unflatten.test_data)
def test_unflatten_int_tosa_MI(test_data: test_data_t):
module, inputs = test_data()
pipeline = TosaPipelineMI[input_t](
module,
inputs,
Unflatten.aten_op,
)
pipeline.run()


@common.parametrize("test_data", Unflatten.test_data)
def test_unflatten_int_tosa_BI(test_data: test_data_t):
module, inputs = test_data()
pipeline = TosaPipelineBI[input_t](
module,
inputs,
Unflatten.aten_op,
)
pipeline.run()
Loading
Loading