|
15 | 15 | from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
|
16 | 16 | XnnpackFloatingPointPartitioner,
|
17 | 17 | )
|
| 18 | +from executorch.backends.xnnpack.passes import XNNPACKPassManager |
18 | 19 | from executorch.backends.xnnpack.utils.configs import (
|
19 | 20 | get_xnnpack_capture_config,
|
20 | 21 | get_xnnpack_edge_compile_config,
|
|
33 | 34 | from executorch.extension.pybindings.portable_lib import ( # @manual
|
34 | 35 | _load_for_executorch_from_buffer,
|
35 | 36 | )
|
| 37 | +from torch._export.pass_base import PassType |
36 | 38 | from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
37 | 39 | from torch.ao.quantization.quantizer.quantizer import Quantizer
|
38 | 40 | from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
@@ -159,6 +161,26 @@ def graph_module(self) -> str:
|
159 | 161 | return self.edge_dialect_program.exported_program.graph_module
|
160 | 162 |
|
161 | 163 |
|
| 164 | +@register_stage |
| 165 | +class RunPasses(Stage): |
| 166 | + def __init__(self, pass_list: Optional[List[PassType]] = None): |
| 167 | + self.pass_list = pass_list |
| 168 | + self.edge_dialect_program = None |
| 169 | + |
| 170 | + def run(self, artifact: ExirExportedProgram, inputs=None) -> None: |
| 171 | + pass_manager = XNNPACKPassManager(artifact.exported_program, self.pass_list) |
| 172 | + self.edge_dialect_program = artifact |
| 173 | + self.edge_dialect_program.exported_program = pass_manager.transform() |
| 174 | + |
| 175 | + @property |
| 176 | + def artifact(self) -> ExirExportedProgram: |
| 177 | + return self.edge_dialect_program |
| 178 | + |
| 179 | + @property |
| 180 | + def graph_module(self) -> str: |
| 181 | + return self.edge_dialect_program.exported_program.graph_module |
| 182 | + |
| 183 | + |
162 | 184 | @register_stage
|
163 | 185 | class Partition(Stage):
|
164 | 186 | def __init__(self, partitioner: Optional[Partitioner] = None):
|
@@ -239,7 +261,11 @@ def __init__(
|
239 | 261 | self._stage_name(Export): [
|
240 | 262 | self._stage_name(ToEdge),
|
241 | 263 | ],
|
242 |
| - self._stage_name(ToEdge): [self._stage_name(Partition)], |
| 264 | + self._stage_name(ToEdge): [ |
| 265 | + self._stage_name(Partition), |
| 266 | + self._stage_name(RunPasses), |
| 267 | + ], |
| 268 | + self._stage_name(RunPasses): [self._stage_name(Partition)], |
243 | 269 | # TODO Make this Stage optional
|
244 | 270 | self._stage_name(Partition): [self._stage_name(ToExecutorch)],
|
245 | 271 | self._stage_name(ToExecutorch): [self._stage_name(Serialize)],
|
@@ -298,6 +324,9 @@ def export(self, export_stage: Optional[Export] = None):
|
298 | 324 | def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
|
299 | 325 | return self._run_stage(to_edge_stage or ToEdge())
|
300 | 326 |
|
| 327 | + def run_passes(self, run_passes_stage: Optional[RunPasses] = None): |
| 328 | + return self._run_stage(run_passes_stage or RunPasses()) |
| 329 | + |
301 | 330 | def partition(self, partition_stage: Optional[Partition] = None):
|
302 | 331 | return self._run_stage(partition_stage or Partition())
|
303 | 332 |
|
|
0 commit comments