Skip to content

Commit 5a71ead

Browse files
SaoirseARMYIWENX14
authored andcommitted
Fix for multiple outputs in FVP tests (#7650)
Fix for multiple outputs in corstone - Update to ensure all output nodes are consumed. - Update to ensure output quant scales are used.
1 parent dde1dc2 commit 5a71ead

File tree

4 files changed

+114
-57
lines changed

4 files changed

+114
-57
lines changed

backends/arm/test/misc/test_multiple_outputs.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
import unittest
88

9+
import pytest
910
import torch
10-
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test import common, conftest
1112
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
from executorch.exir.backend.compile_spec_schema import CompileSpec
1214

1315

1416
class TestMultipleOutputs(unittest.TestCase):
@@ -51,3 +53,46 @@ def test_tosa_BI_pipeline(self):
5153
.to_executorch()
5254
.run_method_and_compare_outputs(inputs=inputs, qtol=1.0)
5355
)
56+
57+
def _test_ethosu_BI_pipeline(
58+
self,
59+
module: torch.nn.Module,
60+
test_data: tuple[torch.Tensor],
61+
compile_spec: CompileSpec,
62+
):
63+
tester = (
64+
ArmTester(
65+
module,
66+
example_inputs=test_data,
67+
compile_spec=compile_spec,
68+
)
69+
.quantize()
70+
.export()
71+
.to_edge_transform_and_lower()
72+
.to_executorch()
73+
.serialize()
74+
)
75+
if conftest.is_option_enabled("corstone_fvp"):
76+
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
77+
78+
@pytest.mark.corstone_fvp
79+
def test_u85_BI(self):
80+
module = self.MultipleOutputsModule()
81+
test_data = module.get_inputs()
82+
self._test_ethosu_BI_pipeline(
83+
module,
84+
test_data,
85+
common.get_u85_compile_spec(),
86+
)
87+
88+
@pytest.mark.corstone_fvp
89+
@conftest.expectedFailureOnFVP
90+
# TODO MLETORCH-598
91+
def test_u55_BI(self):
92+
module = self.MultipleOutputsModule()
93+
test_data = module.get_inputs()
94+
self._test_ethosu_BI_pipeline(
95+
module,
96+
test_data,
97+
common.get_u55_compile_spec(),
98+
)

backends/arm/test/runner_utils.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -115,50 +115,53 @@ def _get_input_quantization_params(
115115
return quant_params
116116

117117

118-
def _get_output_node(program: ExportedProgram) -> Node:
118+
def _get_output_nodes(program: ExportedProgram) -> list[Node]:
119119
"""
120120
Get output node to this model.
121121
122122
Args:
123-
program (ExportedProgram): The program to get output node from.
123+
program (ExportedProgram): The program to get the output nodes from.
124124
Returns:
125-
The node that is the output of 'program'.
125+
The nodes that are the outputs of the 'program'.
126126
"""
127-
127+
output_nodes = []
128128
for node in program.graph.nodes:
129129
if node.op == "output":
130-
return node
131-
raise RuntimeError("No output node found.")
130+
for output in node.args[0]:
131+
output_nodes.append(output)
132+
if len(output_nodes) == 0:
133+
raise RuntimeError("No output nodes found.")
134+
else:
135+
return output_nodes
132136

133137

134138
def _get_output_quantization_params(
135-
program: ExportedProgram, output_node: Node
136-
) -> Optional[QuantizationParams]:
139+
output_nodes: list[Node],
140+
) -> List[QuantizationParams]:
137141
"""
138142
Get output QuantizationParams from a program.
139143
Args:
140-
program (ExportedProgram): The program to get output quantization parameters from.
144+
output_nodes (list(Node)): A list of output nodes to get output quantization parameters from.
141145
Returns:
142146
QuantizationParams: The found quantization parameters.
143147
Raises:
144148
RuntimeError if no output quantization parameters are found.
145149
"""
146-
147-
quant_params = None
148-
for node in program.graph.nodes:
149-
if (
150-
node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default
151-
and node == output_node.args[0][0]
152-
):
153-
quant_params = QuantizationParams(
154-
node_name=node.args[0].name,
155-
scale=node.args[1],
156-
zp=node.args[2],
157-
qmin=node.args[3],
158-
qmax=node.args[4],
159-
dtype=node.args[5],
150+
quant_params = []
151+
for node in output_nodes:
152+
if node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default:
153+
quant_params.append(
154+
QuantizationParams(
155+
node_name=node.args[0].name,
156+
scale=node.args[1],
157+
zp=node.args[2],
158+
qmin=node.args[3],
159+
qmax=node.args[4],
160+
dtype=node.args[5],
161+
)
160162
)
161-
break # break early, there's only one output node
163+
if len(quant_params) == 0:
164+
raise RuntimeError("No Quantization parameters not found in exported model.")
162165
return quant_params
163166

164167

@@ -211,7 +214,7 @@ def __init__(
211214
self.input_names: list[str] = None
212215
self.output_name: str = None
213216
self.qp_input: list[QuantizationParams] = None
214-
self.qp_output: QuantizationParams = None
217+
self.qp_output: list[QuantizationParams] = None
215218
self.timeout = 480
216219
self.target_board: str = None
217220

@@ -226,19 +229,17 @@ def init_run(
226229
):
227230

228231
self.input_names = _get_input_names(edge_program)
229-
self.output_node = _get_output_node(exported_program)
230-
self.output_name = self.output_node.name
232+
self.output_nodes = _get_output_nodes(exported_program)
233+
231234
self.is_quantized = is_quantized
232235
self.target_board = target_board
233236

234237
if is_quantized:
235238
self.qp_input = _get_input_quantization_params(exported_program)
236-
self.qp_output = _get_output_quantization_params(
237-
exported_program, self.output_node
238-
)
239+
self.qp_output = _get_output_quantization_params(self.output_nodes)
239240
else:
240241
self.qp_input = [None] * len(self.input_names)
241-
self.qp_output = None
242+
self.qp_output = [None] * len(self.output_nodes)
242243

243244
self._has_init_run = True
244245

@@ -265,7 +266,7 @@ def run_corstone(
265266
save_bytes(self.intermediate_path, data, False, input_name, quant_param)
266267

267268
out_path = os.path.join(self.intermediate_path, "out")
268-
out_path_with_suffix = out_path + "-0.bin"
269+
269270
input_paths = []
270271
for name in self.input_names:
271272
input_paths.append(
@@ -281,6 +282,7 @@ def run_corstone(
281282
), f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?"
282283

283284
cmd_line = f"executor_runner -m {pte_path} -o {out_path}"
285+
284286
for input_path in input_paths:
285287
cmd_line += f" -i {input_path}"
286288

@@ -362,11 +364,14 @@ def run_corstone(
362364
raise RuntimeError(
363365
f"Corstone simulation failed:\ncmd: {command_args[self.target_board]}\n, log: \n {result_stdout}\n{result.stderr.decode()}"
364366
)
365-
366-
tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
367-
output_shape = self.output_node.args[0][0].meta["val"].shape
368-
tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape)
369-
return tosa_ref_output
367+
output_np = []
368+
for i, node in enumerate(self.output_nodes):
369+
tosa_ref_output = np.fromfile(
370+
os.path.join(self.intermediate_path, f"out-{i}.bin"), dtype=np.float32
371+
)
372+
output_shape = node.meta["val"].shape
373+
output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
374+
return tuple(output_np)
370375

371376
def run_tosa_graph(
372377
self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor]

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -9,7 +9,7 @@
99
import torch
1010
from executorch.backends.arm.test.runner_utils import (
1111
_get_input_quantization_params,
12-
_get_output_node,
12+
_get_output_nodes,
1313
_get_output_quantization_params,
1414
)
1515

@@ -228,9 +228,9 @@ def dump_error_output(
228228
export_stage = tester.stages.get(tester.stage_name(Export), None)
229229
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
230230
if export_stage is not None and quantize_stage is not None:
231-
output_node = _get_output_node(export_stage.artifact)
231+
output_nodes = _get_output_nodes(export_stage.artifact)
232232
qp_input = _get_input_quantization_params(export_stage.artifact)
233-
qp_output = _get_output_quantization_params(export_stage.artifact, output_node)
233+
qp_output = _get_output_quantization_params(output_nodes)
234234
logger.error(f"Input QuantArgs: {qp_input}")
235235
logger.error(f"Output QuantArgs: {qp_output}")
236236

backends/arm/test/tester/arm_tester.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import serializer.tosa_serializer as ts
1515

1616
import torch.fx
17+
import torch.utils._pytree as pytree
1718

1819
from executorch.backends.arm.arm_backend import get_intermediate_path
1920
from executorch.backends.arm.arm_partitioner import ArmPartitioner
@@ -302,21 +303,22 @@ def run_method_and_compare_outputs(
302303

303304
exported_program = self.stages[self.stage_name(tester.Export)].artifact
304305
edge_program = edge_stage.artifact.exported_program()
306+
305307
self.runner_util.init_run(
306308
exported_program,
307309
edge_program,
308310
is_quantized,
309311
target_board,
310312
)
311313

312-
quantization_scale = None
313314
if is_quantized:
314315
reference_stage = self.stages[self.stage_name(tester.Quantize)]
315316
# bool output is quantized with none quantized output so allow
316317
# self.runner_util.qp_output to be none
317318
if self.runner_util.qp_output is not None:
318-
quantization_scale = self.runner_util.qp_output.scale
319+
quantization_scales = [qp.scale for qp in self.runner_util.qp_output]
319320
else:
321+
quantization_scales = [None] * len(self.runner_util.output_nodes)
320322
reference_stage = self.stages[self.stage_name(InitialModel)]
321323

322324
logger.info(
@@ -334,21 +336,26 @@ def run_method_and_compare_outputs(
334336
input_shape_str = ", ".join([str(list(i)) for i in input_shapes])
335337
logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}")
336338

337-
reference_output = reference_stage.run_artifact(reference_input)
338-
if not isinstance(reference_output, tuple):
339-
reference_output = (reference_output,)
340-
test_output = test_stage.run_artifact(reference_input)
341-
342-
self._compare_outputs(
343-
reference_output,
344-
test_output,
345-
quantization_scale,
346-
atol,
347-
rtol,
348-
qtol,
349-
error_callbacks,
339+
reference_outputs, _ = pytree.tree_flatten(
340+
reference_stage.run_artifact(reference_input)
341+
)
342+
test_outputs, _ = pytree.tree_flatten(
343+
test_stage.run_artifact(reference_input)
350344
)
351345

346+
for reference_output, test_output, quantization_scale in zip(
347+
reference_outputs, test_outputs, quantization_scales
348+
):
349+
self._compare_outputs(
350+
reference_output,
351+
test_output,
352+
quantization_scale,
353+
atol,
354+
rtol,
355+
qtol,
356+
error_callbacks,
357+
)
358+
352359
return self
353360

354361
def get_graph(self, stage: str | None = None) -> Graph:

0 commit comments

Comments
 (0)