Skip to content

Commit 0b16f27

Browse files
authored
Arm backend: Fix helper method for getting input names (#7996)
The previous version missed ConstantInputs which caused problems in the runtime when the provided inputs did not match the GraphSignature. You could argue that having to provide constant inputs that are already encoded in the graph as arguments is unnecessary. However, this solution is more general and does not stop us from adding a pass that prunes unwanted inputs from the graph signature in the future. (if that is possible) Encountered an issue with too long input names, had to rename some parameters in testcases to shorter names (added ticket to adress this, #MLETORCH-628) Signed-off-by: Erik Lundell <[email protected]>
1 parent 42ff569 commit 0b16f27

File tree

4 files changed

+59
-35
lines changed

4 files changed

+59
-35
lines changed

backends/arm/test/ops/test_cat.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
33
# All rights reserved.
44
#
55
# This source code is licensed under the BSD-style license found in the
@@ -33,6 +33,8 @@ class Cat(torch.nn.Module):
3333
),
3434
-1,
3535
),
36+
((torch.randn(1, 2, 4, 4), torch.randn(1, 2, 4, 1)), 3),
37+
((torch.randn(1, 2, 4, 4), torch.randn(1, 2, 4, 4)), 0),
3638
((torch.randn(2, 2, 4, 4), torch.randn(2, 2, 4, 1)), 3),
3739
(
3840
(
@@ -47,8 +49,8 @@ class Cat(torch.nn.Module):
4749
def __init__(self):
4850
super().__init__()
4951

50-
def forward(self, tensors: tuple[torch.Tensor, ...], dim: int) -> torch.Tensor:
51-
return torch.cat(tensors, dim=dim)
52+
def forward(self, t: tuple[torch.Tensor, ...], dim: int) -> torch.Tensor:
53+
return torch.cat(t, dim=dim)
5254

5355
def _test_cat_tosa_MI_pipeline(
5456
self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int]
@@ -134,22 +136,38 @@ def test_cat_tosa_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
134136
test_data = (operands, dim)
135137
self._test_cat_tosa_BI_pipeline(self.Cat(), test_data)
136138

137-
# Mismatch in provided number of inputs and model signature, MLETORCH 519
138-
@parameterized.expand(Cat.test_parameters)
139+
@parameterized.expand(Cat.test_parameters[:-3])
139140
@pytest.mark.corstone_fvp
140-
@conftest.expectedFailureOnFVP
141141
def test_cat_u55_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
142142
test_data = (operands, dim)
143143
self._test_cat_ethosu_BI_pipeline(
144144
self.Cat(), common.get_u55_compile_spec(), test_data
145145
)
146146

147-
# Mismatch in provided number of inputs and model signature, MLETORCH 519
148-
@parameterized.expand(Cat.test_parameters)
147+
# MLETORCH-630 Cat does not work on FVP with batch>1
148+
@parameterized.expand(Cat.test_parameters[-3:])
149149
@pytest.mark.corstone_fvp
150150
@conftest.expectedFailureOnFVP
151+
def test_cat_u55_BI_xfails(self, operands: tuple[torch.Tensor, ...], dim: int):
152+
test_data = (operands, dim)
153+
self._test_cat_ethosu_BI_pipeline(
154+
self.Cat(), common.get_u55_compile_spec(), test_data
155+
)
156+
157+
@parameterized.expand(Cat.test_parameters[:-3])
158+
@pytest.mark.corstone_fvp
151159
def test_cat_u85_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
152160
test_data = (operands, dim)
153161
self._test_cat_ethosu_BI_pipeline(
154162
self.Cat(), common.get_u85_compile_spec(), test_data
155163
)
164+
165+
# MLETORCH-630 Cat does not work on FVP with batch>1
166+
@parameterized.expand(Cat.test_parameters[-3:])
167+
@pytest.mark.corstone_fvp
168+
@conftest.expectedFailureOnFVP
169+
def test_cat_u85_BI_xfails(self, operands: tuple[torch.Tensor, ...], dim: int):
170+
test_data = (operands, dim)
171+
self._test_cat_ethosu_BI_pipeline(
172+
self.Cat(), common.get_u85_compile_spec(), test_data
173+
)

backends/arm/test/ops/test_expand.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,17 @@ class Expand(torch.nn.Module):
3737
test_parameters = [
3838
(torch.rand(1), (2,)),
3939
(torch.randn(1, 4), (1, -1)),
40-
(torch.rand(1, 1, 2, 2), (4, 3, -1, 2)),
4140
(torch.randn(1), (2, 2, 4)),
42-
(torch.rand(3, 2, 4, 1), (-1, -1, -1, 3)),
41+
(torch.randn(1, 1, 1, 5), (1, 4, -1, -1)),
4342
(torch.randn(1, 1, 192), (1, -1, -1)),
43+
(torch.randn(1, 1), (1, 2, 2, 4)),
44+
(torch.randn(1, 1), (2, 2, 2, 4)),
4445
(torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)),
46+
(torch.rand(1, 1, 2, 2), (4, 3, -1, 2)),
4547
]
4648

47-
def forward(self, x: torch.Tensor, multiples: Sequence):
48-
return x.expand(multiples)
49+
def forward(self, x: torch.Tensor, m: Sequence):
50+
return x.expand(m)
4951

5052
def _test_expand_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
5153
(
@@ -113,20 +115,34 @@ def test_expand_tosa_MI(self, test_input, multiples):
113115
def test_expand_tosa_BI(self, test_input, multiples):
114116
self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))
115117

116-
# Mismatch in provided number of inputs and model signature, MLETORCH 519
117-
@parameterized.expand(Expand.test_parameters)
118+
@parameterized.expand(Expand.test_parameters[:-3])
118119
@pytest.mark.corstone_fvp
119-
@conftest.expectedFailureOnFVP
120120
def test_expand_u55_BI(self, test_input, multiples):
121121
self._test_expand_ethosu_BI_pipeline(
122122
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
123123
)
124124

125-
# Mismatch in provided number of inputs and model signature, MLETORCH 519
126-
@parameterized.expand(Expand.test_parameters)
125+
# MLETORCH-629: Expand does not work on FVP with batch>1
126+
@parameterized.expand(Expand.test_parameters[-3:])
127127
@pytest.mark.corstone_fvp
128128
@conftest.expectedFailureOnFVP
129+
def test_expand_u55_BI_xfails(self, test_input, multiples):
130+
self._test_expand_ethosu_BI_pipeline(
131+
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
132+
)
133+
134+
@parameterized.expand(Expand.test_parameters[:-3])
135+
@pytest.mark.corstone_fvp
129136
def test_expand_u85_BI(self, test_input, multiples):
130137
self._test_expand_ethosu_BI_pipeline(
131138
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
132139
)
140+
141+
# MLETORCH-629: Expand does not work on FVP with batch>1
142+
@parameterized.expand(Expand.test_parameters[-3:])
143+
@pytest.mark.corstone_fvp
144+
@conftest.expectedFailureOnFVP
145+
def test_expand_u85_BI_xfails(self, test_input, multiples):
146+
self._test_expand_ethosu_BI_pipeline(
147+
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
148+
)

backends/arm/test/ops/test_full.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,20 +143,16 @@ def test_full_tosa_MI(self, test_tensor: Tuple):
143143
def test_full_tosa_BI(self, test_tensor: Tuple):
144144
self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor)
145145

146-
# Mismatch in provided number of inputs and model signature, MLETORCH 519
147146
@parameterized.expand(AddVariableFull.test_parameters)
148147
@pytest.mark.corstone_fvp
149-
@conftest.expectedFailureOnFVP
150148
def test_full_u55_BI(self, test_tensor: Tuple):
151149
self._test_full_tosa_u55_pipeline(
152150
self.AddVariableFull(),
153151
test_tensor,
154152
)
155153

156-
# Mismatch in provided number of inputs and model signature, MLETORCH 519
157154
@parameterized.expand(AddVariableFull.test_parameters)
158155
@pytest.mark.corstone_fvp
159-
@conftest.expectedFailureOnFVP
160156
def test_full_u85_BI(self, test_tensor: Tuple):
161157
self._test_full_tosa_u85_pipeline(
162158
self.AddVariableFull(),

backends/arm/test/runner_utils.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,7 @@ def get_input_names(program: ExportedProgram) -> list[str]:
6565
Returns:
6666
A list of strings with the names of the model input.
6767
"""
68-
input_names = []
69-
70-
# E.g. bias and weights are 'placeholders' as well. This is used to
71-
# get only the use inputs.
72-
usr_inputs = program.graph_signature.user_inputs
73-
for node in program.graph.nodes:
74-
if node.op == "placeholder" and node.name in usr_inputs:
75-
input_names.append(node.name)
76-
77-
return input_names
68+
return [spec.arg.name for spec in program.graph_signature.input_specs]
7869

7970

8071
def get_input_quantization_params(
@@ -334,13 +325,16 @@ def run_corstone(
334325

335326

336327
def prep_data_for_save(
337-
data: torch.Tensor,
328+
data,
338329
input_name: str,
339330
quant_param: Optional[QuantizationParams] = None,
340331
):
341-
data_np = np.array(data.detach(), order="C").astype(
342-
torch_to_numpy_dtype_dict[data.dtype]
343-
)
332+
if isinstance(data, torch.Tensor):
333+
data_np = np.array(data.detach(), order="C").astype(
334+
torch_to_numpy_dtype_dict[data.dtype]
335+
)
336+
else:
337+
data_np = np.array(data)
344338
if quant_param is not None:
345339
assert quant_param.node_name in input_name, (
346340
f"The quantization params name '{quant_param.node_name}' does not "

0 commit comments

Comments
 (0)