12
12
from executorch .backends .partitioner import Partitioner , TPartitioner
13
13
from executorch .backends .utils import is_identical_graph
14
14
from executorch .exir import (
15
- attach_export_graph_metadata ,
16
15
CallSpec ,
17
- ExirExportedProgram ,
18
16
ExportGraphSignature ,
19
- get_exir_meta ,
20
17
MultiMethodExirExportedProgram ,
21
- pytree ,
22
18
)
23
19
24
20
from executorch .exir .delegate import (
25
21
create_submodule_from_nodes ,
26
22
executorch_call_delegate ,
27
23
get_lowered_module_name ,
28
24
LoweredBackendModule ,
29
- patch_lowered_functions ,
30
25
)
31
- from executorch .exir .graph_module import get_control_flow_submodules
26
+ from executorch .exir .graph_module import (
27
+ attach_export_graph_metadata ,
28
+ ExirMetadata ,
29
+ get_control_flow_submodules ,
30
+ )
32
31
from executorch .exir .pass_base import ExportPass
32
+ from torch ._export .exported_program import ExportedProgram
33
33
34
34
35
35
@singledispatch
@@ -39,7 +39,7 @@ def to_backend(args):
39
39
40
40
def to_backend(
41
41
backend_id: str,
42
- edge_graph_module: torch.fx.GraphModule ,
42
+ edge_graph_module: ExportedProgram ,
43
43
compile_specs: List[CompileSpec],
44
44
) -> LoweredBackendModule:
45
45
@@ -58,23 +58,24 @@ def to_backend(
58
58
@to_backend .register
59
59
def _ (
60
60
backend_id : str ,
61
- edge_graph_module : torch . fx . GraphModule ,
61
+ edge_program : ExportedProgram ,
62
62
compile_specs : List [CompileSpec ],
63
63
) -> LoweredBackendModule :
64
64
"""
65
65
Add overloaded implementations for to_backend:
66
66
def to_backend(
67
67
backend_id: str,
68
- edge_graph_module: torch.fx.GraphModule ,
68
+ edge_program: ExportedProgram ,
69
69
compile_specs: List[CompileSpec],
70
70
) -> LoweredBackendModule:
71
- Requires the passed in Module in Edge dialect to be executed in the backend identified
72
- by backend_id. The forward method of the given edge_graph_module will be
73
- targeted for execution.
71
+ Requires the passed in exported program in Edge dialect to be executed in
72
+ the backend identified by backend_id. The forward method of the given
73
+ edge_graph_module will be targeted for execution.
74
74
75
75
Args:
76
76
backend_id: The backend identifier.
77
- edge_graph_module: A module in Edge dialect to target for lowering to the backend.
77
+ exported_program: An exported program in Edge dialect to target for
78
+ lowering to the backend.
78
79
compile_specs: A list of backend-specific objects with static
79
80
metadata to configure the "compilation" process (e.g. it could be
80
81
another dictionary itself).
@@ -83,7 +84,7 @@ def to_backend(
83
84
LoweredBackendModule: A Module that has been lowered to the target backend.
84
85
Internally, the lowered Module contains these special attributes:
85
86
backend_id (str: backend id), __processed_module__ (str: a compiled module)
86
- compile_spec, original_module (original exported module )
87
+ compile_spec, original_module (original exported program )
87
88
88
89
Raises:
89
90
NotImplementedError: The backend is not implemented (e.g. it was not found).
@@ -93,18 +94,17 @@ def to_backend(
93
94
# All backend implementation are final, so we don't need to consider nested subclasses.
94
95
for cls in BackendDetails .__subclasses__ ():
95
96
if backend_id == cls .__name__ :
96
- copied_graph_module = copy .deepcopy (edge_graph_module )
97
+ copied_edge_program = copy .deepcopy (edge_program )
97
98
processed_bytes = cls .preprocess (
98
- copied_graph_module ,
99
+ copied_edge_program ,
99
100
compile_specs ,
100
101
)
101
102
lowered_module = LoweredBackendModule (
102
- edge_graph_module ,
103
+ edge_program ,
103
104
backend_id ,
104
105
processed_bytes ,
105
106
compile_specs ,
106
107
)
107
- patch_lowered_functions (lowered_module )
108
108
return lowered_module
109
109
raise NotImplementedError (f"Backend { backend_id } was not found." )
110
110
@@ -156,9 +156,26 @@ def _partition_and_lower(
156
156
)
157
157
logging .debug (f"Partitioned graph module: { tagged_graph_module } " )
158
158
159
+ # TODO(T158558782): Update the metadata once we migrate to torch.export
160
+ submodule_program = ExportedProgram (
161
+ submodule ,
162
+ submodule .graph ,
163
+ ExportGraphSignature ([], [], [], [], {}, {}, {}, None ),
164
+ CallSpec (None , None ),
165
+ {},
166
+ {},
167
+ [],
168
+ )
169
+ meta = ExirMetadata (
170
+ in_spec = None ,
171
+ out_spec = None ,
172
+ update_spec = 0 ,
173
+ )
174
+ attach_export_graph_metadata (submodule_program .graph_module , meta )
175
+
159
176
lowered_submodule = to_backend (
160
177
delegation_spec .backend_id ,
161
- submodule ,
178
+ submodule_program ,
162
179
delegation_spec .compile_specs ,
163
180
)
164
181
@@ -199,22 +216,22 @@ def _partition_and_lower(
199
216
200
217
@to_backend .register
201
218
def _ (
202
- edge_graph_module : torch . fx . GraphModule ,
219
+ edge_program : ExportedProgram ,
203
220
partitioner : Type [TPartitioner ],
204
- ) -> torch . fx . GraphModule :
221
+ ) -> ExportedProgram :
205
222
"""
206
223
Add overloaded implementations for to_backend:
207
224
def to_backend(
208
- edge_graph_module: torch.fx.GraphModule ,
225
+ edge_program: ExportedProgram ,
209
226
partitioner: Type[TPartitioner],
210
- ) -> torch.fx.GraphModule
227
+ ) -> ExportedProgram:
211
228
212
229
Returns a semantically-equivalent program to the one given as input (represented
213
230
as a graph module in Edge dialect), but with portions of the program targeted for
214
231
delegation as determined by the partitioner.
215
232
216
233
Args:
217
- torch.fx.GraphModule : Program in Edge dialect.
234
+ ExportedProgram : Program in Edge dialect.
218
235
219
236
partitioner: An instance of the Partitioner class type, in charge with tagging
220
237
portions of the input program for delegation. A valid partitioner must have
@@ -224,8 +241,9 @@ def to_backend(
224
241
225
242
226
243
Returns:
227
- torch.fx.GraphModule : The input program, with some portions targeted for delegation.
244
+ ExportedProgram : The input program, with some portions targeted for delegation.
228
245
"""
246
+ edge_graph_module = edge_program .graph_module
229
247
copied_graph_module = copy .deepcopy (edge_graph_module )
230
248
# Call the partitioner on the given graph module
231
249
partitioner_instance : Partitioner = partitioner ()
@@ -249,7 +267,8 @@ def to_backend(
249
267
tagged_graph_module , partitioner_instance
250
268
)
251
269
252
- return tagged_graph_module
270
+ edge_program .graph_module = tagged_graph_module
271
+ return edge_program
253
272
254
273
255
274
def to_backend_multiple (
@@ -287,35 +306,18 @@ def to_backend_multiple(
287
306
+ "partitioner subclass, or a partitioner subclass."
288
307
)
289
308
290
- method_name_to_delegated_gm = {}
309
+ method_name_to_delegated_program = {}
291
310
for method_name , prog in multi_method_program .methods ().items ():
292
- gm = prog .graph_module
293
311
if isinstance (partitioner , dict ):
294
312
if method_name in partitioner :
295
- method_name_to_delegated_gm [method_name ] = to_backend (
296
- gm , partitioner [method_name ]
313
+ method_name_to_delegated_program [method_name ] = to_backend (
314
+ prog , partitioner [method_name ]
297
315
)
298
316
else :
299
- method_name_to_delegated_gm [method_name ] = gm
317
+ method_name_to_delegated_program [method_name ] = prog
300
318
else :
301
- method_name_to_delegated_gm [method_name ] = to_backend (gm , partitioner )
302
-
303
- def gm_to_program (gm : torch .fx .GraphModule ):
304
- ep = ExirExportedProgram (
305
- gm ,
306
- gm .graph ,
307
- ExportGraphSignature ([], [], [], [], {}, {}, {}, None ),
308
- CallSpec (None , None ),
309
- {},
310
- {},
311
- [],
312
- True ,
313
- )
314
- ep .graph_module .meta .update (gm .meta )
315
- attach_export_graph_metadata (ep .graph_module , get_exir_meta (gm ))
316
- return ep
319
+ method_name_to_delegated_program [method_name ] = to_backend (
320
+ prog , partitioner
321
+ )
317
322
318
- method_name_to_delegated_program = pytree .tree_map (
319
- gm_to_program , method_name_to_delegated_gm
320
- )
321
323
return MultiMethodExirExportedProgram (method_name_to_delegated_program )
0 commit comments