Skip to content

Commit 9e8c542

Browse files
cccclaifacebook-github-bot
authored andcommitted
Seperate lowered backend module as a seperate file
Summary: Previously we define `LoweredBackendModule` and have `executorch_call_delegate` implemented in `delegate.py`. As a preparation for emit `LoweredBackendModule` in D47803806, we may need to separate them to two files due to some circular dependency issues. Reviewed By: angelayi Differential Revision: D47809768 fbshipit-source-id: 18101f5be3819935ecd7ece2daa71670121a81fb
1 parent d601582 commit 9e8c542

21 files changed

+272
-231
lines changed

backends/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,5 +103,6 @@ runtime.python_library(
103103
],
104104
deps = [
105105
"//caffe2:torch",
106+
"//executorch/exir:lowered_backend_module",
106107
],
107108
)

backends/backend_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
MultiMethodExirExportedProgram,
1818
)
1919

20-
from executorch.exir.delegate import (
20+
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
21+
22+
from executorch.exir.graph_module import get_control_flow_submodules
23+
from executorch.exir.lowered_backend_module import (
2124
create_submodule_from_nodes,
22-
executorch_call_delegate,
23-
get_lowered_module_name,
2425
LoweredBackendModule,
2526
)
26-
from executorch.exir.graph_module import get_control_flow_submodules
2727
from executorch.exir.pass_base import ExportPass
2828
from torch._export.exported_program import ExportedProgram
2929

backends/test/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ python_unittest(
100100
"//executorch/exir:delegate",
101101
"//executorch/exir:graph_module",
102102
"//executorch/exir:lib",
103+
"//executorch/exir:lowered_backend_module",
103104
"//executorch/exir:print_program",
104105
"//executorch/exir:schema",
105106
"//executorch/exir/dialects:lib",
@@ -147,6 +148,7 @@ python_unittest(
147148
"//executorch/exir:delegate",
148149
"//executorch/exir:graph_module",
149150
"//executorch/exir:lib",
151+
"//executorch/exir:lowered_backend_module",
150152
"//executorch/kernels/portable:custom_ops_generated_lib",
151153
"//executorch/kernels/quantized:custom_ops_generated_lib",
152154
"//executorch/runtime/executor/test:test_backend_compiler_lib",

backends/test/test_backends.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
from executorch.backends.test.qnn_backend_demo import QnnBackend
2626
from executorch.exir import EdgeCompileConfig, multi_method_program_to_executorch
2727

28-
from executorch.exir.delegate import executorch_call_delegate, get_lowered_submodules
28+
from executorch.exir.delegate import executorch_call_delegate
2929
from executorch.exir.dialects._ops import ops as exir_ops
3030
from executorch.exir.graph_module import get_control_flow_submodules
31+
from executorch.exir.lowered_backend_module import get_lowered_submodules
3132
from executorch.exir.print_program import print_program
3233
from executorch.exir.schema import (
3334
BackendDelegate,

backends/test/test_backends_nested.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
AddOperatorSupport,
1818
MatmulOperatorSupport,
1919
)
20+
from executorch.exir.delegate import executorch_call_delegate
2021

21-
from executorch.exir.delegate import executorch_call_delegate, get_lowered_submodules
2222
from executorch.exir.graph_module import _get_submodule, get_control_flow_submodules
23-
23+
from executorch.exir.lowered_backend_module import get_lowered_submodules
2424
from functorch.experimental import control_flow
2525
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
2626

backends/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import Iterable, List, Tuple
22

33
import torch
4-
5-
from executorch.exir.delegate import create_submodule_from_nodes
64
from executorch.exir.dialects._ops import ops as exir_ops
5+
6+
from executorch.exir.lowered_backend_module import create_submodule_from_nodes
77
from torch.fx.passes.utils.source_matcher_utils import SourcePartition
88

99
T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default

exir/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ python_library(
108108
"delegate.py",
109109
],
110110
deps = [
111+
"//caffe2:torch",
112+
],
113+
)
114+
115+
python_library(
116+
name = "lowered_backend_module",
117+
srcs = [
118+
"lowered_backend_module.py",
119+
],
120+
deps = [
121+
":delegate",
111122
":graph_module",
112123
":tracer",
113124
"//caffe2:torch",

exir/delegate.py

Lines changed: 16 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,30 @@
11
# pyre-strict
22

3-
import types
4-
from typing import List, Tuple
3+
from __future__ import annotations
4+
5+
from typing import Any
56

67
import torch
78
import torch.utils._pytree as pytree
8-
9-
from executorch.backends.compile_spec_schema import CompileSpec
10-
from executorch.exir.graph_module import _get_submodule
11-
from executorch.exir.tracer import Value
12-
from torch._export.exported_program import ExportedProgram
139
from torch._functorch.eager_transforms import (
1410
_unwrap_all_tensors_from_functional,
1511
_wrap_all_tensors_to_functional,
1612
)
1713
from torch._ops import HigherOrderOperator
18-
from torch._subclasses import FakeTensor
1914
from torch._subclasses.fake_tensor import FakeTensorMode
2015
from torch.fx.experimental.proxy_tensor import (
2116
disable_proxy_modes_tracing,
2217
get_proxy_slot,
2318
ProxyTorchDispatchMode,
2419
track_tensor_tree,
2520
)
26-
from torch.fx.passes.utils.fuser_utils import (
27-
erase_nodes,
28-
fuse_as_graphmodule,
29-
insert_subgm,
30-
legalize_graph,
31-
NodeList,
32-
topo_sort,
33-
)
3421
from torch.utils._python_dispatch import (
3522
_get_current_dispatch_mode,
3623
_pop_mode_temporarily,
3724
)
38-
3925
from torch.utils._pytree import tree_flatten
4026

4127

42-
class LoweredBackendModule(torch.nn.Module):
43-
"""
44-
A subclass of nn.Module that is generated for modules containing
45-
delegated functions. This is can be created by calling `to_backend`.
46-
47-
Private Attributes:
48-
* **backend_id**: The backend's name
49-
* **processed_bytes**: The delegate blobs created from backend.preprocess
50-
* **compile_specs**: A list of backend-specific objects with static
51-
metadata to configure the "compilation" process.
52-
* **original_module**: The original EXIR module
53-
"""
54-
55-
_backend_id: str
56-
_processed_bytes: bytes
57-
_compile_specs: List[CompileSpec]
58-
_original_module: ExportedProgram
59-
60-
def __init__(
61-
self,
62-
edge_program: ExportedProgram,
63-
backend_id: str,
64-
processed_bytes: bytes,
65-
compile_specs: List[CompileSpec],
66-
) -> None:
67-
super().__init__()
68-
self._original_module = edge_program
69-
self._backend_id = backend_id
70-
self._processed_bytes = processed_bytes
71-
self._compile_specs = compile_specs
72-
73-
@property
74-
def backend_id(self) -> str:
75-
return self._backend_id
76-
77-
@property
78-
def processed_bytes(self) -> bytes:
79-
return self._processed_bytes
80-
81-
@property
82-
def compile_specs(self) -> List[CompileSpec]:
83-
return self._compile_specs
84-
85-
@property
86-
def original_module(self) -> ExportedProgram:
87-
return self._original_module
88-
89-
# Used to patch each delegated function with a call_delegate call
90-
# @staticmethod
91-
def forward(
92-
self,
93-
*args: Value,
94-
**kwargs: Tuple[Value, ...],
95-
) -> Value:
96-
return executorch_call_delegate(self, *args)
97-
98-
9928
executorch_call_delegate = HigherOrderOperator(
10029
"executorch_call_delegate", _deprecated_global_ns=True
10130
)
@@ -108,6 +37,7 @@ def forward(
10837
# pyre-ignore
10938
executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
11039

40+
LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
11141

11242
# pyre-ignore
11343
def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
@@ -117,7 +47,7 @@ def _unwrap_proxy(e):
11747
return e
11848
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
11949

120-
if not isinstance(lowered_module, LoweredBackendModule):
50+
if not is_lowered_module(lowered_module):
12151
raise ValueError(
12252
"executorch_call_delegate()'s first argument must be a LoweredBackendModule"
12353
)
@@ -235,8 +165,18 @@ def call_delegate_functionalize(interpreter, lowered_module, *args):
235165
return _wrap_all_tensors_to_functional(res, level=interpreter.level())
236166

237167

168+
# pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre
169+
def is_lowered_module(obj: Any) -> bool:
170+
"""
171+
This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import.
172+
"""
173+
return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE
174+
175+
238176
def get_lowered_module_name(
239-
root: torch.nn.Module, lowered_module: LoweredBackendModule
177+
root: torch.nn.Module,
178+
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
179+
lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa
240180
) -> str:
241181
"""
242182
Adds the given lowered_module into the given root module and returns the
@@ -254,110 +194,3 @@ def get_lowered_module_name(
254194

255195
root.add_module(qualname, lowered_module)
256196
return qualname
257-
258-
259-
# TODO(zhxchen17) Try ExportPass
260-
def _fixup_output_node(gm: torch.fx.GraphModule) -> None:
261-
for node in reversed(gm.graph.nodes):
262-
if node.op == "output":
263-
with gm.graph.inserting_before(node):
264-
assert len(node.args) == 1
265-
outputs = node.args[0]
266-
if isinstance(outputs, torch.fx.Node):
267-
val = outputs.meta.get("val")
268-
if isinstance(val, list):
269-
# If a list is returned, in some cases it is represented as a
270-
# singular node, like `split_copy_tensor` but EXIR will return a
271-
# opened-up list like `[getitem1, getitem2]`
272-
outputs = [
273-
torch.fx.Proxy(outputs)[i].node for i in range(len(val))
274-
]
275-
returns, out_spec = pytree.tree_flatten(outputs)
276-
node.args = (returns,)
277-
return
278-
279-
280-
def create_submodule_from_nodes(
281-
gm: torch.fx.GraphModule,
282-
node_list: NodeList,
283-
tag: str,
284-
skip_legalize_graph: bool = False,
285-
) -> Tuple[torch.fx.GraphModule, torch.fx.Node]:
286-
"""
287-
Modifies the given graph module in-place to separate out the given nodes
288-
into a submodule. The given node_list should form a fully connected
289-
subgraph.
290-
291-
Args:
292-
gm: The graph module that we want to partition
293-
node_list: A list of nodes that belong in the partition
294-
295-
Returns:
296-
The submodule that has been partitioned, the call_module node in the
297-
toplevel graph module calling the submodule
298-
"""
299-
sorted_nodes = topo_sort(node_list)
300-
301-
submodule_name = "fused_" + tag
302-
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
303-
gm, sorted_nodes, submodule_name
304-
)
305-
306-
_fixup_output_node(sub_gm)
307-
308-
gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
309-
if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor):
310-
# If the original output is a single tensor, it has been
311-
# pytree.tree_flatten-ed to be a singleton list, so we want to replace
312-
# all uses with a getitem call to the 0th index of the result
313-
for node in gm.graph.nodes:
314-
if node.op == "call_module":
315-
with gm.graph.inserting_after(node):
316-
proxy_out = torch.fx.Proxy(node)[0].node # type: ignore[index]
317-
node.replace_all_uses_with(proxy_out, propagate_meta=True)
318-
# Reset the args since it was overwritten in the previous line
319-
proxy_out.args = (node, 0)
320-
321-
erase_nodes(gm, sorted_nodes)
322-
323-
# Topological sort original gm with newly created sub_gm
324-
# TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes
325-
# once we transition to using fuse_by_partitions.
326-
if not skip_legalize_graph:
327-
legalize_graph(gm)
328-
329-
# Get the call_module node
330-
submodule_node = None
331-
for node in gm.graph.nodes:
332-
if node.op == "call_module" and node.target == submodule_name:
333-
submodule_node = node
334-
elif node.op == "call_module":
335-
raise RuntimeError(
336-
f"The submodule created with nodes {node_list} did not form \
337-
one fully contained subgraph. Check that these nodes form a \
338-
fully contained graph. Partitioned graph: {gm.graph}."
339-
)
340-
341-
assert (
342-
submodule_node is not None
343-
), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}"
344-
345-
return sub_gm, submodule_node
346-
347-
348-
def get_lowered_submodules(
349-
graph_module: torch.fx.GraphModule,
350-
) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]:
351-
"""
352-
Returns a list of lowered modules that are in the given graph (does not look
353-
into submodules). Specifically, the returned value is a list containing a
354-
tuple of (name of the lowered module that's stored in the graph module, the
355-
lowered module itself, and the fx node that called this lowered module).
356-
"""
357-
lowered_submodules = []
358-
for node in graph_module.graph.nodes:
359-
if node.op == "call_function" and node.target == executorch_call_delegate:
360-
name, module, node = _get_submodule(graph_module, node, 0)
361-
assert isinstance(module, LoweredBackendModule)
362-
lowered_submodules.append((name, module, node))
363-
return lowered_submodules

exir/emit/_emitter.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import torch.fx
3636
from executorch.exir import delegate
3737
from executorch.exir.common import add_cursor_to_graph
38-
from executorch.exir.delegate import LoweredBackendModule
38+
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
3939
from executorch.exir.dialects.backend._ops import BackendOpOverload
4040
from executorch.exir.dialects.edge._ops import EdgeOpOverload
4141
from executorch.exir.error import ExportError, ExportErrorType, InternalError
@@ -877,7 +877,7 @@ def _add_debug_handle(self, emitter_id: int, target: _Target) -> None:
877877
"""
878878
# If it's a delegate call, collect the list of debug handles that are consumed by this
879879
# delegate call and store it in the debug handle map.
880-
if target == delegate.executorch_call_delegate:
880+
if target == executorch_call_delegate:
881881
debug_handle_list = []
882882
for node in self.node.graph.nodes:
883883
if (
@@ -912,7 +912,8 @@ def _emit_argument(
912912

913913
def _emit_delegate(
914914
self,
915-
lowered_module: LoweredBackendModule,
915+
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
916+
lowered_module: "LoweredBackendModule", # noqa
916917
args: Tuple[_Argument, ...],
917918
kwargs: Dict[str, _Argument],
918919
) -> _EmitterValue:
@@ -1201,9 +1202,9 @@ def call_function(
12011202
elif target is torch.ops.map_impl:
12021203
return self._emit_control_flow(target, args, kwargs)
12031204

1204-
elif target == delegate.executorch_call_delegate:
1205+
elif target == executorch_call_delegate:
12051206
lowered_module = args[0]
1206-
assert isinstance(lowered_module, LoweredBackendModule)
1207+
assert is_lowered_module(lowered_module)
12071208
v = self._emit_delegate(lowered_module, args[1:], kwargs)
12081209
self._add_debug_handle(len(self.chain.instructions) - 1, target)
12091210
return v

0 commit comments

Comments
 (0)