Skip to content

Commit 5e555d5

Browse files
Arm backend: Add missing ops to annotator (#11517)
- Adds missing operators to annotator. These are needed to improve numerical precision for llama and deit_tiny. - Extends 5D support to runtime by permuting input/output to/from channels_last. - Adds unit tests for operators with 5D tensors. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 18e9149 commit 5e555d5

File tree

9 files changed

+224
-57
lines changed

9 files changed

+224
-57
lines changed

backends/arm/quantizer/quantization_annotator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,12 @@ def _match_pattern(
221221
torch.ops.aten.squeeze_copy.dim,
222222
torch.ops.aten.squeeze.dim,
223223
torch.ops.aten.squeeze.dims,
224+
torch.ops.aten.unbind.int,
224225
torch.ops.aten.unsqueeze.default,
225226
torch.ops.aten.unsqueeze_copy.default,
226227
torch.ops.aten.reshape.default,
227228
torch.ops.aten.repeat.default,
229+
torch.ops.aten.repeat_interleave.self_int,
228230
torch.ops.aten.expand_copy.default,
229231
torch.ops.aten.expand.default,
230232
# Disabling these as there seems to be an issue with support for complex
@@ -256,6 +258,7 @@ def _match_pattern(
256258
torch.ops.aten.amin.default,
257259
torch.ops.aten.clamp.default,
258260
torch.ops.aten.clamp.Tensor,
261+
torch.ops.aten.unflatten.int,
259262
]
260263

261264
_one_to_one_shared_input_or_input_act_qspec = [
@@ -271,6 +274,7 @@ def _match_pattern(
271274
torch.ops.aten.avg_pool2d.default,
272275
torch.ops.aten.max_pool2d.default,
273276
torch.ops.aten.full.default,
277+
torch.ops.aten.full,
274278
torch.ops.aten.flatten.using_ints,
275279
torch.ops.aten.dropout.default,
276280
torch.ops.aten.dropout_.default,
@@ -539,6 +543,7 @@ def annotate_graph( # type: ignore[return]
539543
if node.target in [
540544
torch.ops.aten.full_like.default,
541545
torch.ops.aten.full.default,
546+
torch.ops.aten.full,
542547
torch.ops.aten.scalar_tensor.default,
543548
]:
544549
node.kwargs = {}

backends/arm/runtime/EthosUBackend.cpp

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,24 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
261261
event_tracer,
262262
"+EthosUBackend::execute()handles.input.permute_CHW_to_HWC()");
263263
// permuted byte copy CHW to HWC
264+
int c, h, w;
265+
if (tensor_in.dim() == 4) {
266+
c = tensor_in.size(1);
267+
h = tensor_in.size(2);
268+
w = tensor_in.size(3);
269+
} else if (tensor_in.dim() == 5) {
270+
c = tensor_in.size(2);
271+
h = tensor_in.size(3);
272+
w = tensor_in.size(4);
273+
} else {
274+
ET_LOG(
275+
Error,
276+
"Unsupported input tensor dimension %d, expected 4 or 5",
277+
tensor_in.dim());
278+
return Error::InvalidProgram;
279+
}
264280
permute_CHW_to_HWC(
265-
tensor_in.mutable_data_ptr<char>(),
266-
scratch_addr,
267-
tensor_in.size(1),
268-
tensor_in.size(2),
269-
tensor_in.size(3));
281+
tensor_in.mutable_data_ptr<char>(), scratch_addr, c, h, w);
270282
} else if (both_char or both_int or both_short) {
271283
EXECUTORCH_PROF_SCOPE(
272284
event_tracer, "+EthosUBackend::execute()handles.input.memcpy()");
@@ -364,12 +376,24 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
364376
"+EthosUBackend::execute()handles.output.permute_HWC_to_CHW()");
365377

366378
char* output_address = (char*)output_addr;
379+
int c, h, w;
380+
if (tensor_out.dim() == 4) {
381+
c = tensor_out.size(1);
382+
h = tensor_out.size(2);
383+
w = tensor_out.size(3);
384+
} else if (tensor_out.dim() == 5) {
385+
c = tensor_out.size(2);
386+
h = tensor_out.size(3);
387+
w = tensor_out.size(4);
388+
} else {
389+
ET_LOG(
390+
Error,
391+
"Unsupported output tensor dimension %d, expected 4 or 5",
392+
tensor_out.dim());
393+
return Error::InvalidProgram;
394+
}
367395
permute_HWC_to_CHW(
368-
output_address,
369-
tensor_out.mutable_data_ptr<char>(),
370-
tensor_out.size(1),
371-
tensor_out.size(2),
372-
tensor_out.size(3));
396+
output_address, tensor_out.mutable_data_ptr<char>(), c, h, w);
373397
} else {
374398
EXECUTORCH_PROF_SCOPE(
375399
event_tracer, "+EthosUBackend::execute()handles.output.move()");
@@ -430,6 +454,14 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
430454
if (permuted_shape) {
431455
ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
432456
}
457+
} else if (tensor.dim() == 5) {
458+
// Same as above, but for 5D tensors.
459+
permuted_shape = tensor.size(0) == io->shape[0] &&
460+
tensor.size(1) == io->shape[1] && tensor.size(2) == io->shape[4] &&
461+
tensor.size(3) == io->shape[2] && tensor.size(4) == io->shape[3];
462+
if (permuted_shape) {
463+
ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
464+
}
433465
}
434466
*is_permuted = permuted_shape;
435467
return Error::Ok;

backends/arm/scripts/parse_test_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"bitwise_right_shift.Tensor",
1818
"bitwise_left_shift.Tensor",
1919
"native_group_norm.default",
20+
"unbind.int",
21+
"unflatten.int",
2022
"_native_batch_norm_legit_no_training.default",
2123
"_native_batch_norm_legit.no_stats",
2224
]

backends/arm/test/models/test_deit_tiny_arm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_deit_tiny_tosa_BI():
5252
aten_op=[],
5353
exir_op=[],
5454
use_to_edge_transform_and_lower=True,
55-
atol=2.5, # This needs to go down: MLETORCH-956
55+
atol=1,
5656
qtol=1,
5757
)
5858
pipeline.run()

backends/arm/test/models/test_llama.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,4 @@ def test_llama_tosa_BI():
126126
exir_op=[],
127127
use_to_edge_transform_and_lower=True,
128128
)
129-
pipeline.change_args(
130-
"run_method_and_compare_outputs",
131-
atol=9.9,
132-
rtol=1.5, # TODO: Tolerance needs to be updated after MLETORCH-907
133-
inputs=llama_inputs,
134-
)
135129
pipeline.run()

backends/arm/test/ops/test_repeat.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,68 +21,91 @@
2121
)
2222

2323
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x, Input y
24-
aten_op = "torch.ops.aten.repeat.default"
2524

2625

2726
"""Tests Tensor.repeat for different ranks and dimensions."""
2827

2928

3029
class Repeat(torch.nn.Module):
31-
# (input tensor, multiples)
32-
test_parameters = {
33-
"1_x_1": lambda: (torch.randn(3), (2,)),
34-
"2_x_2": lambda: (torch.randn(3, 4), (2, 1)),
35-
"4_x_4": lambda: (torch.randn(1, 1, 2, 2), (1, 2, 3, 4)),
36-
"1_x_2": lambda: (torch.randn(3), (2, 2)),
37-
"1_x_3": lambda: (torch.randn(3), (1, 2, 3)),
38-
"2_x_3": lambda: (torch.randn((3, 3)), (2, 2, 2)),
39-
"1_x_4": lambda: (torch.randn((3, 3, 3)), (2, 1, 2, 4)),
40-
}
41-
42-
def forward(self, x: torch.Tensor, multiples: Sequence):
43-
return x.repeat(multiples)
44-
45-
46-
@common.parametrize("test_data", Repeat.test_parameters)
30+
aten_op = "torch.ops.aten.repeat.default"
31+
32+
def __init__(self, multiples: Sequence[int]):
33+
super().__init__()
34+
self.multiples = multiples
35+
36+
def forward(self, x: torch.Tensor):
37+
return x.repeat(self.multiples)
38+
39+
40+
class RepeatInterleaveInt(torch.nn.Module):
41+
aten_op = "torch.ops.aten.repeat_interleave.self_int"
42+
43+
def __init__(self, repeats: int, dim: int):
44+
super().__init__()
45+
self.repeats = repeats
46+
self.dim = dim
47+
48+
def forward(self, x: torch.Tensor):
49+
return x.repeat_interleave(self.repeats, self.dim)
50+
51+
52+
test_data_suite = {
53+
# test_name : lambda: (module, test_data)
54+
"1_x_1": lambda: (Repeat((2,)), (torch.randn(3),)),
55+
"2_x_2": lambda: (Repeat((2, 1)), (torch.randn(3, 4),)),
56+
"4_x_4": lambda: (Repeat((1, 2, 3, 4)), (torch.randn(1, 1, 2, 2),)),
57+
"1_x_2": lambda: (Repeat((2, 2)), (torch.randn(3),)),
58+
"1_x_3": lambda: (Repeat((1, 2, 3)), (torch.randn(3),)),
59+
"2_x_3": lambda: (Repeat((2, 2, 2)), (torch.randn((3, 3)),)),
60+
"1_x_4": lambda: (Repeat((2, 1, 2, 4)), (torch.randn((3, 3, 3)),)),
61+
"interleave_int_3_x_1": lambda: (RepeatInterleaveInt(3, 1), (torch.randn(3, 4),)),
62+
}
63+
64+
65+
@common.parametrize("test_data", test_data_suite)
4766
def test_repeat_tosa_MI(test_data: Tuple):
67+
module, test_data = test_data()
4868
pipeline = TosaPipelineMI[input_t1](
49-
Repeat(),
50-
test_data(),
51-
aten_op,
69+
module,
70+
test_data,
71+
module.aten_op,
5272
exir_op=[],
5373
)
5474
pipeline.run()
5575

5676

57-
@common.parametrize("test_data", Repeat.test_parameters)
77+
@common.parametrize("test_data", test_data_suite)
5878
def test_repeat_tosa_BI(test_data: Tuple):
79+
module, test_data = test_data()
5980
pipeline = TosaPipelineBI[input_t1](
60-
Repeat(),
61-
test_data(),
62-
aten_op,
81+
module,
82+
test_data,
83+
module.aten_op,
6384
exir_op=[],
6485
)
6586
pipeline.run()
6687

6788

68-
@common.parametrize("test_data", Repeat.test_parameters)
89+
@common.parametrize("test_data", test_data_suite)
6990
def test_repeat_u55_BI(test_data: Tuple):
91+
module, test_data = test_data()
7092
pipeline = EthosU55PipelineBI[input_t1](
71-
Repeat(),
72-
test_data(),
73-
aten_op,
93+
module,
94+
test_data,
95+
module.aten_op,
7496
exir_ops=[],
7597
run_on_fvp=False,
7698
)
7799
pipeline.run()
78100

79101

80-
@common.parametrize("test_data", Repeat.test_parameters)
102+
@common.parametrize("test_data", test_data_suite)
81103
def test_repeat_u85_BI(test_data: Tuple):
104+
module, test_data = test_data()
82105
pipeline = EthosU85PipelineBI[input_t1](
83-
Repeat(),
84-
test_data(),
85-
aten_op,
106+
module,
107+
test_data,
108+
module.aten_op,
86109
exir_ops=[],
87110
run_on_fvp=False,
88111
)

backends/arm/test/ops/test_unbind.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
TosaPipelineBI,
13+
TosaPipelineMI,
14+
)
15+
16+
input_t = tuple[torch.Tensor]
17+
test_data_t = tuple[int, torch.dtype]
18+
19+
20+
class Unbind(torch.nn.Module):
21+
aten_op: str = "torch.ops.aten.unbind.int"
22+
23+
def __init__(self, dim: int):
24+
super().__init__()
25+
self.dim = dim
26+
27+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
28+
return torch.unbind(x, self.dim)
29+
30+
test_data: dict[str, test_data_t] = {
31+
"randn_4d": (lambda: (torch.randn(1, 5, 4, 3),), (2,)),
32+
"randn_3d": (lambda: (torch.randn(5, 4, 3),), (0,)),
33+
}
34+
35+
36+
@common.parametrize("test_data", Unbind.test_data)
37+
def test_unbind_int_tosa_MI(test_data: test_data_t):
38+
input_data, init_data = test_data
39+
pipeline = TosaPipelineMI[input_t](
40+
Unbind(*init_data),
41+
input_data(),
42+
Unbind.aten_op,
43+
)
44+
pipeline.run()
45+
46+
47+
@common.parametrize("test_data", Unbind.test_data)
48+
def test_unbind_int_tosa_BI(test_data: test_data_t):
49+
input_data, init_data = test_data
50+
pipeline = TosaPipelineBI[input_t](
51+
Unbind(*init_data),
52+
input_data(),
53+
Unbind.aten_op,
54+
)
55+
pipeline.run()
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
TosaPipelineBI,
13+
TosaPipelineMI,
14+
)
15+
16+
input_t = tuple[torch.Tensor]
17+
test_data_t = tuple[torch.nn.Module, input_t]
18+
19+
20+
class Unflatten(torch.nn.Module):
21+
aten_op: str = "torch.ops.aten.unflatten.int"
22+
23+
def __init__(self, dim: int, sizes: Tuple[int, ...]):
24+
super().__init__()
25+
self.dim = dim
26+
self.sizes = sizes
27+
28+
def forward(self, x: torch.Tensor) -> torch.Tensor:
29+
return torch.unflatten(x, self.dim, self.sizes)
30+
31+
test_data: dict[str, test_data_t] = {
32+
"randn_4d": (lambda: (Unflatten(1, (2, 2)), (torch.randn(3, 4, 5, 1),))),
33+
"rand_3d": (lambda: (Unflatten(1, (-1, 2)), (torch.rand(3, 4, 4),))),
34+
}
35+
36+
37+
@common.parametrize("test_data", Unflatten.test_data)
38+
def test_unflatten_int_tosa_MI(test_data: test_data_t):
39+
module, inputs = test_data()
40+
pipeline = TosaPipelineMI[input_t](
41+
module,
42+
inputs,
43+
Unflatten.aten_op,
44+
)
45+
pipeline.run()
46+
47+
48+
@common.parametrize("test_data", Unflatten.test_data)
49+
def test_unflatten_int_tosa_BI(test_data: test_data_t):
50+
module, inputs = test_data()
51+
pipeline = TosaPipelineBI[input_t](
52+
module,
53+
inputs,
54+
Unflatten.aten_op,
55+
)
56+
pipeline.run()

0 commit comments

Comments
 (0)