1
1
# pyre-strict
2
2
3
- import types
4
- from typing import List , Tuple
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
5
6
6
7
import torch
7
8
import torch .utils ._pytree as pytree
8
-
9
- from executorch .backends .compile_spec_schema import CompileSpec
10
- from executorch .exir .graph_module import _get_submodule
11
- from executorch .exir .tracer import Value
12
- from torch ._export .exported_program import ExportedProgram
13
9
from torch ._functorch .eager_transforms import (
14
10
_unwrap_all_tensors_from_functional ,
15
11
_wrap_all_tensors_to_functional ,
16
12
)
17
13
from torch ._ops import HigherOrderOperator
18
- from torch ._subclasses import FakeTensor
19
14
from torch ._subclasses .fake_tensor import FakeTensorMode
20
15
from torch .fx .experimental .proxy_tensor import (
21
16
disable_proxy_modes_tracing ,
22
17
get_proxy_slot ,
23
18
ProxyTorchDispatchMode ,
24
19
track_tensor_tree ,
25
20
)
26
- from torch .fx .passes .utils .fuser_utils import (
27
- erase_nodes ,
28
- fuse_as_graphmodule ,
29
- insert_subgm ,
30
- legalize_graph ,
31
- NodeList ,
32
- topo_sort ,
33
- )
34
21
from torch .utils ._python_dispatch import (
35
22
_get_current_dispatch_mode ,
36
23
_pop_mode_temporarily ,
37
24
)
38
-
39
25
from torch .utils ._pytree import tree_flatten
40
26
41
27
42
- class LoweredBackendModule (torch .nn .Module ):
43
- """
44
- A subclass of nn.Module that is generated for modules containing
45
- delegated functions. This is can be created by calling `to_backend`.
46
-
47
- Private Attributes:
48
- * **backend_id**: The backend's name
49
- * **processed_bytes**: The delegate blobs created from backend.preprocess
50
- * **compile_specs**: A list of backend-specific objects with static
51
- metadata to configure the "compilation" process.
52
- * **original_module**: The original EXIR module
53
- """
54
-
55
- _backend_id : str
56
- _processed_bytes : bytes
57
- _compile_specs : List [CompileSpec ]
58
- _original_module : ExportedProgram
59
-
60
- def __init__ (
61
- self ,
62
- edge_program : ExportedProgram ,
63
- backend_id : str ,
64
- processed_bytes : bytes ,
65
- compile_specs : List [CompileSpec ],
66
- ) -> None :
67
- super ().__init__ ()
68
- self ._original_module = edge_program
69
- self ._backend_id = backend_id
70
- self ._processed_bytes = processed_bytes
71
- self ._compile_specs = compile_specs
72
-
73
- @property
74
- def backend_id (self ) -> str :
75
- return self ._backend_id
76
-
77
- @property
78
- def processed_bytes (self ) -> bytes :
79
- return self ._processed_bytes
80
-
81
- @property
82
- def compile_specs (self ) -> List [CompileSpec ]:
83
- return self ._compile_specs
84
-
85
- @property
86
- def original_module (self ) -> ExportedProgram :
87
- return self ._original_module
88
-
89
- # Used to patch each delegated function with a call_delegate call
90
- # @staticmethod
91
- def forward (
92
- self ,
93
- * args : Value ,
94
- ** kwargs : Tuple [Value , ...],
95
- ) -> Value :
96
- return executorch_call_delegate (self , * args )
97
-
98
-
99
28
executorch_call_delegate = HigherOrderOperator (
100
29
"executorch_call_delegate" , _deprecated_global_ns = True
101
30
)
@@ -108,6 +37,7 @@ def forward(
108
37
# pyre-ignore
109
38
executorch_call_delegate .fallthrough (torch ._C .DispatchKey .AutocastCPU )
110
39
40
+ LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
111
41
112
42
# pyre-ignore
113
43
def trace_call_delegate (proxy_mode , func_overload , lowered_module , * args ):
@@ -117,7 +47,7 @@ def _unwrap_proxy(e):
117
47
return e
118
48
return get_proxy_slot (e , proxy_mode .tracer , e , lambda e : e .proxy )
119
49
120
- if not isinstance (lowered_module , LoweredBackendModule ):
50
+ if not is_lowered_module (lowered_module ):
121
51
raise ValueError (
122
52
"executorch_call_delegate()'s first argument must be a LoweredBackendModule"
123
53
)
@@ -235,8 +165,18 @@ def call_delegate_functionalize(interpreter, lowered_module, *args):
235
165
return _wrap_all_tensors_to_functional (res , level = interpreter .level ())
236
166
237
167
168
+ # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre
169
+ def is_lowered_module (obj : Any ) -> bool :
170
+ """
171
+ This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import.
172
+ """
173
+ return type (obj ).__name__ == LOWERED_BACKEND_MODULE_TYPE
174
+
175
+
238
176
def get_lowered_module_name (
239
- root : torch .nn .Module , lowered_module : LoweredBackendModule
177
+ root : torch .nn .Module ,
178
+ # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
179
+ lowered_module : LOWERED_BACKEND_MODULE_TYPE , # noqa
240
180
) -> str :
241
181
"""
242
182
Adds the given lowered_module into the given root module and returns the
@@ -254,110 +194,3 @@ def get_lowered_module_name(
254
194
255
195
root .add_module (qualname , lowered_module )
256
196
return qualname
257
-
258
-
259
- # TODO(zhxchen17) Try ExportPass
260
- def _fixup_output_node (gm : torch .fx .GraphModule ) -> None :
261
- for node in reversed (gm .graph .nodes ):
262
- if node .op == "output" :
263
- with gm .graph .inserting_before (node ):
264
- assert len (node .args ) == 1
265
- outputs = node .args [0 ]
266
- if isinstance (outputs , torch .fx .Node ):
267
- val = outputs .meta .get ("val" )
268
- if isinstance (val , list ):
269
- # If a list is returned, in some cases it is represented as a
270
- # singular node, like `split_copy_tensor` but EXIR will return a
271
- # opened-up list like `[getitem1, getitem2]`
272
- outputs = [
273
- torch .fx .Proxy (outputs )[i ].node for i in range (len (val ))
274
- ]
275
- returns , out_spec = pytree .tree_flatten (outputs )
276
- node .args = (returns ,)
277
- return
278
-
279
-
280
- def create_submodule_from_nodes (
281
- gm : torch .fx .GraphModule ,
282
- node_list : NodeList ,
283
- tag : str ,
284
- skip_legalize_graph : bool = False ,
285
- ) -> Tuple [torch .fx .GraphModule , torch .fx .Node ]:
286
- """
287
- Modifies the given graph module in-place to separate out the given nodes
288
- into a submodule. The given node_list should form a fully connected
289
- subgraph.
290
-
291
- Args:
292
- gm: The graph module that we want to partition
293
- node_list: A list of nodes that belong in the partition
294
-
295
- Returns:
296
- The submodule that has been partitioned, the call_module node in the
297
- toplevel graph module calling the submodule
298
- """
299
- sorted_nodes = topo_sort (node_list )
300
-
301
- submodule_name = "fused_" + tag
302
- sub_gm , orig_inputs , orig_outputs = fuse_as_graphmodule (
303
- gm , sorted_nodes , submodule_name
304
- )
305
-
306
- _fixup_output_node (sub_gm )
307
-
308
- gm = insert_subgm (gm , sub_gm , orig_inputs , orig_outputs )
309
- if len (orig_outputs ) == 1 and isinstance (orig_outputs [0 ].meta ["val" ], FakeTensor ):
310
- # If the original output is a single tensor, it has been
311
- # pytree.tree_flatten-ed to be a singleton list, so we want to replace
312
- # all uses with a getitem call to the 0th index of the result
313
- for node in gm .graph .nodes :
314
- if node .op == "call_module" :
315
- with gm .graph .inserting_after (node ):
316
- proxy_out = torch .fx .Proxy (node )[0 ].node # type: ignore[index]
317
- node .replace_all_uses_with (proxy_out , propagate_meta = True )
318
- # Reset the args since it was overwritten in the previous line
319
- proxy_out .args = (node , 0 )
320
-
321
- erase_nodes (gm , sorted_nodes )
322
-
323
- # Topological sort original gm with newly created sub_gm
324
- # TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes
325
- # once we transition to using fuse_by_partitions.
326
- if not skip_legalize_graph :
327
- legalize_graph (gm )
328
-
329
- # Get the call_module node
330
- submodule_node = None
331
- for node in gm .graph .nodes :
332
- if node .op == "call_module" and node .target == submodule_name :
333
- submodule_node = node
334
- elif node .op == "call_module" :
335
- raise RuntimeError (
336
- f"The submodule created with nodes { node_list } did not form \
337
- one fully contained subgraph. Check that these nodes form a \
338
- fully contained graph. Partitioned graph: { gm .graph } ."
339
- )
340
-
341
- assert (
342
- submodule_node is not None
343
- ), f"No submodule was created with the nodes { node_list } in the graph { gm .graph } "
344
-
345
- return sub_gm , submodule_node
346
-
347
-
348
- def get_lowered_submodules (
349
- graph_module : torch .fx .GraphModule ,
350
- ) -> List [Tuple [str , LoweredBackendModule , torch .fx .Node ]]:
351
- """
352
- Returns a list of lowered modules that are in the given graph (does not look
353
- into submodules). Specifically, the returned value is a list containing a
354
- tuple of (name of the lowered module that's stored in the graph module, the
355
- lowered module itself, and the fx node that called this lowered module).
356
- """
357
- lowered_submodules = []
358
- for node in graph_module .graph .nodes :
359
- if node .op == "call_function" and node .target == executorch_call_delegate :
360
- name , module , node = _get_submodule (graph_module , node , 0 )
361
- assert isinstance (module , LoweredBackendModule )
362
- lowered_submodules .append ((name , module , node ))
363
- return lowered_submodules
0 commit comments