Skip to content

Commit c7c6ad9

Browse files
authored
Arm backend: Test popular torch modules/ functions, adress issues (#9221)
- Adds tests for top 10 torch.nn.Modules, torch.* functions, and torch.nn.functional. - Add check to not partion ops with int64 input - A few minor fixes Signed-off-by: Erik Lundell <[email protected]>
1 parent 6de5b36 commit c7c6ad9

File tree

9 files changed

+415
-8
lines changed

9 files changed

+415
-8
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
199199
)
200200

201201
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
202-
self.add_pass(ScalarsToAttributePass())
203202
self.add_pass(ReplaceScalarWithTensorArgPass())
203+
self.add_pass(ScalarsToAttributePass())
204204
self.add_pass(DecomposeLayerNormPass())
205205
self.add_pass(DecomposeVarPass())
206206
self.add_pass(DecomposeMeanDimPass())

backends/arm/_passes/arm_pass_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-unsafe
99

1010
from inspect import isclass
11-
from typing import Optional
11+
from typing import Optional, Sequence
1212

1313
import torch
1414
import torch.fx
@@ -149,7 +149,7 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
149149
If the node contains many fake tensors, return the first one.
150150
"""
151151
if isinstance(
152-
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
152+
node.meta["val"], (Sequence, torch.fx.immutable_collections.immutable_list)
153153
):
154154
fake_tensor = node.meta["val"][0]
155155
else:

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
FuseQuantizedActivationPass,
1919
)
2020
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
21+
from executorch.exir import ExportedProgram
2122
from executorch.exir.dialects._ops import ops as exir_ops
23+
from torch.export.graph_signature import InputKind
2224
from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase
2325
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
2426

@@ -84,9 +86,10 @@ def get_registered_tosa_support_checks(
8486

8587
def tosa_support_factory(
8688
tosa_spec: TosaSpecification,
89+
exported_program: ExportedProgram,
8790
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
8891
) -> OperatorSupportBase:
89-
negative_checks: list[OperatorSupportBase] = []
92+
negative_checks: list[OperatorSupportBase] = [CheckInt64Inputs(exported_program)]
9093
if not tosa_spec.support_float():
9194
negative_checks.append(NeedsDecompositionCheck())
9295
negative_checks.append(CheckProperQuantization())
@@ -247,6 +250,10 @@ def is_node_supported(
247250
exir_ops.edge.aten._log_softmax.default,
248251
exir_ops.edge.aten.var.correction,
249252
exir_ops.edge.aten.var.dim,
253+
exir_ops.edge.aten.add.Scalar,
254+
exir_ops.edge.aten.sub.Scalar,
255+
exir_ops.edge.aten.mul.Scalar,
256+
exir_ops.edge.aten.div.Scalar,
250257
]
251258
return not needs_decomp
252259

@@ -312,6 +319,8 @@ def is_node_supported(
312319
exir_ops.edge.aten.bmm.default,
313320
exir_ops.edge.aten.convolution.default,
314321
exir_ops.edge.aten.exp.default,
322+
exir_ops.edge.aten.full.default,
323+
exir_ops.edge.aten.full_like.default,
315324
exir_ops.edge.aten.hardtanh.default,
316325
exir_ops.edge.aten.linear.default,
317326
exir_ops.edge.aten.log.default,
@@ -371,3 +380,29 @@ def is_node_supported(
371380
if not output_quantized:
372381
return False
373382
return True
383+
384+
385+
class CheckInt64Inputs(OperatorSupportBase):
386+
387+
def __init__(self, exported_program: ExportedProgram):
388+
self.input_names = [
389+
spec.arg.name
390+
for spec in exported_program.graph_signature.input_specs
391+
if spec.kind == InputKind.USER_INPUT
392+
]
393+
super().__init__()
394+
395+
def is_node_supported(
396+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
397+
) -> bool:
398+
399+
for input_node in node.all_input_nodes:
400+
# We can cast constant placeholders AOT, not call_functions.
401+
if (
402+
input_node.name in self.input_names
403+
or not input_node.op == "placeholder"
404+
):
405+
tensor = get_first_fake_tensor(input_node)
406+
if tensor.dtype == torch.int64:
407+
return False
408+
return True

backends/arm/test/models/test_conformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_conformer_tosa_BI(self):
9393
)
9494
)
9595

96-
@unittest.expectedFailure # TODO(MLETORCH-635)
96+
@conftest.expectedFailureOnFVP # TODO(MLETORCH-635)
9797
def test_conformer_u55_BI(self):
9898
tester = (
9999
ArmTester(
@@ -115,7 +115,7 @@ def test_conformer_u55_BI(self):
115115
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
116116
)
117117

118-
@unittest.expectedFailure # TODO(MLETORCH-635)
118+
@conftest.expectedFailureOnFVP # TODO(MLETORCH-635)
119119
def test_conformer_u85_BI(self):
120120
tester = (
121121
ArmTester(
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Tests 10 popular torch.nn.functional not tested in other ways or training related
8+
- normalize
9+
- grid_sample
10+
- one_hot
11+
- softplus
12+
- cosine_similarity
13+
- unfold
14+
- elu
15+
- fold
16+
- affine_grid
17+
- max_pool1d
18+
- threshold
19+
"""
20+
from typing import Callable
21+
22+
import torch
23+
from executorch.backends.arm.test.common import parametrize
24+
from executorch.backends.arm.test.tester.test_pipeline import (
25+
TosaPipelineBI,
26+
TosaPipelineMI,
27+
)
28+
29+
30+
def module_factory(function: Callable) -> torch.nn.Module:
31+
class ModuleWrapper(torch.nn.Module):
32+
def forward(self, *args):
33+
return function(*args)
34+
35+
return ModuleWrapper()
36+
37+
38+
example_input = torch.rand(1, 6, 16, 16)
39+
40+
module_tests = {
41+
"normalize": (module_factory(torch.nn.functional.normalize), (example_input,)),
42+
"grid_sample": (
43+
module_factory(torch.nn.functional.grid_sample),
44+
(torch.rand(1, 1, 4, 4), torch.rand(1, 5, 5, 2)),
45+
),
46+
"one_hot": (
47+
module_factory(torch.nn.functional.one_hot),
48+
(torch.randint(0, 5, (2, 2, 5, 5)), 5),
49+
),
50+
"softplus": (module_factory(torch.nn.functional.softplus), (example_input,)),
51+
"cosine_similarity": (
52+
module_factory(torch.nn.functional.cosine_similarity),
53+
(example_input, example_input),
54+
),
55+
"unfold": (
56+
module_factory(torch.nn.functional.unfold),
57+
(torch.randn(1, 3, 10, 12), (4, 5)),
58+
),
59+
"elu": (module_factory(torch.nn.functional.elu), (example_input,)),
60+
"fold": (
61+
module_factory(torch.nn.functional.fold),
62+
(torch.randn(1, 12, 12), (4, 5), (2, 2)),
63+
),
64+
"affine_grid": (
65+
module_factory(torch.nn.functional.affine_grid),
66+
(torch.rand(1, 2, 3), (1, 2, 10, 10)),
67+
),
68+
"max_pool1d": (
69+
module_factory(torch.nn.functional.max_pool1d),
70+
(torch.randn(20, 16, 50), 4),
71+
),
72+
"threshold": (
73+
module_factory(torch.nn.functional.threshold),
74+
(example_input, 0.5, 0.1),
75+
),
76+
}
77+
78+
input_t = tuple[torch.Tensor]
79+
80+
81+
@parametrize(
82+
"test_data", module_tests, xfails={"max_pool1d": "ValueError: Invalid TOSA graph"}
83+
)
84+
def test_nn_functional_MI(test_data):
85+
module, inputs = test_data
86+
pipeline = TosaPipelineMI[input_t](
87+
module, inputs, "", use_to_edge_transform_and_lower=True
88+
)
89+
pipeline.pop_stage("check.aten")
90+
pipeline.pop_stage("check_count.exir")
91+
try:
92+
pipeline.run()
93+
except RuntimeError as e:
94+
if (
95+
"Ran model with TosaReferenceModelDispatch but never ran TOSABackend delegate."
96+
not in str(e)
97+
):
98+
raise e
99+
100+
101+
@parametrize("test_data", module_tests)
102+
def test_nn_functional_BI(test_data):
103+
module, inputs = test_data
104+
pipeline = TosaPipelineBI[input_t](
105+
module, inputs, "", use_to_edge_transform_and_lower=True
106+
)
107+
pipeline.pop_stage("check.aten")
108+
pipeline.pop_stage("check_count.exir")
109+
pipeline.pop_stage("check.quant_nodes")
110+
pipeline.pop_stage("check_not.quant_nodes")
111+
try:
112+
pipeline.run()
113+
except RuntimeError as e:
114+
if (
115+
"Ran model with TosaReferenceModelDispatch but never ran TOSABackend delegate."
116+
not in str(e)
117+
):
118+
raise e
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Tests 10 popular nn modules not tested in other ways or training related.
8+
- Embedding
9+
- LeakyReLU
10+
- BatchNorm1d
11+
- AdaptiveAvgPool2d
12+
- ConvTranspose2d
13+
- GRU
14+
- GroupNorm
15+
- InstanceNorm2d
16+
- PReLU
17+
- Transformer
18+
"""
19+
20+
import torch
21+
from executorch.backends.arm.test.common import parametrize
22+
from executorch.backends.arm.test.tester.test_pipeline import (
23+
TosaPipelineBI,
24+
TosaPipelineMI,
25+
)
26+
27+
example_input = torch.rand(1, 6, 16, 16)
28+
29+
module_tests = [
30+
(torch.nn.Embedding(10, 10), (torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),)),
31+
(torch.nn.LeakyReLU(), (example_input,)),
32+
(torch.nn.BatchNorm1d(16), (torch.rand(6, 16, 16),)),
33+
(torch.nn.AdaptiveAvgPool2d((12, 12)), (example_input,)),
34+
(torch.nn.ConvTranspose2d(6, 3, 2), (example_input,)),
35+
(torch.nn.GRU(10, 20, 2), (torch.randn(5, 3, 10), torch.randn(2, 3, 20))),
36+
(torch.nn.GroupNorm(2, 6), (example_input,)),
37+
(torch.nn.InstanceNorm2d(16), (example_input,)),
38+
(torch.nn.PReLU(), (example_input,)),
39+
(
40+
torch.nn.Transformer(
41+
d_model=64,
42+
nhead=1,
43+
num_encoder_layers=1,
44+
num_decoder_layers=1,
45+
dtype=torch.float32,
46+
),
47+
(torch.rand((10, 32, 64)), torch.rand((20, 32, 64))),
48+
),
49+
]
50+
51+
input_t = tuple[torch.Tensor]
52+
53+
test_parameters = {str(test[0].__class__.__name__): test for test in module_tests}
54+
55+
56+
@parametrize(
57+
"test_data",
58+
test_parameters,
59+
xfails={"Transformer": "Output 0 does not match reference output."},
60+
)
61+
def test_nn_Modules_MI(test_data):
62+
module, inputs = test_data
63+
pipeline = TosaPipelineMI[input_t](
64+
module, inputs, "", use_to_edge_transform_and_lower=True
65+
)
66+
pipeline.pop_stage("check.aten")
67+
pipeline.pop_stage("check_count.exir")
68+
try:
69+
pipeline.run()
70+
except RuntimeError as e:
71+
if (
72+
"Ran model with TosaReferenceModelDispatch but never ran TOSABackend delegate."
73+
not in str(e)
74+
):
75+
raise e
76+
77+
78+
@parametrize(
79+
"test_data",
80+
test_parameters,
81+
xfails={
82+
"GRU": "RuntimeError: Node aten_linear_default with op <EdgeOpOverload: aten.linear[...]> was not decomposed or delegated.",
83+
"PReLU": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.",
84+
"Transformer": "RuntimeError: Expected out tensor to have dtype signed char, but got float",
85+
},
86+
)
87+
def test_nn_Modules_BI(test_data):
88+
module, inputs = test_data
89+
pipeline = TosaPipelineBI[input_t](
90+
module, inputs, "", use_to_edge_transform_and_lower=True
91+
)
92+
pipeline.pop_stage("check.aten")
93+
pipeline.pop_stage("check_count.exir")
94+
pipeline.pop_stage("check.quant_nodes")
95+
pipeline.pop_stage("check_not.quant_nodes")
96+
try:
97+
pipeline.run()
98+
except RuntimeError as e:
99+
if (
100+
"Ran model with TosaReferenceModelDispatch but never ran TOSABackend delegate."
101+
not in str(e)
102+
):
103+
raise e

0 commit comments

Comments
 (0)