Skip to content

Commit bb6be76

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
migrate passes calls to pass utils (#6721)
Summary: keep the same calling order in oss, but calling by pass util pass register Reviewed By: hsharma35 Differential Revision: D65464704
1 parent dc41596 commit bb6be76

File tree

3 files changed

+89
-44
lines changed

3 files changed

+89
-44
lines changed

backends/cadence/aot/TARGETS

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ python_library(
4848
],
4949
)
5050

51+
5152
python_library(
52-
name = "passes",
53+
name = "pass_utils",
5354
srcs = [
54-
"_passes.py",
55+
"pass_utils.py",
5556
],
5657
deps = [
5758
":utils",
@@ -64,9 +65,9 @@ python_library(
6465
)
6566

6667
python_library(
67-
name = "pass_utils",
68+
name = "passes",
6869
srcs = [
69-
"pass_utils.py",
70+
"passes.py",
7071
],
7172
deps = [
7273
":utils",

backends/cadence/aot/compiler.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,41 +8,33 @@
88

99
import logging
1010
from pathlib import Path
11-
from typing import Optional
11+
from typing import Callable, cast, Optional
1212

1313
import torch
1414

15-
from executorch.backends.cadence.aot._passes import (
16-
InitializePipeline,
17-
RemoveNopExpandOpPass,
18-
RemoveZeroSizedCatArgsPass,
19-
ReplaceLogicalNotBooleanWhereWithWherePass,
20-
ReplacePT2DequantWithCadenceDequantPass,
21-
ReplacePT2QuantWithCadenceQuantPass,
22-
ReplaceSafeSoftmaxWithSoftmax,
23-
ReplaceScalarTensorWithFullPass,
24-
ReplaceSqueezeAndUnsqueezeWithViewPass,
25-
)
15+
from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax
2616
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
2717
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
2818
from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
2919
from executorch.backends.transforms.decompose_sdpa import (
3020
DecomposeScaledDotProductAttention,
3121
)
32-
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
3322
from executorch.devtools import generate_etrecord
3423
from executorch.exir import (
3524
EdgeCompileConfig,
3625
EdgeProgramManager,
3726
ExecutorchProgramManager,
3827
to_edge,
3928
)
29+
from executorch.exir.pass_base import PassResult
4030
from torch.ao.quantization.pt2e.export_utils import model_is_exported
4131
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4232

4333
from torch.export import export
4434
from torch.export.exported_program import ExportedProgram
4535

36+
from .passes import get_cadence_passes
37+
4638
from .utils import print_ops_info
4739

4840

@@ -209,22 +201,16 @@ def export_to_cadence_edge_executorch(
209201
inputs: tuple[object, ...],
210202
dump_graphs: bool = False,
211203
output_dir: Optional[str] = None,
204+
opt_level: int = 1,
212205
) -> ExecutorchProgramManager:
213206
edge_prog_manager = export_to_edge(model, inputs)
207+
cadence_passes = get_cadence_passes(opt_level)
214208

215209
# Run a couple required passes for quant/dequant ops
216210
cadence_prog_manager = edge_prog_manager.transform(
217-
[
218-
InitializePipeline(),
219-
RemoveZeroSizedCatArgsPass(),
220-
ReplaceLogicalNotBooleanWhereWithWherePass(),
221-
ReplaceScalarTensorWithFullPass(),
222-
RemoveCloneOpsTransform(),
223-
RemoveNopExpandOpPass(),
224-
ReplaceSqueezeAndUnsqueezeWithViewPass(),
225-
ReplacePT2QuantWithCadenceQuantPass(),
226-
ReplacePT2DequantWithCadenceDequantPass(),
227-
]
211+
cast(
212+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
213+
)
228214
)
229215

230216
# Print some information to terminal

backends/cadence/aot/_passes.py renamed to backends/cadence/aot/passes.py

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,58 @@
66

77
# pyre-strict
88

9-
from typing import Any, cast, Dict, Sequence, Tuple
9+
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type
1010

1111
import torch
12+
import torch.fx
13+
import torch.utils._pytree as pytree
14+
from executorch.backends.cadence.aot.pass_utils import (
15+
CadencePassAttribute,
16+
create_cadence_pass_filter,
17+
register_cadence_pass,
18+
)
1219
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
20+
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
1321
from executorch.exir.dialects._ops import ops as exir_ops
1422
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
23+
from executorch.exir.pass_manager import PassManager, PassType
1524
from executorch.exir.passes import dead_code_elimination_pass
25+
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
1626
from executorch.exir.passes.spec_prop_pass import SpecPropPass
1727
from torch._subclasses import FakeTensor
1828
from torch.utils._pytree import tree_map_only
1929

30+
31+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
32+
class InitializePipeline(ExportPass):
33+
"""
34+
Initialize the Jarvis pipeline. This should invariably be the first pass to
35+
run.
36+
"""
37+
38+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
39+
dead_code_elimination_pass(graph_module)
40+
result = SpecPropPass()(graph_module)
41+
assert result is not None
42+
return result
43+
44+
45+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
46+
class FinalizePipeline(ExportPass):
47+
"""
48+
The final cleanup pass after running the Jarvis pipeline.
49+
"""
50+
51+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
52+
finalize_passes: List[PassType] = [
53+
ScalarToTensorPass(),
54+
SpecPropPass(),
55+
]
56+
result = PassManager(passes=finalize_passes)(graph_module)
57+
dead_code_elimination_pass(result.graph_module)
58+
return result
59+
60+
2061
# Similar to what's done in executorch/exir/pass_base.py
2162
Argument = Any # pyre-ignore
2263

@@ -131,7 +172,7 @@ def call_operator(
131172
)
132173

133174

134-
class RemoveZeroSizedCatArgsPass(ExportPass):
175+
class RemoveZeroSizedCatArgsPass(ExportPass): # is this the latest?
135176
def call_operator(
136177
self,
137178
op, # pyre-ignore
@@ -255,20 +296,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
255296
return result
256297

257298

258-
class InitializePipeline(ExportPass):
259-
"""
260-
Initialize the Jarvis pipeline. This should invariably be the first pass to
261-
run.
262-
"""
263-
264-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
265-
dead_code_elimination_pass(graph_module)
266-
result = SpecPropPass()(graph_module)
267-
assert result is not None
268-
return result
269-
270-
271-
class ReplaceSafeSoftmaxWithSoftmax(ExportPass):
299+
class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep
272300
"""
273301
Replace _safe_softmax with _softmax
274302
"""
@@ -292,3 +320,33 @@ def call_operator(
292320
kwargs,
293321
meta,
294322
)
323+
324+
325+
def get_passes_in_default_order() -> List[Type[PassType]]:
326+
passes = [
327+
InitializePipeline,
328+
RemoveZeroSizedCatArgsPass,
329+
ReplaceLogicalNotBooleanWhereWithWherePass,
330+
ReplaceScalarTensorWithFullPass,
331+
RemoveCloneOpsTransform,
332+
RemoveNopExpandOpPass,
333+
ReplaceSqueezeAndUnsqueezeWithViewPass,
334+
ReplacePT2QuantWithCadenceQuantPass,
335+
ReplacePT2DequantWithCadenceDequantPass,
336+
# TODO: add the rest of the passes here.
337+
]
338+
return pytree.tree_flatten(passes)[0]
339+
340+
341+
def get_cadence_passes(
342+
opt_level: int,
343+
) -> List[Optional[PassResult]]:
344+
passes = get_passes_in_default_order()
345+
pass_filter = create_cadence_pass_filter(opt_level)
346+
filtered_passes = [
347+
# pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`.
348+
filtered_pass()
349+
# pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
350+
for filtered_pass in list(filter(pass_filter, passes))
351+
]
352+
return filtered_passes

0 commit comments

Comments
 (0)