6
6
7
7
# pyre-strict
8
8
9
- from typing import Any , cast , Dict , Sequence , Tuple
9
+ from typing import Any , cast , Dict , List , Optional , Sequence , Tuple , Type
10
10
11
11
import torch
12
+ import torch .fx
13
+ import torch .utils ._pytree as pytree
14
+ from executorch .backends .cadence .aot .pass_utils import (
15
+ CadencePassAttribute ,
16
+ create_cadence_pass_filter ,
17
+ register_cadence_pass ,
18
+ )
12
19
from executorch .backends .cadence .aot .utils import get_edge_overload_packet
20
+ from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
13
21
from executorch .exir .dialects ._ops import ops as exir_ops
14
22
from executorch .exir .pass_base import ExportPass , NodeMetadata , PassResult , ProxyValue
23
+ from executorch .exir .pass_manager import PassManager , PassType
15
24
from executorch .exir .passes import dead_code_elimination_pass
25
+ from executorch .exir .passes .scalar_to_tensor_pass import ScalarToTensorPass
16
26
from executorch .exir .passes .spec_prop_pass import SpecPropPass
17
27
from torch ._subclasses import FakeTensor
18
28
from torch .utils ._pytree import tree_map_only
19
29
30
+
31
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
32
+ class RemoveCloneOpsTransformImported (ExportPass ):
33
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
34
+ finalize_passes : List [PassType ] = [
35
+ RemoveCloneOpsTransform (),
36
+ ]
37
+ result = PassManager (passes = finalize_passes )(graph_module )
38
+ dead_code_elimination_pass (result .graph_module )
39
+ return result
40
+
41
+
42
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
43
+ class InitializePipeline (ExportPass ):
44
+ """
45
+ Initialize the Jarvis pipeline. This should invariably be the first pass to
46
+ run.
47
+ """
48
+
49
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
50
+ dead_code_elimination_pass (graph_module )
51
+ result = SpecPropPass ()(graph_module )
52
+ assert result is not None
53
+ return result
54
+
55
+
56
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
57
+ class FinalizePipeline (ExportPass ):
58
+ """
59
+ The final cleanup pass after running the Jarvis pipeline.
60
+ """
61
+
62
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
63
+ finalize_passes : List [PassType ] = [
64
+ ScalarToTensorPass (),
65
+ SpecPropPass (),
66
+ ]
67
+ result = PassManager (passes = finalize_passes )(graph_module )
68
+ dead_code_elimination_pass (result .graph_module )
69
+ return result
70
+
71
+
20
72
# Similar to what's done in executorch/exir/pass_base.py
21
73
Argument = Any # pyre-ignore
22
74
23
75
76
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
24
77
class ReplacePT2QuantWithCadenceQuantPass (ExportPass ):
25
78
"""
26
79
Replace the pt2 quantization ops with custom cadence quantization ops.
@@ -44,6 +97,7 @@ def call_operator(
44
97
)
45
98
46
99
100
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
47
101
class ReplacePT2DequantWithCadenceDequantPass (ExportPass ):
48
102
"""
49
103
Replace the pt2 dequantization ops with custom cadence dequantization ops.
@@ -67,6 +121,7 @@ def call_operator(
67
121
)
68
122
69
123
124
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
70
125
class ReplaceScalarTensorWithFullPass (ExportPass ):
71
126
"""
72
127
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
@@ -96,6 +151,7 @@ def call_operator(
96
151
)
97
152
98
153
154
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
99
155
class ReplaceSqueezeAndUnsqueezeWithViewPass (ExportPass ):
100
156
"""
101
157
When the shape is static, replace squeeze_copy and unsqueeze_copy ops with
@@ -131,7 +187,8 @@ def call_operator(
131
187
)
132
188
133
189
134
- class RemoveZeroSizedCatArgsPass (ExportPass ):
190
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
191
+ class RemoveZeroSizedCatArgsPass (ExportPass ): # is this the latest?
135
192
def call_operator (
136
193
self ,
137
194
op , # pyre-ignore
@@ -176,6 +233,7 @@ def call_operator(
176
233
return super ().call_operator (op , args , kwargs , meta )
177
234
178
235
236
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
179
237
class RemoveNopExpandOpPass (ExportPass ):
180
238
"""
181
239
For an expand op, if the operator shape matches the expand shape, then the
@@ -205,6 +263,7 @@ def call_operator(
205
263
return super ().call_operator (op , args , kwargs , meta )
206
264
207
265
266
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
208
267
class ReplaceLogicalNotBooleanWhereWithWherePass (ExportPass ):
209
268
"""
210
269
A where op with a logical_not and a boolean tensor can be replaced
@@ -255,20 +314,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
255
314
return result
256
315
257
316
258
- class InitializePipeline (ExportPass ):
259
- """
260
- Initialize the Jarvis pipeline. This should invariably be the first pass to
261
- run.
262
- """
263
-
264
- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
265
- dead_code_elimination_pass (graph_module )
266
- result = SpecPropPass ()(graph_module )
267
- assert result is not None
268
- return result
269
-
270
-
271
- class ReplaceSafeSoftmaxWithSoftmax (ExportPass ):
317
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
318
+ class ReplaceSafeSoftmaxWithSoftmax (ExportPass ): # keep
272
319
"""
273
320
Replace _safe_softmax with _softmax
274
321
"""
@@ -292,3 +339,33 @@ def call_operator(
292
339
kwargs ,
293
340
meta ,
294
341
)
342
+
343
+
344
+ def get_passes_in_default_order () -> List [Type [PassType ]]:
345
+ passes = [
346
+ InitializePipeline ,
347
+ RemoveZeroSizedCatArgsPass ,
348
+ ReplaceLogicalNotBooleanWhereWithWherePass ,
349
+ ReplaceScalarTensorWithFullPass ,
350
+ RemoveCloneOpsTransformImported ,
351
+ RemoveNopExpandOpPass ,
352
+ ReplaceSqueezeAndUnsqueezeWithViewPass ,
353
+ ReplacePT2QuantWithCadenceQuantPass ,
354
+ ReplacePT2DequantWithCadenceDequantPass ,
355
+ # TODO: add the rest of the passes here.
356
+ ]
357
+ return pytree .tree_flatten (passes )[0 ]
358
+
359
+
360
+ def get_cadence_passes (
361
+ opt_level : int ,
362
+ ) -> List [Optional [PassResult ]]:
363
+ passes = get_passes_in_default_order ()
364
+ pass_filter = create_cadence_pass_filter (opt_level )
365
+ filtered_passes = [
366
+ # pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`.
367
+ filtered_pass ()
368
+ # pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
369
+ for filtered_pass in list (filter (pass_filter , passes ))
370
+ ]
371
+ return filtered_passes
0 commit comments