Skip to content

Commit b3f192c

Browse files
committed
Refactor XNNPACK tester to extract base classes (#11596)
Summary: Refactor the XNNPACK tester to split out reusable base components from XNNPACK-specific parts. I've relocated the base classes to backends/test/harness. I've kept the tester structure pretty much unchanged, except for replacing stage names with an enum. It looks like Arm tests are directly importing for XNNPACK's tester currently. Ideally, we'll want to refactor to have their own stage implementations, but I've left that as a follow-up to minimize changes for the initial refactor. Pull Request resolved: #11596 Test Plan: CI Rollback Plan: Differential Revision: D76547310 Pulled By: GregoryComer
1 parent 7503bb3 commit b3f192c

26 files changed

+1076
-795
lines changed

backends/__init__.py

Whitespace-only changes.

backends/arm/test/misc/test_lifted_tensor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
TosaPipelineBI,
1313
TosaPipelineMI,
1414
)
15-
from executorch.backends.xnnpack.test.tester import ToEdge
15+
from executorch.backends.test.harness.stages import StageType
1616

1717

1818
input_t1 = Tuple[torch.Tensor]
@@ -72,9 +72,8 @@ def test_partition_lifted_tensor_tosa_MI(test_data: input_t1):
7272
use_to_edge_transform_and_lower=False,
7373
)
7474
pipeline.run()
75-
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
7675
signature = (
77-
pipeline.tester.stages[to_edge_stage_name]
76+
pipeline.tester.stages[StageType.TO_EDGE]
7877
.artifact.exported_program()
7978
.graph_signature
8079
)
@@ -94,9 +93,8 @@ def test_partition_lifted_tensor_tosa_BI(test_data: input_t1):
9493
use_to_edge_transform_and_lower=False,
9594
)
9695
pipeline.run()
97-
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
9896
signature = (
99-
pipeline.tester.stages[to_edge_stage_name]
97+
pipeline.tester.stages[StageType.TO_EDGE]
10098
.artifact.exported_program()
10199
.graph_signature
102100
)

backends/arm/test/passes/test_cast_int64_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from executorch.backends.arm.test import common
1212
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1313

14+
from executorch.backends.test.harness.stages import StageType
15+
1416
input_t = Tuple[torch.Tensor] # Input x
1517

1618

@@ -40,6 +42,8 @@ def test_int64_model(test_data: input_t):
4042
)
4143
pipeline.run()
4244

43-
exported_program = pipeline.tester.get_artifact("RunPasses").exported_program()
45+
exported_program = pipeline.tester.get_artifact(
46+
StageType.RUN_PASSES
47+
).exported_program()
4448
for state in exported_program.state_dict:
4549
assert exported_program.state_dict[state].dtype == torch.int32

backends/arm/test/quantizer/test_generic_annotater.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from executorch.backends.arm.quantizer import is_annotated
1111
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI
12+
from executorch.backends.test.harness.stages import StageType
1213

1314
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1415

@@ -36,7 +37,7 @@ def check_annotation(model):
3637
pipeline.pop_stage("run_method_and_compare_outputs")
3738
pipeline.run()
3839

39-
artifact = pipeline.tester.get_artifact("Quantize")
40+
artifact = pipeline.tester.get_artifact(StageType.QUANTIZE)
4041

4142
partitions = get_source_partitions(artifact.graph, [model.op])
4243
partitions = list(itertools.chain.from_iterable(partitions.values()))

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
get_output_quantization_params,
1515
)
1616

17-
from executorch.backends.xnnpack.test.tester.tester import Export, Quantize
17+
from executorch.backends.test.harness.stages import StageType
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -238,8 +238,8 @@ def dump_error_output(
238238
if path_to_tosa_files is None:
239239
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
240240

241-
export_stage = tester.stages.get(tester.stage_name(Export), None)
242-
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
241+
export_stage = tester.stages.get(StageType.EXPORT, None)
242+
quantize_stage = tester.stages.get(StageType.QUANTIZE, None)
243243
if export_stage is not None and quantize_stage is not None:
244244
output_nodes = get_output_nodes(export_stage.artifact)
245245
qp_input = get_input_quantization_params(export_stage.artifact)

backends/arm/test/tester/arm_tester.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
6262
from executorch.backends.arm.tosa_specification import TosaSpecification
6363

64+
from executorch.backends.test.harness.stages import Stage, StageType
6465
from executorch.backends.xnnpack.test.tester import Tester
6566
from executorch.devtools.backend_debug import get_delegation_info
6667

@@ -259,10 +260,13 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram:
259260
super().run(artifact, inputs)
260261

261262

262-
class InitialModel(tester.Stage):
263+
class InitialModel(Stage):
263264
def __init__(self, model: torch.nn.Module):
264265
self.model = model
265266

267+
def stage_type(self) -> StageType:
268+
return StageType.INITIAL_MODEL
269+
266270
def run(self, artifact, inputs=None) -> None:
267271
pass
268272

@@ -305,13 +309,13 @@ def __init__(
305309
self.constant_methods = constant_methods
306310
self.compile_spec = compile_spec
307311
super().__init__(model, example_inputs, dynamic_shapes)
308-
self.pipeline[self.stage_name(InitialModel)] = [
309-
self.stage_name(tester.Quantize),
310-
self.stage_name(tester.Export),
312+
self.pipeline[StageType.INITIAL_MODEL] = [
313+
StageType.QUANTIZE,
314+
StageType.EXPORT,
311315
]
312316

313317
# Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
314-
self.stages[self.stage_name(InitialModel)] = None
318+
self.stages[StageType.INITIAL_MODEL] = None
315319
self._run_stage(InitialModel(self.original_module))
316320

317321
def quantize(
@@ -413,7 +417,7 @@ def serialize(
413417
return super().serialize(serialize_stage)
414418

415419
def is_quantized(self) -> bool:
416-
return self.stages[self.stage_name(tester.Quantize)] is not None
420+
return self.stages[StageType.QUANTIZE] is not None
417421

418422
def run_method_and_compare_outputs(
419423
self,
@@ -442,18 +446,16 @@ def run_method_and_compare_outputs(
442446
"""
443447

444448
if not run_eager_mode:
445-
edge_stage = self.stages[self.stage_name(tester.ToEdge)]
449+
edge_stage = self.stages[StageType.TO_EDGE]
446450
if edge_stage is None:
447-
edge_stage = self.stages[
448-
self.stage_name(tester.ToEdgeTransformAndLower)
449-
]
451+
edge_stage = self.stages[StageType.TO_EDGE_TRANSFORM_AND_LOWER]
450452
assert (
451453
edge_stage is not None
452454
), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
453455
else:
454456
# Run models in eager mode. We do this when we want to check that the passes
455457
# are numerically accurate and the exported graph is correct.
456-
export_stage = self.stages[self.stage_name(tester.Export)]
458+
export_stage = self.stages[StageType.EXPORT]
457459
assert (
458460
export_stage is not None
459461
), "To compare outputs in eager mode, the model must be at Export stage"
@@ -463,11 +465,11 @@ def run_method_and_compare_outputs(
463465
is_quantized = self.is_quantized()
464466

465467
if is_quantized:
466-
reference_stage = self.stages[self.stage_name(tester.Quantize)]
468+
reference_stage = self.stages[StageType.QUANTIZE]
467469
else:
468-
reference_stage = self.stages[self.stage_name(InitialModel)]
470+
reference_stage = self.stages[StageType.INITIAL_MODEL]
469471

470-
exported_program = self.stages[self.stage_name(tester.Export)].artifact
472+
exported_program = self.stages[StageType.EXPORT].artifact
471473
output_nodes = get_output_nodes(exported_program)
472474

473475
output_qparams = get_output_quantization_params(output_nodes)
@@ -477,7 +479,7 @@ def run_method_and_compare_outputs(
477479
quantization_scales.append(getattr(output_qparams[node], "scale", None))
478480

479481
logger.info(
480-
f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'"
482+
f"Comparing Stage '{test_stage.stage_type()}' with Stage '{reference_stage.stage_type()}'"
481483
)
482484

483485
# Loop inputs and compare reference stage with the compared stage.
@@ -528,14 +530,12 @@ def get_graph(self, stage: str | None = None) -> Graph:
528530
stage = self.cur
529531
artifact = self.get_artifact(stage)
530532
if (
531-
self.cur == self.stage_name(tester.ToEdge)
532-
or self.cur == self.stage_name(Partition)
533-
or self.cur == self.stage_name(ToEdgeTransformAndLower)
533+
self.cur == StageType.TO_EDGE
534+
or self.cur == StageType.PARTITION
535+
or self.cur == StageType.TO_EDGE_TRANSFORM_AND_LOWER
534536
):
535537
graph = artifact.exported_program().graph
536-
elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name(
537-
tester.Quantize
538-
):
538+
elif self.cur == StageType.EXPORT or self.cur == StageType.QUANTIZE:
539539
graph = artifact.graph
540540
else:
541541
raise RuntimeError(
@@ -556,13 +556,13 @@ def dump_operator_distribution(
556556
Returns self for daisy-chaining.
557557
"""
558558
line = "#" * 10
559-
to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n"
559+
to_print = f"{line} {self.cur} Operator Distribution {line}\n"
560560

561561
if (
562562
self.cur
563563
in (
564-
self.stage_name(tester.Partition),
565-
self.stage_name(ToEdgeTransformAndLower),
564+
StageType.PARTITION,
565+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
566566
)
567567
and print_table
568568
):
@@ -602,9 +602,7 @@ def dump_dtype_distribution(
602602
"""
603603

604604
line = "#" * 10
605-
to_print = (
606-
f"{line} {self.cur.capitalize()} Placeholder Dtype Distribution {line}\n"
607-
)
605+
to_print = f"{line} {self.cur} Placeholder Dtype Distribution {line}\n"
608606

609607
graph = self.get_graph(self.cur)
610608
tosa_spec = get_tosa_spec(self.compile_spec)
@@ -653,7 +651,7 @@ def run_transform_for_annotation_pipeline(
653651
stage = self.cur
654652
# We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
655653
artifact = self.get_artifact(stage)
656-
if self.cur == self.stage_name(tester.Export):
654+
if self.cur == StageType.EXPORT:
657655
new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type]
658656
graph_module=artifact.graph_module
659657
)

backends/test/harness/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "tester",
7+
srcs = [
8+
"__init__.py",
9+
"tester.py",
10+
] + native.glob(["stages/*.py"]),
11+
visibility = [
12+
"//executorch/...",
13+
"@EXECUTORCH_CLIENTS",
14+
],
15+
deps = [
16+
"//executorch/exir:graph_module",
17+
],
18+
)

backends/test/harness/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .tester import Tester
2+
3+
__all__ = ["Tester"]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from .export import Export
2+
from .partition import Partition
3+
from .quantize import Quantize
4+
from .run_passes import RunPasses
5+
from .serialize import Serialize
6+
from .stage import Stage, StageType
7+
from .to_edge import ToEdge
8+
from .to_edge_transform_and_lower import ToEdgeTransformAndLower
9+
from .to_executorch import ToExecutorch
10+
11+
__all__ = [
12+
"Export",
13+
"Partition",
14+
"Quantize",
15+
"RunPasses",
16+
"Serialize",
17+
"Stage",
18+
"StageType",
19+
"ToEdge",
20+
"ToEdgeTransformAndLower",
21+
"ToExecutorch",
22+
]
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Any, Optional, Tuple
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from torch.export import export, ExportedProgram
7+
8+
9+
class Export(Stage):
10+
def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None):
11+
self.exported_program = None
12+
self.dynamic_shapes = dynamic_shapes
13+
14+
def stage_type(self) -> StageType:
15+
return StageType.EXPORT
16+
17+
def run(
18+
self,
19+
artifact: torch.nn.Module,
20+
inputs: Tuple[torch.Tensor],
21+
) -> None:
22+
self.exported_program = export(
23+
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
24+
)
25+
26+
@property
27+
def artifact(self) -> ExportedProgram:
28+
return self.exported_program
29+
30+
@property
31+
def graph_module(self) -> str:
32+
return self.exported_program.graph_module
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from executorch.backends.test.harness.stages.stage import Stage, StageType
2+
from executorch.exir import EdgeProgramManager
3+
from executorch.exir.backend.backend_api import validation_disabled
4+
from executorch.exir.backend.partitioner import Partitioner
5+
6+
7+
class Partition(Stage):
8+
def __init__(self, partitioner: Partitioner):
9+
self.partitioner = partitioner
10+
self.delegate_module = None
11+
12+
def stage_type(self) -> StageType:
13+
return StageType.PARTITION
14+
15+
def run(self, artifact: EdgeProgramManager, inputs=None):
16+
with validation_disabled():
17+
self.delegate_module = artifact
18+
self.delegate_module = self.delegate_module.to_backend(self.partitioner)
19+
20+
@property
21+
def artifact(self) -> EdgeProgramManager:
22+
return self.delegate_module
23+
24+
@property
25+
def graph_module(self) -> str:
26+
return self.delegate_module.exported_program().graph_module

0 commit comments

Comments
 (0)