Skip to content

Commit c17e2c6

Browse files
committed
Refactor XNNPACK tester to extract base classes (#11596)
Summary: Refactor the XNNPACK tester to split out reusable base components from XNNPACK-specific parts. I've relocated the base classes to backends/test/harness. I've kept the tester structure pretty much unchanged, except for replacing stage names with an enum. It looks like Arm tests are directly importing for XNNPACK's tester currently. Ideally, we'll want to refactor to have their own stage implementations, but I've left that as a follow-up to minimize changes for the initial refactor. Pull Request resolved: #11596 Test Plan: CI Rollback Plan: Differential Revision: D76547310 Pulled By: GregoryComer
1 parent 67b6009 commit c17e2c6

28 files changed

+1150
-795
lines changed

backends/__init__.py

Whitespace-only changes.

backends/arm/test/misc/test_lifted_tensor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
TosaPipelineBI,
1313
TosaPipelineMI,
1414
)
15-
from executorch.backends.xnnpack.test.tester import ToEdge
15+
from executorch.backends.test.harness.stages import StageType
1616

1717

1818
input_t1 = Tuple[torch.Tensor]
@@ -72,9 +72,8 @@ def test_partition_lifted_tensor_tosa_MI(test_data: input_t1):
7272
use_to_edge_transform_and_lower=False,
7373
)
7474
pipeline.run()
75-
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
7675
signature = (
77-
pipeline.tester.stages[to_edge_stage_name]
76+
pipeline.tester.stages[StageType.TO_EDGE]
7877
.artifact.exported_program()
7978
.graph_signature
8079
)
@@ -94,9 +93,8 @@ def test_partition_lifted_tensor_tosa_BI(test_data: input_t1):
9493
use_to_edge_transform_and_lower=False,
9594
)
9695
pipeline.run()
97-
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
9896
signature = (
99-
pipeline.tester.stages[to_edge_stage_name]
97+
pipeline.tester.stages[StageType.TO_EDGE]
10098
.artifact.exported_program()
10199
.graph_signature
102100
)

backends/arm/test/passes/test_cast_int64_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from executorch.backends.arm.test import common
1212
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1313

14+
from executorch.backends.test.harness.stages import StageType
15+
1416
input_t = Tuple[torch.Tensor] # Input x
1517

1618

@@ -40,6 +42,8 @@ def test_int64_model(test_data: input_t):
4042
)
4143
pipeline.run()
4244

43-
exported_program = pipeline.tester.get_artifact("RunPasses").exported_program()
45+
exported_program = pipeline.tester.get_artifact(
46+
StageType.RUN_PASSES
47+
).exported_program()
4448
for state in exported_program.state_dict:
4549
assert exported_program.state_dict[state].dtype == torch.int32

backends/arm/test/quantizer/test_generic_annotater.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from executorch.backends.arm.quantizer import is_annotated
1111
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI
12+
from executorch.backends.test.harness.stages import StageType
1213

1314
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1415

@@ -36,7 +37,7 @@ def check_annotation(model):
3637
pipeline.pop_stage("run_method_and_compare_outputs")
3738
pipeline.run()
3839

39-
artifact = pipeline.tester.get_artifact("Quantize")
40+
artifact = pipeline.tester.get_artifact(StageType.QUANTIZE)
4041

4142
partitions = get_source_partitions(artifact.graph, [model.op])
4243
partitions = list(itertools.chain.from_iterable(partitions.values()))

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
get_output_quantization_params,
1515
)
1616

17-
from executorch.backends.xnnpack.test.tester.tester import Export, Quantize
17+
from executorch.backends.test.harness.stages import StageType
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -238,8 +238,8 @@ def dump_error_output(
238238
if path_to_tosa_files is None:
239239
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
240240

241-
export_stage = tester.stages.get(tester.stage_name(Export), None)
242-
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
241+
export_stage = tester.stages.get(StageType.EXPORT, None)
242+
quantize_stage = tester.stages.get(StageType.QUANTIZE, None)
243243
if export_stage is not None and quantize_stage is not None:
244244
output_nodes = get_output_nodes(export_stage.artifact)
245245
qp_input = get_input_quantization_params(export_stage.artifact)

backends/arm/test/tester/arm_tester.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
6262
from executorch.backends.arm.tosa_specification import TosaSpecification
6363

64+
from executorch.backends.test.harness.stages import Stage, StageType
6465
from executorch.backends.xnnpack.test.tester import Tester
6566
from executorch.devtools.backend_debug import get_delegation_info
6667

@@ -259,10 +260,13 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram:
259260
super().run(artifact, inputs)
260261

261262

262-
class InitialModel(tester.Stage):
263+
class InitialModel(Stage):
263264
def __init__(self, model: torch.nn.Module):
264265
self.model = model
265266

267+
def stage_type(self) -> StageType:
268+
return StageType.INITIAL_MODEL
269+
266270
def run(self, artifact, inputs=None) -> None:
267271
pass
268272

@@ -305,13 +309,13 @@ def __init__(
305309
self.constant_methods = constant_methods
306310
self.compile_spec = compile_spec
307311
super().__init__(model, example_inputs, dynamic_shapes)
308-
self.pipeline[self.stage_name(InitialModel)] = [
309-
self.stage_name(tester.Quantize),
310-
self.stage_name(tester.Export),
312+
self.pipeline[StageType.INITIAL_MODEL] = [
313+
StageType.QUANTIZE,
314+
StageType.EXPORT,
311315
]
312316

313317
# Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
314-
self.stages[self.stage_name(InitialModel)] = None
318+
self.stages[StageType.INITIAL_MODEL] = None
315319
self._run_stage(InitialModel(self.original_module))
316320

317321
def quantize(self, quantize_stage: Optional[tester.Quantize] = None):
@@ -410,7 +414,7 @@ def serialize(
410414
return super().serialize(serialize_stage)
411415

412416
def is_quantized(self) -> bool:
413-
return self.stages[self.stage_name(tester.Quantize)] is not None
417+
return self.stages[StageType.QUANTIZE] is not None
414418

415419
def run_method_and_compare_outputs(
416420
self,
@@ -439,18 +443,16 @@ def run_method_and_compare_outputs(
439443
"""
440444

441445
if not run_eager_mode:
442-
edge_stage = self.stages[self.stage_name(tester.ToEdge)]
446+
edge_stage = self.stages[StageType.TO_EDGE]
443447
if edge_stage is None:
444-
edge_stage = self.stages[
445-
self.stage_name(tester.ToEdgeTransformAndLower)
446-
]
448+
edge_stage = self.stages[StageType.TO_EDGE_TRANSFORM_AND_LOWER]
447449
assert (
448450
edge_stage is not None
449451
), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
450452
else:
451453
# Run models in eager mode. We do this when we want to check that the passes
452454
# are numerically accurate and the exported graph is correct.
453-
export_stage = self.stages[self.stage_name(tester.Export)]
455+
export_stage = self.stages[StageType.EXPORT]
454456
assert (
455457
export_stage is not None
456458
), "To compare outputs in eager mode, the model must be at Export stage"
@@ -460,11 +462,11 @@ def run_method_and_compare_outputs(
460462
is_quantized = self.is_quantized()
461463

462464
if is_quantized:
463-
reference_stage = self.stages[self.stage_name(tester.Quantize)]
465+
reference_stage = self.stages[StageType.QUANTIZE]
464466
else:
465-
reference_stage = self.stages[self.stage_name(InitialModel)]
467+
reference_stage = self.stages[StageType.INITIAL_MODEL]
466468

467-
exported_program = self.stages[self.stage_name(tester.Export)].artifact
469+
exported_program = self.stages[StageType.EXPORT].artifact
468470
output_nodes = get_output_nodes(exported_program)
469471

470472
output_qparams = get_output_quantization_params(output_nodes)
@@ -474,7 +476,7 @@ def run_method_and_compare_outputs(
474476
quantization_scales.append(getattr(output_qparams[node], "scale", None))
475477

476478
logger.info(
477-
f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'"
479+
f"Comparing Stage '{test_stage.stage_type()}' with Stage '{reference_stage.stage_type()}'"
478480
)
479481

480482
# Loop inputs and compare reference stage with the compared stage.
@@ -525,14 +527,12 @@ def get_graph(self, stage: str | None = None) -> Graph:
525527
stage = self.cur
526528
artifact = self.get_artifact(stage)
527529
if (
528-
self.cur == self.stage_name(tester.ToEdge)
529-
or self.cur == self.stage_name(Partition)
530-
or self.cur == self.stage_name(ToEdgeTransformAndLower)
530+
self.cur == StageType.TO_EDGE
531+
or self.cur == StageType.PARTITION
532+
or self.cur == StageType.TO_EDGE_TRANSFORM_AND_LOWER
531533
):
532534
graph = artifact.exported_program().graph
533-
elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name(
534-
tester.Quantize
535-
):
535+
elif self.cur == StageType.EXPORT or self.cur == StageType.QUANTIZE:
536536
graph = artifact.graph
537537
else:
538538
raise RuntimeError(
@@ -553,13 +553,13 @@ def dump_operator_distribution(
553553
Returns self for daisy-chaining.
554554
"""
555555
line = "#" * 10
556-
to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n"
556+
to_print = f"{line} {self.cur} Operator Distribution {line}\n"
557557

558558
if (
559559
self.cur
560560
in (
561-
self.stage_name(tester.Partition),
562-
self.stage_name(ToEdgeTransformAndLower),
561+
StageType.PARTITION,
562+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
563563
)
564564
and print_table
565565
):
@@ -599,9 +599,7 @@ def dump_dtype_distribution(
599599
"""
600600

601601
line = "#" * 10
602-
to_print = (
603-
f"{line} {self.cur.capitalize()} Placeholder Dtype Distribution {line}\n"
604-
)
602+
to_print = f"{line} {self.cur} Placeholder Dtype Distribution {line}\n"
605603

606604
graph = self.get_graph(self.cur)
607605
tosa_spec = get_tosa_spec(self.compile_spec)
@@ -650,7 +648,7 @@ def run_transform_for_annotation_pipeline(
650648
stage = self.cur
651649
# We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
652650
artifact = self.get_artifact(stage)
653-
if self.cur == self.stage_name(tester.Export):
651+
if self.cur == StageType.EXPORT:
654652
new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type]
655653
graph_module=artifact.graph_module
656654
)

backends/test/harness/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "tester",
7+
srcs = [
8+
"__init__.py",
9+
"tester.py",
10+
] + native.glob(["stages/*.py"]),
11+
visibility = [
12+
"//executorch/...",
13+
"@EXECUTORCH_CLIENTS",
14+
],
15+
deps = [
16+
"//executorch/exir:graph_module",
17+
],
18+
)

backends/test/harness/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .tester import Tester
2+
3+
__all__ = ["Tester"]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from .export import Export
2+
from .partition import Partition
3+
from .quantize import Quantize
4+
from .run_passes import RunPasses
5+
from .serialize import Serialize
6+
from .stage import Stage, StageType
7+
from .to_edge import ToEdge
8+
from .to_edge_transform_and_lower import ToEdgeTransformAndLower
9+
from .to_executorch import ToExecutorch
10+
11+
__all__ = [
12+
"Export",
13+
"Partition",
14+
"Quantize",
15+
"RunPasses",
16+
"Serialize",
17+
"Stage",
18+
"StageType",
19+
"ToEdge",
20+
"ToEdgeTransformAndLower",
21+
"ToExecutorch",
22+
]
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Any, Optional, Tuple
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from torch.export import export, ExportedProgram
7+
8+
9+
class Export(Stage):
10+
def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None):
11+
self.exported_program = None
12+
self.dynamic_shapes = dynamic_shapes
13+
14+
def stage_type(self) -> StageType:
15+
return StageType.EXPORT
16+
17+
def run(
18+
self,
19+
artifact: torch.nn.Module,
20+
inputs: Tuple[torch.Tensor],
21+
) -> None:
22+
self.exported_program = export(
23+
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
24+
)
25+
26+
@property
27+
def artifact(self) -> ExportedProgram:
28+
return self.exported_program
29+
30+
@property
31+
def graph_module(self) -> str:
32+
return self.exported_program.graph_module
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from executorch.backends.test.harness.stages.stage import Stage, StageType
2+
from executorch.exir import EdgeProgramManager
3+
from executorch.exir.backend.backend_api import validation_disabled
4+
from executorch.exir.backend.partitioner import Partitioner
5+
6+
7+
class Partition(Stage):
8+
def __init__(self, partitioner: Partitioner):
9+
self.partitioner = partitioner
10+
self.delegate_module = None
11+
12+
def stage_type(self) -> StageType:
13+
return StageType.PARTITION
14+
15+
def run(self, artifact: EdgeProgramManager, inputs=None):
16+
with validation_disabled():
17+
self.delegate_module = artifact
18+
self.delegate_module = self.delegate_module.to_backend(self.partitioner)
19+
20+
@property
21+
def artifact(self) -> EdgeProgramManager:
22+
return self.delegate_module
23+
24+
@property
25+
def graph_module(self) -> str:
26+
return self.delegate_module.exported_program().graph_module

0 commit comments

Comments
 (0)