6
6
7
7
# pyre-strict
8
8
9
- from typing import Any , cast , Dict , Sequence , Tuple
9
+ from typing import Any , cast , Dict , List , Optional , Sequence , Tuple , Type
10
10
11
11
import torch
12
+ import torch .fx
13
+ import torch .utils ._pytree as pytree
14
+ from executorch .backends .cadence .aot .pass_utils import (
15
+ CadencePassAttribute ,
16
+ create_cadence_pass_filter ,
17
+ register_cadence_pass ,
18
+ )
12
19
from executorch .backends .cadence .aot .utils import get_edge_overload_packet
20
+ from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
13
21
from executorch .exir .dialects ._ops import ops as exir_ops
14
22
from executorch .exir .pass_base import ExportPass , NodeMetadata , PassResult , ProxyValue
23
+ from executorch .exir .pass_manager import PassManager , PassType
15
24
from executorch .exir .passes import dead_code_elimination_pass
25
+ from executorch .exir .passes .scalar_to_tensor_pass import ScalarToTensorPass
16
26
from executorch .exir .passes .spec_prop_pass import SpecPropPass
17
27
from torch ._subclasses import FakeTensor
18
28
from torch .utils ._pytree import tree_map_only
19
29
30
+
31
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
32
+ class InitializePipeline (ExportPass ):
33
+ """
34
+ Initialize the Jarvis pipeline. This should invariably be the first pass to
35
+ run.
36
+ """
37
+
38
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
39
+ dead_code_elimination_pass (graph_module )
40
+ result = SpecPropPass ()(graph_module )
41
+ assert result is not None
42
+ return result
43
+
44
+
45
+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
46
+ class FinalizePipeline (ExportPass ):
47
+ """
48
+ The final cleanup pass after running the Jarvis pipeline.
49
+ """
50
+
51
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
52
+ finalize_passes : List [PassType ] = [
53
+ ScalarToTensorPass (),
54
+ SpecPropPass (),
55
+ ]
56
+ result = PassManager (passes = finalize_passes )(graph_module )
57
+ dead_code_elimination_pass (result .graph_module )
58
+ return result
59
+
60
+
20
61
# Similar to what's done in executorch/exir/pass_base.py
21
62
Argument = Any # pyre-ignore
22
63
@@ -131,7 +172,7 @@ def call_operator(
131
172
)
132
173
133
174
134
- class RemoveZeroSizedCatArgsPass (ExportPass ):
175
+ class RemoveZeroSizedCatArgsPass (ExportPass ): # is this the latest?
135
176
def call_operator (
136
177
self ,
137
178
op , # pyre-ignore
@@ -255,20 +296,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
255
296
return result
256
297
257
298
258
- class InitializePipeline (ExportPass ):
259
- """
260
- Initialize the Jarvis pipeline. This should invariably be the first pass to
261
- run.
262
- """
263
-
264
- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
265
- dead_code_elimination_pass (graph_module )
266
- result = SpecPropPass ()(graph_module )
267
- assert result is not None
268
- return result
269
-
270
-
271
- class ReplaceSafeSoftmaxWithSoftmax (ExportPass ):
299
+ class ReplaceSafeSoftmaxWithSoftmax (ExportPass ): # keep
272
300
"""
273
301
Replace _safe_softmax with _softmax
274
302
"""
@@ -292,3 +320,33 @@ def call_operator(
292
320
kwargs ,
293
321
meta ,
294
322
)
323
+
324
+
325
+ def get_passes_in_default_order () -> List [Type [PassType ]]:
326
+ passes = [
327
+ InitializePipeline ,
328
+ RemoveZeroSizedCatArgsPass ,
329
+ ReplaceLogicalNotBooleanWhereWithWherePass ,
330
+ ReplaceScalarTensorWithFullPass ,
331
+ RemoveCloneOpsTransform ,
332
+ RemoveNopExpandOpPass ,
333
+ ReplaceSqueezeAndUnsqueezeWithViewPass ,
334
+ ReplacePT2QuantWithCadenceQuantPass ,
335
+ ReplacePT2DequantWithCadenceDequantPass ,
336
+ # TODO: add the rest of the passes here.
337
+ ]
338
+ return pytree .tree_flatten (passes )[0 ]
339
+
340
+
341
+ def get_cadence_passes (
342
+ opt_level : int ,
343
+ ) -> List [Optional [PassResult ]]:
344
+ passes = get_passes_in_default_order ()
345
+ pass_filter = create_cadence_pass_filter (opt_level )
346
+ filtered_passes = [
347
+ # pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`.
348
+ filtered_pass ()
349
+ # pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
350
+ for filtered_pass in list (filter (pass_filter , passes ))
351
+ ]
352
+ return filtered_passes
0 commit comments