Skip to content

Commit 39c729a

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
(WIP) Refactor XNNPACK tester to extract base class (#11596)
Summary: Refactor the XNNPACK tester to split out reusable base components from XNNPACK-specific parts. I've relocated the base classes to backends/test/harness. I've kept the tester structure pretty much unchanged, except for replacing stage names with an enum. It looks like Arm tests are directly importing for XNNPACK's tester currently. Ideally, we'll want to refactor to have their own stage implementations, but I've left that as a follow-up to minimize changes for the initial refactor. Pull Request resolved: #11596 Test Plan: CI Rollback Plan: Differential Revision: D76547310 Pulled By: GregoryComer
1 parent 057558f commit 39c729a

23 files changed

+1061
-792
lines changed

backends/arm/test/misc/test_lifted_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TosaPipelineBI,
1313
TosaPipelineMI,
1414
)
15+
from executorch.backends.test.harness.stages import StageType
1516
from executorch.backends.xnnpack.test.tester import ToEdge
1617

1718

@@ -72,9 +73,8 @@ def test_partition_lifted_tensor_tosa_MI(test_data: input_t1):
7273
use_to_edge_transform_and_lower=False,
7374
)
7475
pipeline.run()
75-
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
7676
signature = (
77-
pipeline.tester.stages[to_edge_stage_name]
77+
pipeline.tester.stages[StageType.TO_EDGE]
7878
.artifact.exported_program()
7979
.graph_signature
8080
)
@@ -94,9 +94,8 @@ def test_partition_lifted_tensor_tosa_BI(test_data: input_t1):
9494
use_to_edge_transform_and_lower=False,
9595
)
9696
pipeline.run()
97-
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
9897
signature = (
99-
pipeline.tester.stages[to_edge_stage_name]
98+
pipeline.tester.stages[StageType.TO_EDGE]
10099
.artifact.exported_program()
101100
.graph_signature
102101
)

backends/arm/test/passes/test_cast_int64_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from executorch.backends.arm.test import common
1212
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1313

14+
from executorch.backends.test.harness.stages import StageType
15+
1416
input_t = Tuple[torch.Tensor] # Input x
1517

1618

@@ -40,6 +42,6 @@ def test_int64_model(test_data: input_t):
4042
)
4143
pipeline.run()
4244

43-
exported_program = pipeline.tester.get_artifact("RunPasses").exported_program()
45+
exported_program = pipeline.tester.get_artifact(StageType.RUN_PASSES).exported_program()
4446
for state in exported_program.state_dict:
4547
assert exported_program.state_dict[state].dtype == torch.int32

backends/arm/test/quantizer/test_generic_annotater.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from executorch.backends.arm.quantizer import is_annotated
1111
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI
12+
from executorch.backends.test.harness.stages import StageType
1213

1314
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1415

@@ -36,7 +37,7 @@ def check_annotation(model):
3637
pipeline.pop_stage("run_method_and_compare_outputs")
3738
pipeline.run()
3839

39-
artifact = pipeline.tester.get_artifact("Quantize")
40+
artifact = pipeline.tester.get_artifact(StageType.QUANTIZE)
4041

4142
partitions = get_source_partitions(artifact.graph, [model.op])
4243
partitions = list(itertools.chain.from_iterable(partitions.values()))

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_output_quantization_params,
1515
)
1616

17+
from executorch.backends.test.harness.stages import StageType
1718
from executorch.backends.xnnpack.test.tester.tester import Export, Quantize
1819

1920
logger = logging.getLogger(__name__)
@@ -238,8 +239,8 @@ def dump_error_output(
238239
if path_to_tosa_files is None:
239240
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
240241

241-
export_stage = tester.stages.get(tester.stage_name(Export), None)
242-
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
242+
export_stage = tester.stages.get(StageType.EXPORT, None)
243+
quantize_stage = tester.stages.get(StageType.QUANTIZE, None)
243244
if export_stage is not None and quantize_stage is not None:
244245
output_nodes = get_output_nodes(export_stage.artifact)
245246
qp_input = get_input_quantization_params(export_stage.artifact)

backends/arm/test/tester/arm_tester.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
6262
from executorch.backends.arm.tosa_specification import TosaSpecification
6363

64+
from executorch.backends.test.harness.stages import Stage, StageType
6465
from executorch.backends.xnnpack.test.tester import Tester
6566
from executorch.devtools.backend_debug import get_delegation_info
6667

@@ -259,9 +260,12 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram:
259260
super().run(artifact, inputs)
260261

261262

262-
class InitialModel(tester.Stage):
263+
class InitialModel(Stage):
263264
def __init__(self, model: torch.nn.Module):
264265
self.model = model
266+
267+
def stage_type(self) -> StageType:
268+
return StageType.INITIAL_MODEL
265269

266270
def run(self, artifact, inputs=None) -> None:
267271
pass
@@ -305,13 +309,13 @@ def __init__(
305309
self.constant_methods = constant_methods
306310
self.compile_spec = compile_spec
307311
super().__init__(model, example_inputs, dynamic_shapes)
308-
self.pipeline[self.stage_name(InitialModel)] = [
309-
self.stage_name(tester.Quantize),
310-
self.stage_name(tester.Export),
312+
self.pipeline[StageType.INITIAL_MODEL] = [
313+
StageType.QUANTIZE,
314+
StageType.EXPORT,
311315
]
312316

313317
# Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
314-
self.stages[self.stage_name(InitialModel)] = None
318+
self.stages[StageType.INITIAL_MODEL] = None
315319
self._run_stage(InitialModel(self.original_module))
316320

317321
def quantize(self, quantize_stage: Optional[tester.Quantize] = None):
@@ -410,7 +414,7 @@ def serialize(
410414
return super().serialize(serialize_stage)
411415

412416
def is_quantized(self) -> bool:
413-
return self.stages[self.stage_name(tester.Quantize)] is not None
417+
return self.stages[StageType.QUANTIZE] is not None
414418

415419
def run_method_and_compare_outputs(
416420
self,
@@ -439,18 +443,16 @@ def run_method_and_compare_outputs(
439443
"""
440444

441445
if not run_eager_mode:
442-
edge_stage = self.stages[self.stage_name(tester.ToEdge)]
446+
edge_stage = self.stages[StageType.TO_EDGE]
443447
if edge_stage is None:
444-
edge_stage = self.stages[
445-
self.stage_name(tester.ToEdgeTransformAndLower)
446-
]
448+
edge_stage = self.stages[StageType.TO_EDGE_TRANSFORM_AND_LOWER]
447449
assert (
448450
edge_stage is not None
449451
), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
450452
else:
451453
# Run models in eager mode. We do this when we want to check that the passes
452454
# are numerically accurate and the exported graph is correct.
453-
export_stage = self.stages[self.stage_name(tester.Export)]
455+
export_stage = self.stages[StageType.EXPORT]
454456
assert (
455457
export_stage is not None
456458
), "To compare outputs in eager mode, the model must be at Export stage"
@@ -460,11 +462,11 @@ def run_method_and_compare_outputs(
460462
is_quantized = self.is_quantized()
461463

462464
if is_quantized:
463-
reference_stage = self.stages[self.stage_name(tester.Quantize)]
465+
reference_stage = self.stages[StageType.QUANTIZE]
464466
else:
465-
reference_stage = self.stages[self.stage_name(InitialModel)]
467+
reference_stage = self.stages[StageType.INITIAL_MODEL]
466468

467-
exported_program = self.stages[self.stage_name(tester.Export)].artifact
469+
exported_program = self.stages[StageType.EXPORT].artifact
468470
output_nodes = get_output_nodes(exported_program)
469471

470472
output_qparams = get_output_quantization_params(output_nodes)
@@ -474,7 +476,7 @@ def run_method_and_compare_outputs(
474476
quantization_scales.append(getattr(output_qparams[node], "scale", None))
475477

476478
logger.info(
477-
f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'"
479+
f"Comparing Stage '{test_stage.stage_type()}' with Stage '{reference_stage.stage_type()}'"
478480
)
479481

480482
# Loop inputs and compare reference stage with the compared stage.
@@ -525,14 +527,12 @@ def get_graph(self, stage: str | None = None) -> Graph:
525527
stage = self.cur
526528
artifact = self.get_artifact(stage)
527529
if (
528-
self.cur == self.stage_name(tester.ToEdge)
529-
or self.cur == self.stage_name(Partition)
530-
or self.cur == self.stage_name(ToEdgeTransformAndLower)
530+
self.cur == StageType.TO_EDGE
531+
or self.cur == StageType.PARTITION
532+
or self.cur == StageType.TO_EDGE_TRANSFORM_AND_LOWER
531533
):
532534
graph = artifact.exported_program().graph
533-
elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name(
534-
tester.Quantize
535-
):
535+
elif self.cur == StageType.EXPORT or self.cur == StageType.QUANTIZE:
536536
graph = artifact.graph
537537
else:
538538
raise RuntimeError(
@@ -553,13 +553,13 @@ def dump_operator_distribution(
553553
Returns self for daisy-chaining.
554554
"""
555555
line = "#" * 10
556-
to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n"
556+
to_print = f"{line} {self.cur} Operator Distribution {line}\n"
557557

558558
if (
559559
self.cur
560560
in (
561-
self.stage_name(tester.Partition),
562-
self.stage_name(ToEdgeTransformAndLower),
561+
StageType.PARTITION,
562+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
563563
)
564564
and print_table
565565
):
@@ -600,7 +600,7 @@ def dump_dtype_distribution(
600600

601601
line = "#" * 10
602602
to_print = (
603-
f"{line} {self.cur.capitalize()} Placeholder Dtype Distribution {line}\n"
603+
f"{line} {self.cur} Placeholder Dtype Distribution {line}\n"
604604
)
605605

606606
graph = self.get_graph(self.cur)
@@ -650,7 +650,7 @@ def run_transform_for_annotation_pipeline(
650650
stage = self.cur
651651
# We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
652652
artifact = self.get_artifact(stage)
653-
if self.cur == self.stage_name(tester.Export):
653+
if self.cur == StageType.EXPORT:
654654
new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type]
655655
graph_module=artifact.graph_module
656656
)

backends/test/harness/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "tester",
7+
srcs = [
8+
"__init__.py",
9+
"tester.py",
10+
] + native.glob(["stages/*.py"]),
11+
visibility = [
12+
"//executorch/...",
13+
"@EXECUTORCH_CLIENTS",
14+
],
15+
deps = [
16+
"//executorch/exir:graph_module",
17+
],
18+
)

backends/test/harness/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .tester import Tester
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .export import Export
2+
from .partition import Partition
3+
from .quantize import Quantize
4+
from .run_passes import RunPasses
5+
from .serialize import Serialize
6+
from .stage import Stage, StageType
7+
from .to_edge import ToEdge
8+
from .to_edge_transform_and_lower import ToEdgeTransformAndLower
9+
from .to_executorch import ToExecutorch
10+
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Any, Optional, Sequence, Tuple
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from torch.export import export, ExportedProgram
7+
8+
class Export(Stage):
9+
def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None):
10+
self.exported_program = None
11+
self.dynamic_shapes = dynamic_shapes
12+
13+
def stage_type(self) -> StageType:
14+
return StageType.EXPORT
15+
16+
def run(
17+
self,
18+
artifact: torch.nn.Module,
19+
inputs: Tuple[torch.Tensor],
20+
) -> None:
21+
self.exported_program = export(
22+
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
23+
)
24+
25+
@property
26+
def artifact(self) -> ExportedProgram:
27+
return self.exported_program
28+
29+
@property
30+
def graph_module(self) -> str:
31+
return self.exported_program.graph_module
32+
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from executorch.backends.test.harness.stages.stage import Stage, StageType
2+
from executorch.exir import (
3+
EdgeProgramManager,
4+
)
5+
from executorch.exir.backend.backend_api import validation_disabled
6+
from executorch.exir.backend.partitioner import Partitioner
7+
8+
class Partition(Stage):
9+
def __init__(self, partitioner: Partitioner):
10+
self.partitioner = partitioner
11+
self.delegate_module = None
12+
13+
def stage_type(self) -> StageType:
14+
return StageType.PARTITION
15+
16+
def run(self, artifact: EdgeProgramManager, inputs=None):
17+
with validation_disabled():
18+
self.delegate_module = artifact
19+
self.delegate_module = self.delegate_module.to_backend(self.partitioner)
20+
21+
@property
22+
def artifact(self) -> EdgeProgramManager:
23+
return self.delegate_module
24+
25+
@property
26+
def graph_module(self) -> str:
27+
return self.delegate_module.exported_program().graph_module

0 commit comments

Comments
 (0)