Skip to content

Commit d5459d6

Browse files
Zonglin Pengfacebook-github-bot
authored andcommitted
migrate pass utils
Differential Revision: D65447532
1 parent 437168e commit d5459d6

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

backends/cadence/aot/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@ python_library(
6262
],
6363
)
6464

65+
python_library(
66+
name = "pass_utils",
67+
srcs = [
68+
"pass_utils.py",
69+
],
70+
deps = [
71+
":utils",
72+
"//caffe2:torch",
73+
"//executorch/exir:pass_base",
74+
"//executorch/exir/dialects:lib",
75+
"//executorch/exir/passes:lib",
76+
"//executorch/exir/passes:spec_prop_pass",
77+
],
78+
)
79+
6580
python_library(
6681
name = "ops_registrations",
6782
srcs = [

backends/cadence/aot/pass_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from dataclasses import dataclass
6+
from typing import Callable, Optional, Set, Union
7+
8+
import torch
9+
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
10+
11+
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
12+
13+
from executorch.exir.pass_base import ExportPass
14+
from torch._ops import OpOverloadPacket
15+
16+
17+
# Is an overlap in tensor lifetime and storage allowed at the current opt level?
18+
# We allow overlap at opt level >= 2.
19+
def allow_lifetime_and_storage_overlap(opt_level: int) -> bool:
20+
return opt_level >= 2
21+
22+
23+
# A dataclass that stores the attributes of an ExportPass.
24+
@dataclass
25+
class CadencePassAttribute:
26+
opt_level: Optional[int] = None
27+
debug_pass: bool = False
28+
29+
30+
# A dictionary that maps an ExportPass to its attributes.
31+
_ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
32+
33+
34+
def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute:
35+
return _ALL_CADENCE_PASSES[p]
36+
37+
38+
# A decorator that registers a pass.
39+
def register_cadence_pass(
40+
pass_attribute: CadencePassAttribute,
41+
) -> Callable[[ExportPass], ExportPass]:
42+
def wrapper(cls: ExportPass) -> ExportPass:
43+
_ALL_CADENCE_PASSES[cls] = pass_attribute
44+
return cls
45+
46+
return wrapper
47+
48+
49+
def get_all_available_cadence_passes() -> Set[ExportPass]:
50+
return set(_ALL_CADENCE_PASSES.keys())
51+
52+
53+
# Create a new filter to filter out relevant passes from all Jarvis passes.
54+
def create_cadence_pass_filter(
55+
opt_level: int, debug: bool = False
56+
) -> Callable[[ExportPass], bool]:
57+
def _filter(p: ExportPass) -> bool:
58+
pass_attribute = get_cadence_pass_attribute(p)
59+
return (
60+
pass_attribute.opt_level is not None
61+
and pass_attribute.opt_level <= opt_level
62+
and (not pass_attribute.debug_pass or debug)
63+
)
64+
65+
return _filter
66+
67+
68+
# Return the overload packet for the edge or torch op.
69+
def get_overload_packet(
70+
op: Union[Callable[..., str], str],
71+
) -> Union[OpOverloadPacket, EdgeOpOverloadPacket, None]:
72+
return (
73+
get_edge_overload_packet(op)
74+
if isinstance(op, EdgeOpOverload)
75+
else getattr(op, "overloadpacket", None)
76+
)
77+
78+
79+
# Get the list of node names in a graph module (only for "call_function" ops and
80+
# EdgeOpOverload targets). This should be used only after to_edge is called.
81+
def get_node_names_list_from_gm(
82+
graph_module: torch.fx.GraphModule,
83+
) -> list[torch.fx.Node]:
84+
graph_nodes = []
85+
for node in graph_module.graph.nodes:
86+
if node.op != "call_function":
87+
continue
88+
if not isinstance(node.target, EdgeOpOverload):
89+
continue
90+
graph_nodes.append(node.name)
91+
return graph_nodes

0 commit comments

Comments
 (0)