Skip to content

Commit cc8ac82

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
(WIP) Refactor XNNPACK tester to extract base class (#11596)
Summary: Refactor the XNNPACK tester to split out reusable base components from XNNPACK-specific parts. I've relocated the base classes to backends/test/harness. I've kept the tester structure pretty much unchanged, except for replacing stage names with an enum. It looks like Arm tests are directly importing for XNNPACK's tester currently. Ideally, we'll want to refactor to have their own stage implementations, but I've left that as a follow-up to minimize changes for the initial refactor. Pull Request resolved: #11596 Test Plan: CI Rollback Plan: Differential Revision: D76547310 Pulled By: GregoryComer
1 parent 045b3a5 commit cc8ac82

22 files changed

+1056
-789
lines changed

backends/arm/test/misc/test_lifted_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TosaPipelineBI,
1313
TosaPipelineMI,
1414
)
15+
from executorch.backends.test.harness.stages import StageType
1516
from executorch.backends.xnnpack.test.tester import ToEdge
1617

1718

@@ -72,9 +73,8 @@ def test_partition_lifted_tensor_tosa_MI(test_data: input_t1):
7273
use_to_edge_transform_and_lower=False,
7374
)
7475
pipeline.run()
75-
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
7676
signature = (
77-
pipeline.tester.stages[to_edge_stage_name]
77+
pipeline.tester.stages[StageType.TO_EDGE]
7878
.artifact.exported_program()
7979
.graph_signature
8080
)
@@ -94,9 +94,8 @@ def test_partition_lifted_tensor_tosa_BI(test_data: input_t1):
9494
use_to_edge_transform_and_lower=False,
9595
)
9696
pipeline.run()
97-
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
9897
signature = (
99-
pipeline.tester.stages[to_edge_stage_name]
98+
pipeline.tester.stages[StageType.TO_EDGE]
10099
.artifact.exported_program()
101100
.graph_signature
102101
)

backends/arm/test/quantizer/test_generic_annotater.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from executorch.backends.arm.quantizer import is_annotated
1111
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI
12+
from executorch.backends.test.harness.stages import StageType
1213

1314
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1415

@@ -36,7 +37,7 @@ def check_annotation(model):
3637
pipeline.pop_stage("run_method_and_compare_outputs")
3738
pipeline.run()
3839

39-
artifact = pipeline.tester.get_artifact("Quantize")
40+
artifact = pipeline.tester.get_artifact(StageType.QUANTIZE)
4041

4142
partitions = get_source_partitions(artifact.graph, [model.op])
4243
partitions = list(itertools.chain.from_iterable(partitions.values()))

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_output_quantization_params,
1515
)
1616

17+
from executorch.backends.test.harness.stages import StageType
1718
from executorch.backends.xnnpack.test.tester.tester import Export, Quantize
1819

1920
logger = logging.getLogger(__name__)
@@ -236,8 +237,8 @@ def dump_error_output(
236237
if path_to_tosa_files is None:
237238
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
238239

239-
export_stage = tester.stages.get(tester.stage_name(Export), None)
240-
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
240+
export_stage = tester.stages.get(StageType.EXPORT, None)
241+
quantize_stage = tester.stages.get(StageType.QUANTIZE, None)
241242
if export_stage is not None and quantize_stage is not None:
242243
output_nodes = get_output_nodes(export_stage.artifact)
243244
qp_input = get_input_quantization_params(export_stage.artifact)

backends/arm/test/tester/arm_tester.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
5151
from executorch.backends.arm.tosa_specification import TosaSpecification
5252

53+
from executorch.backends.test.harness.stages import Stage, StageType
5354
from executorch.backends.xnnpack.test.tester import Tester
5455
from executorch.devtools.backend_debug import get_delegation_info
5556

@@ -242,9 +243,12 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram:
242243
super().run(artifact, inputs)
243244

244245

245-
class InitialModel(tester.Stage):
246+
class InitialModel(Stage):
246247
def __init__(self, model: torch.nn.Module):
247248
self.model = model
249+
250+
def stage_type(self) -> StageType:
251+
return StageType.INITIAL_MODEL
248252

249253
def run(self, artifact, inputs=None) -> None:
250254
pass
@@ -284,13 +288,13 @@ def __init__(
284288
self.constant_methods = constant_methods
285289
self.compile_spec = compile_spec
286290
super().__init__(model, example_inputs, dynamic_shapes)
287-
self.pipeline[self.stage_name(InitialModel)] = [
288-
self.stage_name(tester.Quantize),
289-
self.stage_name(tester.Export),
291+
self.pipeline[StageType.INITIAL_MODEL] = [
292+
StageType.QUANTIZE,
293+
StageType.EXPORT,
290294
]
291295

292296
# Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
293-
self.stages[self.stage_name(InitialModel)] = None
297+
self.stages[StageType.INITIAL_MODEL] = None
294298
self._run_stage(InitialModel(self.original_module))
295299

296300
def quantize(self, quantize_stage: Optional[tester.Quantize] = None):
@@ -385,7 +389,7 @@ def serialize(
385389
return super().serialize(serialize_stage)
386390

387391
def is_quantized(self) -> bool:
388-
return self.stages[self.stage_name(tester.Quantize)] is not None
392+
return self.stages[StageType.QUANTIZE] is not None
389393

390394
def run_method_and_compare_outputs(
391395
self,
@@ -414,18 +418,16 @@ def run_method_and_compare_outputs(
414418
"""
415419

416420
if not run_eager_mode:
417-
edge_stage = self.stages[self.stage_name(tester.ToEdge)]
421+
edge_stage = self.stages[StageType.TO_EDGE]
418422
if edge_stage is None:
419-
edge_stage = self.stages[
420-
self.stage_name(tester.ToEdgeTransformAndLower)
421-
]
423+
edge_stage = self.stages[StageType.TO_EDGE_TRANSFORM_AND_LOWER]
422424
assert (
423425
edge_stage is not None
424426
), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
425427
else:
426428
# Run models in eager mode. We do this when we want to check that the passes
427429
# are numerically accurate and the exported graph is correct.
428-
export_stage = self.stages[self.stage_name(tester.Export)]
430+
export_stage = self.stages[StageType.EXPORT]
429431
assert (
430432
export_stage is not None
431433
), "To compare outputs in eager mode, the model must be at Export stage"
@@ -435,11 +437,11 @@ def run_method_and_compare_outputs(
435437
is_quantized = self.is_quantized()
436438

437439
if is_quantized:
438-
reference_stage = self.stages[self.stage_name(tester.Quantize)]
440+
reference_stage = self.stages[StageType.QUANTIZE]
439441
else:
440-
reference_stage = self.stages[self.stage_name(InitialModel)]
442+
reference_stage = self.stages[StageType.INITIAL_MODEL]
441443

442-
exported_program = self.stages[self.stage_name(tester.Export)].artifact
444+
exported_program = self.stages[StageType.EXPORT].artifact
443445
output_nodes = get_output_nodes(exported_program)
444446

445447
output_qparams = get_output_quantization_params(output_nodes)
@@ -449,7 +451,7 @@ def run_method_and_compare_outputs(
449451
quantization_scales.append(getattr(output_qparams[node], "scale", None))
450452

451453
logger.info(
452-
f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'"
454+
f"Comparing Stage '{test_stage.stage_type()}' with Stage '{reference_stage.stage_type()}'"
453455
)
454456

455457
# Loop inputs and compare reference stage with the compared stage.
@@ -500,14 +502,12 @@ def get_graph(self, stage: str | None = None) -> Graph:
500502
stage = self.cur
501503
artifact = self.get_artifact(stage)
502504
if (
503-
self.cur == self.stage_name(tester.ToEdge)
504-
or self.cur == self.stage_name(Partition)
505-
or self.cur == self.stage_name(ToEdgeTransformAndLower)
505+
self.cur == StageType.TO_EDGE
506+
or self.cur == StageType.PARTITION
507+
or self.cur == StageType.TO_EDGE_TRANSFORM_AND_LOWER
506508
):
507509
graph = artifact.exported_program().graph
508-
elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name(
509-
tester.Quantize
510-
):
510+
elif self.cur == StageType.EXPORT or self.cur == StageType.QUANTIZE:
511511
graph = artifact.graph
512512
else:
513513
raise RuntimeError(
@@ -533,8 +533,8 @@ def dump_operator_distribution(
533533
if (
534534
self.cur
535535
in (
536-
self.stage_name(tester.Partition),
537-
self.stage_name(ToEdgeTransformAndLower),
536+
StageType.PARTITION,
537+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
538538
)
539539
and print_table
540540
):
@@ -625,7 +625,7 @@ def run_transform_for_annotation_pipeline(
625625
stage = self.cur
626626
# We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
627627
artifact = self.get_artifact(stage)
628-
if self.cur == self.stage_name(tester.Export):
628+
if self.cur == StageType.EXPORT:
629629
new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type]
630630
graph_module=artifact.graph_module
631631
)

backends/test/harness/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "tester",
7+
srcs = [
8+
"__init__.py",
9+
"tester.py",
10+
] + native.glob(["stages/*.py"]),
11+
visibility = [
12+
"//executorch/...",
13+
"@EXECUTORCH_CLIENTS",
14+
],
15+
deps = [
16+
"//executorch/exir:graph_module",
17+
],
18+
)

backends/test/harness/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .tester import Tester
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .export import Export
2+
from .partition import Partition
3+
from .quantize import Quantize
4+
from .run_passes import RunPasses
5+
from .serialize import Serialize
6+
from .stage import Stage, StageType
7+
from .to_edge import ToEdge
8+
from .to_edge_transform_and_lower import ToEdgeTransformAndLower
9+
from .to_executorch import ToExecutorch
10+
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Any, Optional, Sequence, Tuple
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from torch.export import export, ExportedProgram
7+
8+
class Export(Stage):
9+
def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None):
10+
self.exported_program = None
11+
self.dynamic_shapes = dynamic_shapes
12+
13+
def stage_type(self) -> StageType:
14+
return StageType.EXPORT
15+
16+
def run(
17+
self,
18+
artifact: torch.nn.Module,
19+
inputs: Tuple[torch.Tensor],
20+
) -> None:
21+
self.exported_program = export(
22+
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
23+
)
24+
25+
@property
26+
def artifact(self) -> ExportedProgram:
27+
return self.exported_program
28+
29+
@property
30+
def graph_module(self) -> str:
31+
return self.exported_program.graph_module
32+
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from executorch.backends.test.harness.stages.stage import Stage, StageType
2+
from executorch.exir import (
3+
EdgeProgramManager,
4+
)
5+
from executorch.exir.backend.backend_api import validation_disabled
6+
from executorch.exir.backend.partitioner import Partitioner
7+
8+
class Partition(Stage):
9+
def __init__(self, partitioner: Partitioner):
10+
self.partitioner = partitioner
11+
self.delegate_module = None
12+
13+
def stage_type(self) -> StageType:
14+
return StageType.PARTITION
15+
16+
def run(self, artifact: EdgeProgramManager, inputs=None):
17+
with validation_disabled():
18+
self.delegate_module = artifact
19+
self.delegate_module = self.delegate_module.to_backend(self.partitioner)
20+
21+
@property
22+
def artifact(self) -> EdgeProgramManager:
23+
return self.delegate_module
24+
25+
@property
26+
def graph_module(self) -> str:
27+
return self.delegate_module.exported_program().graph_module
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Any, Optional, Sequence, Tuple
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
7+
DuplicateDynamicQuantChainPass,
8+
)
9+
10+
from torch.export import export_for_training
11+
12+
from torchao.quantization.pt2e.quantize_pt2e import (
13+
convert_pt2e,
14+
prepare_pt2e,
15+
prepare_qat_pt2e,
16+
)
17+
from torchao.quantization.pt2e.quantizer import Quantizer
18+
19+
class Quantize(Stage):
20+
def __init__(
21+
self,
22+
quantizer: Optional[Quantizer] = None,
23+
quantization_config: Optional[Any] = None,
24+
calibrate: bool = True,
25+
calibration_samples: Optional[Sequence[Any]] = None,
26+
is_qat: Optional[bool] = False,
27+
):
28+
self.quantizer = quantizer
29+
self.quantization_config = quantization_config
30+
self.calibrate = calibrate
31+
self.calibration_samples = calibration_samples
32+
33+
self.quantizer.set_global(self.quantization_config)
34+
35+
self.converted_graph = None
36+
self.is_qat = is_qat
37+
38+
def stage_type(self) -> str:
39+
return StageType.QUANTIZE
40+
41+
def run(
42+
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
43+
) -> None:
44+
assert inputs is not None
45+
if self.is_qat:
46+
artifact.train()
47+
captured_graph = export_for_training(artifact, inputs, strict=True).module()
48+
49+
assert isinstance(captured_graph, torch.fx.GraphModule)
50+
51+
if self.is_qat:
52+
prepared = prepare_qat_pt2e(captured_graph, self.quantizer)
53+
else:
54+
prepared = prepare_pt2e(captured_graph, self.quantizer)
55+
56+
if self.calibrate:
57+
# Calibrate prepared model to provide data to quantization observers.
58+
if self.calibration_samples is not None:
59+
for inp in self.calibration_samples:
60+
prepared(*inp)
61+
else:
62+
prepared(*inputs)
63+
64+
converted = convert_pt2e(prepared)
65+
DuplicateDynamicQuantChainPass()(converted)
66+
67+
self.converted_graph = converted
68+
69+
@property
70+
def artifact(self) -> torch.fx.GraphModule:
71+
return self.converted_graph
72+
73+
@property
74+
def graph_module(self) -> str:
75+
return self.converted_graph
76+
77+
def run_artifact(self, inputs):
78+
return self.converted_graph.forward(*inputs)

0 commit comments

Comments
 (0)