Skip to content

Commit e1a20b5

Browse files
mcr229facebook-github-bot
authored andcommitted
add RunPasses Stage to XNNPACK Tester
Differential Revision: D48764114 fbshipit-source-id: 245c0e68a622a4b235fb0b65ed4c35b132a4924c
1 parent 0ab5d95 commit e1a20b5

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

backends/xnnpack/test/tester/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Export,
1010
Partition,
1111
Quantize,
12+
RunPasses,
1213
Serialize,
1314
Tester,
1415
ToEdge,
@@ -21,6 +22,7 @@
2122
Quantize,
2223
Export,
2324
ToEdge,
25+
RunPasses,
2426
ToExecutorch,
2527
Serialize,
2628
]

backends/xnnpack/test/tester/tester.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
1616
XnnpackFloatingPointPartitioner,
1717
)
18+
from executorch.backends.xnnpack.passes import XNNPACKPassManager
1819
from executorch.backends.xnnpack.utils.configs import (
1920
get_xnnpack_capture_config,
2021
get_xnnpack_edge_compile_config,
@@ -33,6 +34,7 @@
3334
from executorch.extension.pybindings.portable_lib import ( # @manual
3435
_load_for_executorch_from_buffer,
3536
)
37+
from torch._export.pass_base import PassType
3638
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3739
from torch.ao.quantization.quantizer.quantizer import Quantizer
3840
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
@@ -159,6 +161,26 @@ def graph_module(self) -> str:
159161
return self.edge_dialect_program.exported_program.graph_module
160162

161163

164+
@register_stage
165+
class RunPasses(Stage):
166+
def __init__(self, pass_list: Optional[List[PassType]] = None):
167+
self.pass_list = pass_list
168+
self.edge_dialect_program = None
169+
170+
def run(self, artifact: ExirExportedProgram, inputs=None) -> None:
171+
pass_manager = XNNPACKPassManager(artifact.exported_program, self.pass_list)
172+
self.edge_dialect_program = artifact
173+
self.edge_dialect_program.exported_program = pass_manager.transform()
174+
175+
@property
176+
def artifact(self) -> ExirExportedProgram:
177+
return self.edge_dialect_program
178+
179+
@property
180+
def graph_module(self) -> str:
181+
return self.edge_dialect_program.exported_program.graph_module
182+
183+
162184
@register_stage
163185
class Partition(Stage):
164186
def __init__(self, partitioner: Optional[Partitioner] = None):
@@ -239,7 +261,11 @@ def __init__(
239261
self._stage_name(Export): [
240262
self._stage_name(ToEdge),
241263
],
242-
self._stage_name(ToEdge): [self._stage_name(Partition)],
264+
self._stage_name(ToEdge): [
265+
self._stage_name(Partition),
266+
self._stage_name(RunPasses),
267+
],
268+
self._stage_name(RunPasses): [self._stage_name(Partition)],
243269
# TODO Make this Stage optional
244270
self._stage_name(Partition): [self._stage_name(ToExecutorch)],
245271
self._stage_name(ToExecutorch): [self._stage_name(Serialize)],
@@ -298,6 +324,9 @@ def export(self, export_stage: Optional[Export] = None):
298324
def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
299325
return self._run_stage(to_edge_stage or ToEdge())
300326

327+
def run_passes(self, run_passes_stage: Optional[RunPasses] = None):
328+
return self._run_stage(run_passes_stage or RunPasses())
329+
301330
def partition(self, partition_stage: Optional[Partition] = None):
302331
return self._run_stage(partition_stage or Partition())
303332

0 commit comments

Comments
 (0)