Skip to content

Commit ee651da

Browse files
committed
Refactor XNNPACK tester to extract base class
1 parent 120eb85 commit ee651da

21 files changed

+1206
-790
lines changed

backends/arm/test/misc/test_lifted_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TosaPipelineBI,
1313
TosaPipelineMI,
1414
)
15+
from executorch.backends.test.harness.stages import StageType
1516
from executorch.backends.xnnpack.test.tester import ToEdge
1617

1718

@@ -72,9 +73,8 @@ def test_partition_lifted_tensor_tosa_MI(test_data: input_t1):
7273
use_to_edge_transform_and_lower=False,
7374
)
7475
pipeline.run()
75-
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
7676
signature = (
77-
pipeline.tester.stages[to_edge_stage_name]
77+
pipeline.tester.stages[StageType.TO_EDGE]
7878
.artifact.exported_program()
7979
.graph_signature
8080
)
@@ -94,9 +94,8 @@ def test_partition_lifted_tensor_tosa_BI(test_data: input_t1):
9494
use_to_edge_transform_and_lower=False,
9595
)
9696
pipeline.run()
97-
to_edge_stage_name = pipeline.tester.stage_name(ToEdge)
9897
signature = (
99-
pipeline.tester.stages[to_edge_stage_name]
98+
pipeline.tester.stages[StageType.TO_EDGE]
10099
.artifact.exported_program()
101100
.graph_signature
102101
)

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_output_quantization_params,
1515
)
1616

17+
from executorch.backends.test.harness.stages import StageType
1718
from executorch.backends.xnnpack.test.tester.tester import Export, Quantize
1819

1920
logger = logging.getLogger(__name__)
@@ -236,8 +237,8 @@ def dump_error_output(
236237
if path_to_tosa_files is None:
237238
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
238239

239-
export_stage = tester.stages.get(tester.stage_name(Export), None)
240-
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
240+
export_stage = tester.stages.get(StageType.EXPORT, None)
241+
quantize_stage = tester.stages.get(StageType.QUANTIZE, None)
241242
if export_stage is not None and quantize_stage is not None:
242243
output_nodes = get_output_nodes(export_stage.artifact)
243244
qp_input = get_input_quantization_params(export_stage.artifact)

backends/arm/test/tester/arm_tester.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
5151
from executorch.backends.arm.tosa_specification import TosaSpecification
5252

53+
from executorch.backends.test.harness.stages import StageType
5354
from executorch.backends.xnnpack.test.tester import Tester
5455
from executorch.devtools.backend_debug import get_delegation_info
5556

@@ -284,13 +285,13 @@ def __init__(
284285
self.constant_methods = constant_methods
285286
self.compile_spec = compile_spec
286287
super().__init__(model, example_inputs, dynamic_shapes)
287-
self.pipeline[self.stage_name(InitialModel)] = [
288-
self.stage_name(tester.Quantize),
289-
self.stage_name(tester.Export),
288+
self.pipeline[StageType.INITIAL_MODEL] = [
289+
StageType.QUANTIZE,
290+
StageType.EXPORT,
290291
]
291292

292293
# Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
293-
self.stages[self.stage_name(InitialModel)] = None
294+
self.stages[StageType.INTIIAL_MODEL] = None
294295
self._run_stage(InitialModel(self.original_module))
295296

296297
def quantize(self, quantize_stage: Optional[tester.Quantize] = None):
@@ -385,7 +386,7 @@ def serialize(
385386
return super().serialize(serialize_stage)
386387

387388
def is_quantized(self) -> bool:
388-
return self.stages[self.stage_name(tester.Quantize)] is not None
389+
return self.stages[StageType.QUANTIZE] is not None
389390

390391
def run_method_and_compare_outputs(
391392
self,
@@ -414,18 +415,16 @@ def run_method_and_compare_outputs(
414415
"""
415416

416417
if not run_eager_mode:
417-
edge_stage = self.stages[self.stage_name(tester.ToEdge)]
418+
edge_stage = self.stages[StageType.TO_EDGE]
418419
if edge_stage is None:
419-
edge_stage = self.stages[
420-
self.stage_name(tester.ToEdgeTransformAndLower)
421-
]
420+
edge_stage = self.stages[StageType.TO_EDGE_TRANSFORM_AND_LOWER]
422421
assert (
423422
edge_stage is not None
424423
), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
425424
else:
426425
# Run models in eager mode. We do this when we want to check that the passes
427426
# are numerically accurate and the exported graph is correct.
428-
export_stage = self.stages[self.stage_name(tester.Export)]
427+
export_stage = self.stages[StageType.EXPORT]
429428
assert (
430429
export_stage is not None
431430
), "To compare outputs in eager mode, the model must be at Export stage"
@@ -435,11 +434,11 @@ def run_method_and_compare_outputs(
435434
is_quantized = self.is_quantized()
436435

437436
if is_quantized:
438-
reference_stage = self.stages[self.stage_name(tester.Quantize)]
437+
reference_stage = self.stages[StageType.QUANTIZE]
439438
else:
440-
reference_stage = self.stages[self.stage_name(InitialModel)]
439+
reference_stage = self.stages[StageType.INITIAL_MODEL]
441440

442-
exported_program = self.stages[self.stage_name(tester.Export)].artifact
441+
exported_program = self.stages[StageType.EXPORT].artifact
443442
output_nodes = get_output_nodes(exported_program)
444443

445444
output_qparams = get_output_quantization_params(output_nodes)
@@ -449,7 +448,7 @@ def run_method_and_compare_outputs(
449448
quantization_scales.append(getattr(output_qparams[node], "scale", None))
450449

451450
logger.info(
452-
f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'"
451+
f"Comparing Stage '{test_stage.stage_type()}' with Stage '{reference_stage.stage_type()}'"
453452
)
454453

455454
# Loop inputs and compare reference stage with the compared stage.
@@ -500,14 +499,12 @@ def get_graph(self, stage: str | None = None) -> Graph:
500499
stage = self.cur
501500
artifact = self.get_artifact(stage)
502501
if (
503-
self.cur == self.stage_name(tester.ToEdge)
504-
or self.cur == self.stage_name(Partition)
505-
or self.cur == self.stage_name(ToEdgeTransformAndLower)
502+
self.cur == StageType.TO_EDGE
503+
or self.cur == StageType.PARTITION
504+
or self.cur == StageType.TO_EDGE_TRANSFORM_AND_LOWER
506505
):
507506
graph = artifact.exported_program().graph
508-
elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name(
509-
tester.Quantize
510-
):
507+
elif self.cur == StageType.EXPORT or self.cur == StageType.QUANTIZE:
511508
graph = artifact.graph
512509
else:
513510
raise RuntimeError(
@@ -533,8 +530,8 @@ def dump_operator_distribution(
533530
if (
534531
self.cur
535532
in (
536-
self.stage_name(tester.Partition),
537-
self.stage_name(ToEdgeTransformAndLower),
533+
StageType.PARTITION,
534+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
538535
)
539536
and print_table
540537
):
@@ -625,7 +622,7 @@ def run_transform_for_annotation_pipeline(
625622
stage = self.cur
626623
# We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
627624
artifact = self.get_artifact(stage)
628-
if self.cur == self.stage_name(tester.Export):
625+
if self.cur == StageType.EXPORT:
629626
new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type]
630627
graph_module=artifact.graph_module
631628
)

backends/test/harness/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .tester import Tester
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .export import Export
2+
from .partition import Partition
3+
from .quantize import Quantize
4+
from .run_passes import RunPasses
5+
from .serialize import Serialize
6+
from .stage import Stage, StageType
7+
from .to_edge import ToEdge
8+
from .to_edge_transform_and_lower import ToEdgeTransformAndLower
9+
from .to_executorch import ToExecutorch
10+
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Any, Optional, Sequence, Tuple
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from torch.export import export, ExportedProgram
7+
8+
class Export(Stage):
9+
def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None):
10+
self.exported_program = None
11+
self.dynamic_shapes = dynamic_shapes
12+
13+
def stage_type(self) -> StageType:
14+
return StageType.EXPORT
15+
16+
def run(
17+
self,
18+
artifact: torch.nn.Module,
19+
inputs: Tuple[torch.Tensor],
20+
) -> None:
21+
self.exported_program = export(
22+
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
23+
)
24+
25+
@property
26+
def artifact(self) -> ExportedProgram:
27+
return self.exported_program
28+
29+
@property
30+
def graph_module(self) -> str:
31+
return self.exported_program.graph_module
32+
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from executorch.backends.test.harness.stages.stage import Stage, StageType
2+
from executorch.exir import (
3+
EdgeProgramManager,
4+
)
5+
from executorch.exir.backend.backend_api import validation_disabled
6+
from executorch.exir.backend.partitioner import Partitioner
7+
8+
class Partition(Stage):
9+
def __init__(self, partitioner: Partitioner):
10+
self.partitioner = partitioner
11+
self.delegate_module = None
12+
13+
def stage_type(self) -> StageType:
14+
return StageType.PARTITION
15+
16+
def run(self, artifact: EdgeProgramManager, inputs=None):
17+
with validation_disabled():
18+
self.delegate_module = artifact
19+
self.delegate_module = self.delegate_module.to_backend(self.partitioner)
20+
21+
@property
22+
def artifact(self) -> EdgeProgramManager:
23+
return self.delegate_module
24+
25+
@property
26+
def graph_module(self) -> str:
27+
return self.delegate_module.exported_program().graph_module
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Any, Optional, Sequence, Tuple
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
7+
DuplicateDynamicQuantChainPass,
8+
)
9+
10+
from torch.export import export_for_training
11+
12+
from torchao.quantization.pt2e.quantize_pt2e import (
13+
convert_pt2e,
14+
prepare_pt2e,
15+
prepare_qat_pt2e,
16+
)
17+
from torchao.quantization.pt2e.quantizer import Quantizer
18+
19+
class Quantize(Stage):
20+
def __init__(
21+
self,
22+
quantizer: Optional[Quantizer] = None,
23+
quantization_config: Optional[Any] = None,
24+
calibrate: bool = True,
25+
calibration_samples: Optional[Sequence[Any]] = None,
26+
is_qat: Optional[bool] = False,
27+
):
28+
self.quantizer = quantizer
29+
self.quantization_config = quantization_config
30+
self.calibrate = calibrate
31+
self.calibration_samples = calibration_samples
32+
33+
self.quantizer.set_global(self.quantization_config)
34+
35+
self.converted_graph = None
36+
self.is_qat = is_qat
37+
38+
def stage_type(self) -> str:
39+
return StageType.QUANTIZE
40+
41+
def run(
42+
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
43+
) -> None:
44+
assert inputs is not None
45+
if self.is_qat:
46+
artifact.train()
47+
captured_graph = export_for_training(artifact, inputs, strict=True).module()
48+
49+
assert isinstance(captured_graph, torch.fx.GraphModule)
50+
51+
if self.is_qat:
52+
prepared = prepare_qat_pt2e(captured_graph, self.quantizer)
53+
else:
54+
prepared = prepare_pt2e(captured_graph, self.quantizer)
55+
56+
if self.calibrate:
57+
# Calibrate prepared model to provide data to quantization observers.
58+
if self.calibration_samples is not None:
59+
for inp in self.calibration_samples:
60+
prepared(*inp)
61+
else:
62+
prepared(*inputs)
63+
64+
converted = convert_pt2e(prepared)
65+
DuplicateDynamicQuantChainPass()(converted)
66+
67+
self.converted_graph = converted
68+
69+
@property
70+
def artifact(self) -> torch.fx.GraphModule:
71+
return self.converted_graph
72+
73+
@property
74+
def graph_module(self) -> str:
75+
return self.converted_graph
76+
77+
def run_artifact(self, inputs):
78+
return self.converted_graph.forward(*inputs)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Any, Callable, List, Optional, Sequence, Type, Tuple, Union
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from executorch.exir import (
7+
EdgeCompileConfig,
8+
EdgeProgramManager,
9+
)
10+
from executorch.exir.backend.partitioner import Partitioner
11+
from torch._export.pass_base import PassType
12+
from torch.export import ExportedProgram
13+
14+
class RunPasses(Stage):
15+
def __init__(
16+
self,
17+
pass_manager_cls: Type,
18+
pass_list: Optional[List[Type[PassType]]] = None,
19+
pass_functions: Optional[List[Callable]] = None,
20+
):
21+
self.pass_manager_cls = pass_manager_cls
22+
self.pass_list = pass_list
23+
self.pass_functions = pass_functions
24+
self.edge_or_aten_program = None
25+
26+
def stage_type(self) -> StageType:
27+
return StageType.RUN_PASSES
28+
29+
def run(
30+
self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None
31+
) -> None:
32+
if isinstance(artifact, EdgeProgramManager):
33+
self.edge_or_aten_program = artifact
34+
if self.pass_list:
35+
pass_manager = self.pass_manager_cls(
36+
artifact.exported_program(), self.pass_list
37+
)
38+
self.edge_or_aten_program._edge_programs["forward"] = (
39+
pass_manager.transform()
40+
)
41+
if self.pass_functions:
42+
assert isinstance(self.pass_functions, list)
43+
for pass_function in self.pass_functions:
44+
self.edge_or_aten_program._edge_programs["forward"] = pass_function(
45+
self.edge_or_aten_program.exported_program()
46+
)
47+
else:
48+
transformed_ep = artifact
49+
if self.pass_list:
50+
assert isinstance(self.pass_list, list)
51+
for pass_ in self.pass_list:
52+
transformed_ep = _transform(transformed_ep, pass_())
53+
54+
if self.pass_functions:
55+
assert isinstance(self.pass_functions, list)
56+
for pass_function in self.pass_functions:
57+
transformed_ep = pass_function(transformed_ep)
58+
59+
self.edge_or_aten_program = transformed_ep
60+
61+
@property
62+
def artifact(self) -> Union[EdgeProgramManager, ExportedProgram]:
63+
return self.edge_or_aten_program
64+
65+
@property
66+
def graph_module(self) -> str:
67+
if isinstance(self.edge_or_aten_program, EdgeProgramManager):
68+
return self.edge_or_aten_program.exported_program().graph_module
69+
else:
70+
return self.edge_or_aten_program.graph_module

0 commit comments

Comments
 (0)