Skip to content

Refactor XNNPACK tester to extract delegate-independent tester classes #11596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added backends/__init__.py
Empty file.
8 changes: 3 additions & 5 deletions backends/arm/test/misc/test_lifted_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
TosaPipelineBI,
TosaPipelineMI,
)
from executorch.backends.xnnpack.test.tester import ToEdge
from executorch.backends.test.harness.stages import StageType


input_t1 = Tuple[torch.Tensor]
Expand Down Expand Up @@ -72,9 +72,8 @@ def test_partition_lifted_tensor_tosa_MI(test_data: input_t1):
use_to_edge_transform_and_lower=False,
)
pipeline.run()
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
signature = (
pipeline.tester.stages[to_edge_stage_name]
pipeline.tester.stages[StageType.TO_EDGE]
.artifact.exported_program()
.graph_signature
)
Expand All @@ -94,9 +93,8 @@ def test_partition_lifted_tensor_tosa_BI(test_data: input_t1):
use_to_edge_transform_and_lower=False,
)
pipeline.run()
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
signature = (
pipeline.tester.stages[to_edge_stage_name]
pipeline.tester.stages[StageType.TO_EDGE]
.artifact.exported_program()
.graph_signature
)
Expand Down
6 changes: 5 additions & 1 deletion backends/arm/test/passes/test_cast_int64_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline

from executorch.backends.test.harness.stages import StageType

input_t = Tuple[torch.Tensor] # Input x


Expand Down Expand Up @@ -40,6 +42,8 @@ def test_int64_model(test_data: input_t):
)
pipeline.run()

exported_program = pipeline.tester.get_artifact("RunPasses").exported_program()
exported_program = pipeline.tester.get_artifact(
StageType.RUN_PASSES
).exported_program()
for state in exported_program.state_dict:
assert exported_program.state_dict[state].dtype == torch.int32
3 changes: 2 additions & 1 deletion backends/arm/test/quantizer/test_generic_annotater.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from executorch.backends.arm.quantizer import is_annotated
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI
from executorch.backends.test.harness.stages import StageType

from torch.fx.passes.utils.source_matcher_utils import get_source_partitions

Expand Down Expand Up @@ -36,7 +37,7 @@ def check_annotation(model):
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()

artifact = pipeline.tester.get_artifact("Quantize")
artifact = pipeline.tester.get_artifact(StageType.QUANTIZE)

partitions = get_source_partitions(artifact.graph, [model.op])
partitions = list(itertools.chain.from_iterable(partitions.values()))
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/test/tester/analyze_output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
get_output_quantization_params,
)

from executorch.backends.xnnpack.test.tester.tester import Export, Quantize
from executorch.backends.test.harness.stages import StageType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -238,8 +238,8 @@ def dump_error_output(
if path_to_tosa_files is None:
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")

export_stage = tester.stages.get(tester.stage_name(Export), None)
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
export_stage = tester.stages.get(StageType.EXPORT, None)
quantize_stage = tester.stages.get(StageType.QUANTIZE, None)
if export_stage is not None and quantize_stage is not None:
output_nodes = get_output_nodes(export_stage.artifact)
qp_input = get_input_quantization_params(export_stage.artifact)
Expand Down
54 changes: 26 additions & 28 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
from executorch.backends.arm.tosa_specification import TosaSpecification

from executorch.backends.test.harness.stages import Stage, StageType
from executorch.backends.xnnpack.test.tester import Tester
from executorch.devtools.backend_debug import get_delegation_info

Expand Down Expand Up @@ -259,10 +260,13 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram:
super().run(artifact, inputs)


class InitialModel(tester.Stage):
class InitialModel(Stage):
def __init__(self, model: torch.nn.Module):
self.model = model

def stage_type(self) -> StageType:
return StageType.INITIAL_MODEL

def run(self, artifact, inputs=None) -> None:
pass

Expand Down Expand Up @@ -305,13 +309,13 @@ def __init__(
self.constant_methods = constant_methods
self.compile_spec = compile_spec
super().__init__(model, example_inputs, dynamic_shapes)
self.pipeline[self.stage_name(InitialModel)] = [
self.stage_name(tester.Quantize),
self.stage_name(tester.Export),
self.pipeline[StageType.INITIAL_MODEL] = [
StageType.QUANTIZE,
StageType.EXPORT,
]

# Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
self.stages[self.stage_name(InitialModel)] = None
self.stages[StageType.INITIAL_MODEL] = None
self._run_stage(InitialModel(self.original_module))

def quantize(
Expand Down Expand Up @@ -413,7 +417,7 @@ def serialize(
return super().serialize(serialize_stage)

def is_quantized(self) -> bool:
return self.stages[self.stage_name(tester.Quantize)] is not None
return self.stages[StageType.QUANTIZE] is not None

def run_method_and_compare_outputs(
self,
Expand Down Expand Up @@ -442,18 +446,16 @@ def run_method_and_compare_outputs(
"""

if not run_eager_mode:
edge_stage = self.stages[self.stage_name(tester.ToEdge)]
edge_stage = self.stages[StageType.TO_EDGE]
if edge_stage is None:
edge_stage = self.stages[
self.stage_name(tester.ToEdgeTransformAndLower)
]
edge_stage = self.stages[StageType.TO_EDGE_TRANSFORM_AND_LOWER]
assert (
edge_stage is not None
), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
else:
# Run models in eager mode. We do this when we want to check that the passes
# are numerically accurate and the exported graph is correct.
export_stage = self.stages[self.stage_name(tester.Export)]
export_stage = self.stages[StageType.EXPORT]
assert (
export_stage is not None
), "To compare outputs in eager mode, the model must be at Export stage"
Expand All @@ -463,11 +465,11 @@ def run_method_and_compare_outputs(
is_quantized = self.is_quantized()

if is_quantized:
reference_stage = self.stages[self.stage_name(tester.Quantize)]
reference_stage = self.stages[StageType.QUANTIZE]
else:
reference_stage = self.stages[self.stage_name(InitialModel)]
reference_stage = self.stages[StageType.INITIAL_MODEL]

exported_program = self.stages[self.stage_name(tester.Export)].artifact
exported_program = self.stages[StageType.EXPORT].artifact
output_nodes = get_output_nodes(exported_program)

output_qparams = get_output_quantization_params(output_nodes)
Expand All @@ -477,7 +479,7 @@ def run_method_and_compare_outputs(
quantization_scales.append(getattr(output_qparams[node], "scale", None))

logger.info(
f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'"
f"Comparing Stage '{test_stage.stage_type()}' with Stage '{reference_stage.stage_type()}'"
)

# Loop inputs and compare reference stage with the compared stage.
Expand Down Expand Up @@ -528,14 +530,12 @@ def get_graph(self, stage: str | None = None) -> Graph:
stage = self.cur
artifact = self.get_artifact(stage)
if (
self.cur == self.stage_name(tester.ToEdge)
or self.cur == self.stage_name(Partition)
or self.cur == self.stage_name(ToEdgeTransformAndLower)
self.cur == StageType.TO_EDGE
or self.cur == StageType.PARTITION
or self.cur == StageType.TO_EDGE_TRANSFORM_AND_LOWER
):
graph = artifact.exported_program().graph
elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name(
tester.Quantize
):
elif self.cur == StageType.EXPORT or self.cur == StageType.QUANTIZE:
graph = artifact.graph
else:
raise RuntimeError(
Expand All @@ -556,13 +556,13 @@ def dump_operator_distribution(
Returns self for daisy-chaining.
"""
line = "#" * 10
to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n"
to_print = f"{line} {self.cur} Operator Distribution {line}\n"

if (
self.cur
in (
self.stage_name(tester.Partition),
self.stage_name(ToEdgeTransformAndLower),
StageType.PARTITION,
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
)
and print_table
):
Expand Down Expand Up @@ -602,9 +602,7 @@ def dump_dtype_distribution(
"""

line = "#" * 10
to_print = (
f"{line} {self.cur.capitalize()} Placeholder Dtype Distribution {line}\n"
)
to_print = f"{line} {self.cur} Placeholder Dtype Distribution {line}\n"

graph = self.get_graph(self.cur)
tosa_spec = get_tosa_spec(self.compile_spec)
Expand Down Expand Up @@ -653,7 +651,7 @@ def run_transform_for_annotation_pipeline(
stage = self.cur
# We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
artifact = self.get_artifact(stage)
if self.cur == self.stage_name(tester.Export):
if self.cur == StageType.EXPORT:
new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type]
graph_module=artifact.graph_module
)
Expand Down
18 changes: 18 additions & 0 deletions backends/test/harness/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "tester",
srcs = [
"__init__.py",
"tester.py",
] + native.glob(["stages/*.py"]),
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//executorch/exir:graph_module",
],
)
3 changes: 3 additions & 0 deletions backends/test/harness/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .tester import Tester

__all__ = ["Tester"]
22 changes: 22 additions & 0 deletions backends/test/harness/stages/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .export import Export
from .partition import Partition
from .quantize import Quantize
from .run_passes import RunPasses
from .serialize import Serialize
from .stage import Stage, StageType
from .to_edge import ToEdge
from .to_edge_transform_and_lower import ToEdgeTransformAndLower
from .to_executorch import ToExecutorch

__all__ = [
"Export",
"Partition",
"Quantize",
"RunPasses",
"Serialize",
"Stage",
"StageType",
"ToEdge",
"ToEdgeTransformAndLower",
"ToExecutorch",
]
32 changes: 32 additions & 0 deletions backends/test/harness/stages/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any, Optional, Tuple

import torch

from executorch.backends.test.harness.stages.stage import Stage, StageType
from torch.export import export, ExportedProgram


class Export(Stage):
def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None):
self.exported_program = None
self.dynamic_shapes = dynamic_shapes

def stage_type(self) -> StageType:
return StageType.EXPORT

def run(
self,
artifact: torch.nn.Module,
inputs: Tuple[torch.Tensor],
) -> None:
self.exported_program = export(
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
)

@property
def artifact(self) -> ExportedProgram:
return self.exported_program

@property
def graph_module(self) -> str:
return self.exported_program.graph_module
26 changes: 26 additions & 0 deletions backends/test/harness/stages/partition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from executorch.backends.test.harness.stages.stage import Stage, StageType
from executorch.exir import EdgeProgramManager
from executorch.exir.backend.backend_api import validation_disabled
from executorch.exir.backend.partitioner import Partitioner


class Partition(Stage):
def __init__(self, partitioner: Partitioner):
self.partitioner = partitioner
self.delegate_module = None

def stage_type(self) -> StageType:
return StageType.PARTITION

def run(self, artifact: EdgeProgramManager, inputs=None):
with validation_disabled():
self.delegate_module = artifact
self.delegate_module = self.delegate_module.to_backend(self.partitioner)

@property
def artifact(self) -> EdgeProgramManager:
return self.delegate_module

@property
def graph_module(self) -> str:
return self.delegate_module.exported_program().graph_module
Loading
Loading