Skip to content

Commit 8127edd

Browse files
authored
migrate passes calls to pass utils
Differential Revision: D65464704 Pull Request resolved: #6721
1 parent 74bb5ff commit 8127edd

File tree

3 files changed

+108
-44
lines changed

3 files changed

+108
-44
lines changed

backends/cadence/aot/TARGETS

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ python_library(
4848
],
4949
)
5050

51+
5152
python_library(
52-
name = "passes",
53+
name = "pass_utils",
5354
srcs = [
54-
"_passes.py",
55+
"pass_utils.py",
5556
],
5657
deps = [
5758
":utils",
@@ -64,9 +65,9 @@ python_library(
6465
)
6566

6667
python_library(
67-
name = "pass_utils",
68+
name = "passes",
6869
srcs = [
69-
"pass_utils.py",
70+
"passes.py",
7071
],
7172
deps = [
7273
":utils",

backends/cadence/aot/compiler.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,41 +8,33 @@
88

99
import logging
1010
from pathlib import Path
11-
from typing import Optional
11+
from typing import Callable, cast, Optional
1212

1313
import torch
1414

15-
from executorch.backends.cadence.aot._passes import (
16-
InitializePipeline,
17-
RemoveNopExpandOpPass,
18-
RemoveZeroSizedCatArgsPass,
19-
ReplaceLogicalNotBooleanWhereWithWherePass,
20-
ReplacePT2DequantWithCadenceDequantPass,
21-
ReplacePT2QuantWithCadenceQuantPass,
22-
ReplaceSafeSoftmaxWithSoftmax,
23-
ReplaceScalarTensorWithFullPass,
24-
ReplaceSqueezeAndUnsqueezeWithViewPass,
25-
)
15+
from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax
2616
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
2717
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
2818
from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
2919
from executorch.backends.transforms.decompose_sdpa import (
3020
DecomposeScaledDotProductAttention,
3121
)
32-
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
3322
from executorch.devtools import generate_etrecord
3423
from executorch.exir import (
3524
EdgeCompileConfig,
3625
EdgeProgramManager,
3726
ExecutorchProgramManager,
3827
to_edge,
3928
)
29+
from executorch.exir.pass_base import PassResult
4030
from torch.ao.quantization.pt2e.export_utils import model_is_exported
4131
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4232

4333
from torch.export import export
4434
from torch.export.exported_program import ExportedProgram
4535

36+
from .passes import get_cadence_passes
37+
4638
from .utils import print_ops_info
4739

4840

@@ -209,22 +201,16 @@ def export_to_cadence_edge_executorch(
209201
inputs: tuple[object, ...],
210202
dump_graphs: bool = False,
211203
output_dir: Optional[str] = None,
204+
opt_level: int = 1,
212205
) -> ExecutorchProgramManager:
213206
edge_prog_manager = export_to_edge(model, inputs)
207+
cadence_passes = get_cadence_passes(opt_level)
214208

215209
# Run a couple required passes for quant/dequant ops
216210
cadence_prog_manager = edge_prog_manager.transform(
217-
[
218-
InitializePipeline(),
219-
RemoveZeroSizedCatArgsPass(),
220-
ReplaceLogicalNotBooleanWhereWithWherePass(),
221-
ReplaceScalarTensorWithFullPass(),
222-
RemoveCloneOpsTransform(),
223-
RemoveNopExpandOpPass(),
224-
ReplaceSqueezeAndUnsqueezeWithViewPass(),
225-
ReplacePT2QuantWithCadenceQuantPass(),
226-
ReplacePT2DequantWithCadenceDequantPass(),
227-
]
211+
cast(
212+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
213+
)
228214
)
229215

230216
# Print some information to terminal

backends/cadence/aot/_passes.py renamed to backends/cadence/aot/passes.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,74 @@
66

77
# pyre-strict
88

9-
from typing import Any, cast, Dict, Sequence, Tuple
9+
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type
1010

1111
import torch
12+
import torch.fx
13+
import torch.utils._pytree as pytree
14+
from executorch.backends.cadence.aot.pass_utils import (
15+
CadencePassAttribute,
16+
create_cadence_pass_filter,
17+
register_cadence_pass,
18+
)
1219
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
20+
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
1321
from executorch.exir.dialects._ops import ops as exir_ops
1422
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
23+
from executorch.exir.pass_manager import PassManager, PassType
1524
from executorch.exir.passes import dead_code_elimination_pass
25+
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
1626
from executorch.exir.passes.spec_prop_pass import SpecPropPass
1727
from torch._subclasses import FakeTensor
1828
from torch.utils._pytree import tree_map_only
1929

30+
31+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
32+
class RemoveCloneOpsTransformImported(ExportPass):
33+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
34+
finalize_passes: List[PassType] = [
35+
RemoveCloneOpsTransform(),
36+
]
37+
result = PassManager(passes=finalize_passes)(graph_module)
38+
dead_code_elimination_pass(result.graph_module)
39+
return result
40+
41+
42+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
43+
class InitializePipeline(ExportPass):
44+
"""
45+
Initialize the Jarvis pipeline. This should invariably be the first pass to
46+
run.
47+
"""
48+
49+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
50+
dead_code_elimination_pass(graph_module)
51+
result = SpecPropPass()(graph_module)
52+
assert result is not None
53+
return result
54+
55+
56+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
57+
class FinalizePipeline(ExportPass):
58+
"""
59+
The final cleanup pass after running the Jarvis pipeline.
60+
"""
61+
62+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
63+
finalize_passes: List[PassType] = [
64+
ScalarToTensorPass(),
65+
SpecPropPass(),
66+
]
67+
result = PassManager(passes=finalize_passes)(graph_module)
68+
dead_code_elimination_pass(result.graph_module)
69+
return result
70+
71+
2072
# Similar to what's done in executorch/exir/pass_base.py
2173
Argument = Any # pyre-ignore
2274

2375

76+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2477
class ReplacePT2QuantWithCadenceQuantPass(ExportPass):
2578
"""
2679
Replace the pt2 quantization ops with custom cadence quantization ops.
@@ -44,6 +97,7 @@ def call_operator(
4497
)
4598

4699

100+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
47101
class ReplacePT2DequantWithCadenceDequantPass(ExportPass):
48102
"""
49103
Replace the pt2 dequantization ops with custom cadence dequantization ops.
@@ -67,6 +121,7 @@ def call_operator(
67121
)
68122

69123

124+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
70125
class ReplaceScalarTensorWithFullPass(ExportPass):
71126
"""
72127
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
@@ -96,6 +151,7 @@ def call_operator(
96151
)
97152

98153

154+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
99155
class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
100156
"""
101157
When the shape is static, replace squeeze_copy and unsqueeze_copy ops with
@@ -131,7 +187,8 @@ def call_operator(
131187
)
132188

133189

134-
class RemoveZeroSizedCatArgsPass(ExportPass):
190+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
191+
class RemoveZeroSizedCatArgsPass(ExportPass): # is this the latest?
135192
def call_operator(
136193
self,
137194
op, # pyre-ignore
@@ -176,6 +233,7 @@ def call_operator(
176233
return super().call_operator(op, args, kwargs, meta)
177234

178235

236+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
179237
class RemoveNopExpandOpPass(ExportPass):
180238
"""
181239
For an expand op, if the operator shape matches the expand shape, then the
@@ -205,6 +263,7 @@ def call_operator(
205263
return super().call_operator(op, args, kwargs, meta)
206264

207265

266+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
208267
class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass):
209268
"""
210269
A where op with a logical_not and a boolean tensor can be replaced
@@ -255,20 +314,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
255314
return result
256315

257316

258-
class InitializePipeline(ExportPass):
259-
"""
260-
Initialize the Jarvis pipeline. This should invariably be the first pass to
261-
run.
262-
"""
263-
264-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
265-
dead_code_elimination_pass(graph_module)
266-
result = SpecPropPass()(graph_module)
267-
assert result is not None
268-
return result
269-
270-
271-
class ReplaceSafeSoftmaxWithSoftmax(ExportPass):
317+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
318+
class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep
272319
"""
273320
Replace _safe_softmax with _softmax
274321
"""
@@ -292,3 +339,33 @@ def call_operator(
292339
kwargs,
293340
meta,
294341
)
342+
343+
344+
def get_passes_in_default_order() -> List[Type[PassType]]:
345+
passes = [
346+
InitializePipeline,
347+
RemoveZeroSizedCatArgsPass,
348+
ReplaceLogicalNotBooleanWhereWithWherePass,
349+
ReplaceScalarTensorWithFullPass,
350+
RemoveCloneOpsTransformImported,
351+
RemoveNopExpandOpPass,
352+
ReplaceSqueezeAndUnsqueezeWithViewPass,
353+
ReplacePT2QuantWithCadenceQuantPass,
354+
ReplacePT2DequantWithCadenceDequantPass,
355+
# TODO: add the rest of the passes here.
356+
]
357+
return pytree.tree_flatten(passes)[0]
358+
359+
360+
def get_cadence_passes(
361+
opt_level: int,
362+
) -> List[Optional[PassResult]]:
363+
passes = get_passes_in_default_order()
364+
pass_filter = create_cadence_pass_filter(opt_level)
365+
filtered_passes = [
366+
# pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`.
367+
filtered_pass()
368+
# pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
369+
for filtered_pass in list(filter(pass_filter, passes))
370+
]
371+
return filtered_passes

0 commit comments

Comments
 (0)