Skip to content

Commit 93c3b66

Browse files
Arm backend: Make stage ids unique in the Arm TestPipeline (#8338)
Some stages in the TestPipeline may be added multiple times, such as .check(). To be able to target these by id, give them an unique suffix -> 'id.suffix' Refering to stages in terms of id instead of an index is more self documenting and future proof. This change modifies the add/pop_stage interface: - pos arg in add_stage is now an optional kwarg, appending to the pipline as default - suffix is added to add_stage as an optional_kwarg. If a suffix is not given to a non unique stage, a number is added instead. - pop_stage now allows to use ids for referring to stages. Additionally adds .visualize(stage) for quickly adding visualizing stages to the pipeline.
1 parent 8d96d74 commit 93c3b66

File tree

2 files changed

+103
-41
lines changed

2 files changed

+103
-41
lines changed

backends/arm/test/ops/test_conv2d.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -371,7 +370,7 @@ def test_conv2d_tosa_BI(test_module):
371370
pipeline = TosaPipelineBI[input_t](
372371
test_module, test_module.get_inputs(), aten_op, exir_op
373372
)
374-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
373+
pipeline.change_args("run_method_and_compare_outputs.0", qtol=1)
375374
pipeline.run()
376375

377376

backends/arm/test/tester/test_pipeline.py

Lines changed: 102 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class PipelineStage:
4646
is_called: keeps track of if the function has been called
4747
"""
4848

49-
def __init__(self, func, *args, **kwargs):
50-
self.id: str = func.__name__
49+
def __init__(self, func: Callable, id: str, *args, **kwargs):
50+
self.id: str = id
5151
self.func: Callable = func
5252
self.args = args
5353
self.kwargs = kwargs
@@ -86,72 +86,130 @@ def __init__(
8686
self.test_data = test_data
8787
self._stages = []
8888

89-
self.add_stage(-1, self.tester.export)
90-
self.add_stage(-1, self.tester.check, self.aten_ops)
89+
self.add_stage(self.tester.export)
90+
self.add_stage(self.tester.check, self.aten_ops, suffix="aten")
9191
if use_to_edge_transform_and_lower:
92-
self.add_stage(-1, self.tester.to_edge_transform_and_lower)
93-
92+
self.add_stage(self.tester.to_edge_transform_and_lower)
9493
else:
95-
self.add_stage(-1, self.tester.to_edge)
96-
self.add_stage(-1, self.tester.check, self.exir_ops)
97-
self.add_stage(-1, self.tester.partition)
98-
self.add_stage(-1, self.tester.check_not, self.exir_ops)
94+
self.add_stage(self.tester.to_edge)
95+
self.add_stage(self.tester.check, self.exir_ops, suffix="exir")
96+
self.add_stage(self.tester.partition)
97+
self.add_stage(self.tester.check_not, self.exir_ops, suffix="exir")
9998
self.add_stage(
100-
-1,
10199
self.tester.check_count,
102100
{"torch.ops.higher_order.executorch_call_delegate": 1},
101+
suffix="exir",
103102
)
104-
self.add_stage(-1, self.tester.to_executorch)
103+
self.add_stage(self.tester.to_executorch)
104+
105+
def add_stage(self, func: Callable, *args, **kwargs):
106+
"""
107+
Adds a stage defined by a function with args and kwargs. By default appends to the pipeline.
108+
For stages which may be added multiple times to a pipeline, s.a. checks and debug stages,
109+
a suffix is appended with a dot to make sure every id is unique, e.g. check becomes check.0
105110
106-
def add_stage(self, pos: int, func: Callable, *args, **kwargs):
107-
"""Adds a stage defined by a function with arguments to the pipeline at index pos. Pos wraps around the list for negative values."""
108-
pipeline_stage = self.PipelineStage(func, *args, **kwargs)
111+
Special kwargs:
112+
pos : specifies position in pipeline to add stage at.
113+
suffix : specifies a custom suffix to identify non unique stages, instead of a number.
114+
"""
109115
pipeline_length = len(self._stages)
110116

117+
pos = -1
118+
if "pos" in kwargs:
119+
pos = kwargs.pop("pos")
120+
111121
if pos < 0:
112122
pos = pipeline_length + (pos + 1)
113-
114123
if not -pipeline_length <= pos <= pipeline_length:
115124
raise ValueError(
116125
f"Pos must be between [-{pipeline_length}, {pipeline_length}]"
117126
)
118127

128+
suffix = None
129+
if "suffix" in kwargs:
130+
suffix = kwargs.pop("suffix")
131+
132+
stage_id = func.__name__
133+
unique_stages = [
134+
"quantize",
135+
"export",
136+
"to_edge_transform_and_lower",
137+
"to_edge",
138+
"partition",
139+
"to_executorch",
140+
"serialize",
141+
]
142+
id_list = [stage.id for stage in self._stages]
143+
if stage_id in unique_stages:
144+
if stage_id in id_list:
145+
raise RuntimeError(f"Tried adding {stage_id} to pipeline twice.")
146+
else:
147+
if suffix is None:
148+
stages_containing_stage_id = [
149+
id for id in id_list if stage_id == id.split(".")[0]
150+
]
151+
152+
suffix = str(len(stages_containing_stage_id))
153+
154+
stage_id = stage_id + "." + suffix
155+
156+
if stage_id in id_list:
157+
raise ValueError("Suffix must be unique in pipeline")
158+
159+
pipeline_stage = self.PipelineStage(func, stage_id, *args, **kwargs)
119160
self._stages.insert(pos, pipeline_stage)
120161

121-
logger.debug(f"Added stage {func.__name__} to {type(self).__name__}")
162+
logger.debug(f"Added stage {stage_id} to {type(self).__name__}")
122163

123164
return self
124165

125-
def pop_stage(self, pos: int):
166+
def pop_stage(self, identifier: int | str):
126167
"""Removes and returns the stage at postion pos"""
127-
return self._stages.pop(pos)
168+
if isinstance(identifier, int):
169+
stage = self._stages.pop(identifier)
170+
elif isinstance(identifier, str):
171+
pos = self.find_pos(identifier)
172+
stage = self._stages.pop(pos)
173+
174+
logger.debug(f"Removed stage {stage.id} from {type(self).__name__}")
175+
176+
return stage
128177

129178
def find_pos(self, stage_id: str):
130-
"""Returns the position of the stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
179+
"""Returns the position of the stage id."""
131180
for i, stage in enumerate(self._stages):
132181
if stage.id == stage_id:
133182
return i
134183

135184
raise Exception(f"Stage id {stage_id} not found in pipeline")
136185

137186
def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs):
138-
"""Adds a stage after the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
139-
pos = self.find_pos(stage_id)
140-
self.add_stage(pos + 1, func, *args, **kwargs)
187+
"""Adds a stage after the given stage id."""
188+
pos = self.find_pos(stage_id) + 1
189+
kwargs["pos"] = pos
190+
191+
self.add_stage(func, *args, **kwargs)
192+
return self
193+
194+
def dump_artifact(self, stage_id: str, suffix: str = None):
195+
"""Adds a dump_artifact stage after the given stage id."""
196+
self.add_stage_after(stage_id, self.tester.dump_artifact, suffix=suffix)
141197
return self
142198

143-
def dump_artifact(self, stage_id: str):
144-
"""Adds a dump_artifact stage after the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
145-
self.add_stage_after(stage_id, self.tester.dump_artifact)
199+
def dump_operator_distribution(self, stage_id: str, suffix: str = None):
200+
"""Adds a dump_operator_distribution stage after the given stage id."""
201+
self.add_stage_after(
202+
stage_id, self.tester.dump_operator_distribution, suffix=suffix
203+
)
146204
return self
147205

148-
def dump_operator_distribution(self, stage_id: str):
149-
"""Adds a dump_operator_distribution stage after the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
150-
self.add_stage_after(stage_id, self.tester.dump_operator_distribution)
206+
def visualize(self, stage_id: str, suffix: str = None):
207+
"""Adds a dump_operator_distribution stage after the given stage id."""
208+
self.add_stage_after(stage_id, self.tester.visualize, suffix=suffix)
151209
return self
152210

153211
def change_args(self, stage_id: str, *args, **kwargs):
154-
"""Updates the args to the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
212+
"""Updates the args to the given stage id."""
155213
pos = self.find_pos(stage_id)
156214
pipeline_stage = self._stages[pos]
157215
pipeline_stage.update(*args, **kwargs)
@@ -193,14 +251,15 @@ def __init__(
193251
compile_spec,
194252
use_to_edge_transform_and_lower,
195253
)
196-
self.add_stage(0, self.tester.quantize)
254+
self.add_stage(self.tester.quantize, pos=0)
197255
self.add_stage_after(
198256
"quantize",
199257
self.tester.check,
200258
[
201259
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
202260
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
203261
],
262+
suffix="quant_nodes",
204263
)
205264

206265
remove_quant_nodes_stage = (
@@ -215,10 +274,11 @@ def __init__(
215274
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
216275
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
217276
],
277+
suffix="quant_nodes",
218278
)
219279

220280
self.add_stage(
221-
-1, self.tester.run_method_and_compare_outputs, inputs=self.test_data
281+
self.tester.run_method_and_compare_outputs, inputs=self.test_data
222282
)
223283

224284

@@ -252,10 +312,11 @@ def __init__(
252312
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
253313
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
254314
],
315+
suffix="quant_nodes",
255316
)
256317

257318
self.add_stage(
258-
-1, self.tester.run_method_and_compare_outputs, inputs=self.test_data
319+
self.tester.run_method_and_compare_outputs, inputs=self.test_data
259320
)
260321

261322

@@ -280,14 +341,15 @@ def __init__(
280341
compile_spec,
281342
use_to_edge_transform_and_lower,
282343
)
283-
self.add_stage(0, self.tester.quantize)
344+
self.add_stage(self.tester.quantize, pos=0)
284345
self.add_stage_after(
285346
"quantize",
286347
self.tester.check,
287348
[
288349
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
289350
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
290351
],
352+
suffix="quant_nodes",
291353
)
292354

293355
remove_quant_nodes_stage = (
@@ -302,12 +364,12 @@ def __init__(
302364
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
303365
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
304366
],
367+
suffix="quant_nodes",
305368
)
306369

307370
if run_on_fvp:
308-
self.add_stage(-1, self.tester.serialize)
371+
self.add_stage(self.tester.serialize)
309372
self.add_stage(
310-
-1,
311373
self.tester.run_method_and_compare_outputs,
312374
qtol=1,
313375
inputs=self.test_data,
@@ -335,14 +397,15 @@ def __init__(
335397
compile_spec,
336398
use_to_edge_transform_and_lower,
337399
)
338-
self.add_stage(0, self.tester.quantize)
400+
self.add_stage(self.tester.quantize, pos=0)
339401
self.add_stage_after(
340402
"quantize",
341403
self.tester.check,
342404
[
343405
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
344406
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
345407
],
408+
suffix="quant_nodes",
346409
)
347410

348411
remove_quant_nodes_stage = (
@@ -357,12 +420,12 @@ def __init__(
357420
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
358421
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
359422
],
423+
suffix="quant_nodes",
360424
)
361425

362426
if run_on_fvp:
363-
self.add_stage(-1, self.tester.serialize)
427+
self.add_stage(self.tester.serialize)
364428
self.add_stage(
365-
-1,
366429
self.tester.run_method_and_compare_outputs,
367430
qtol=1,
368431
inputs=self.test_data,

0 commit comments

Comments
 (0)