6
6
7
7
import copy
8
8
import logging
9
- from contextlib import contextmanager
9
+ from contextlib import contextmanager , nullcontext
10
10
from functools import singledispatch
11
11
from typing import Generator , List
12
12
13
13
import torch
14
+ import torch .utils ._pytree as pytree
14
15
15
16
from executorch .exir .backend .backend_details import BackendDetails , PreprocessResult
16
17
from executorch .exir .backend .compile_spec_schema import CompileSpec
25
26
26
27
from executorch .exir .graph_module import get_control_flow_submodules
27
28
from executorch .exir .lowered_backend_module import (
28
- _get_new_signature ,
29
+ _unsafe_adjust_original_program ,
29
30
create_exported_program_from_submodule ,
30
31
create_submodule_from_nodes ,
31
32
LoweredBackendModule ,
32
33
)
33
- from executorch .exir .pass_base import ExportPass
34
34
from executorch .exir .program ._fake_program import (
35
35
get_fake_program ,
36
36
update_to_real_program ,
@@ -193,6 +193,7 @@ def _partition_and_lower_one_graph_module(
193
193
tagged_graph_module : torch .fx .GraphModule ,
194
194
partition_result : PartitionResult ,
195
195
owning_program : ExportedProgram ,
196
+ is_submodule : bool ,
196
197
) -> torch .fx .GraphModule :
197
198
"""
198
199
Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
@@ -210,21 +211,44 @@ def _partition_and_lower_one_graph_module(
210
211
211
212
logging .debug (f"For tag { tag } , found nodes { node_list } " )
212
213
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
213
- submodule , call_module_node = create_submodule_from_nodes (
214
- tagged_graph_module , node_list , tag
214
+
215
+ replace_ctx = (
216
+ tagged_graph_module ._set_replace_hook (
217
+ owning_program .graph_signature .get_replace_hook ()
218
+ )
219
+ if not is_submodule
220
+ else nullcontext ()
215
221
)
222
+ with replace_ctx :
223
+ submodule , call_module_node = create_submodule_from_nodes (
224
+ tagged_graph_module , node_list , tag
225
+ )
226
+
216
227
tagged_graph_module_output_node = [
217
228
node for node in tagged_graph_module .graph .nodes if node .op == "output"
218
- ]
229
+ ][ 0 ]
219
230
submodule_output_node = [
220
231
node for node in submodule .graph .nodes if node .op == "output"
221
- ]
222
- # Copy the output node meta from the original output node, because create_submodule_from_nodes doesn't cover the meta field
223
- submodule_output_node [0 ].meta = tagged_graph_module_output_node [0 ].meta
232
+ ][0 ]
233
+ # Copy the output node meta from the original output node, because
234
+ # create_submodule_from_nodes doesn't cover the meta field
235
+ submodule_output_node .meta = tagged_graph_module_output_node .meta
236
+ submodule_output_node .meta ["val" ] = pytree .tree_map (
237
+ lambda arg : arg .meta .get ("val" ) if isinstance (arg , torch .fx .Node ) else arg ,
238
+ submodule_output_node .args ,
239
+ )
224
240
logging .debug (f"Partitioned graph module: { tagged_graph_module } " )
225
241
226
- submodule_program = create_exported_program_from_submodule (
227
- submodule , owning_program , tag
242
+ (
243
+ submodule_program ,
244
+ toplevel_input_specs_to_delete ,
245
+ toplevel_output_specs_to_delete ,
246
+ ) = create_exported_program_from_submodule (
247
+ submodule ,
248
+ owning_program ,
249
+ tag ,
250
+ call_module_node ,
251
+ is_submodule ,
228
252
)
229
253
230
254
lowered_submodule = to_backend (
@@ -257,64 +281,48 @@ def _partition_and_lower_one_graph_module(
257
281
call_delegate_node .meta ["debug_handle" ] = len (
258
282
tagged_graph_module .graph .nodes
259
283
)
284
+ call_delegate_node .meta ["val" ] = submodule_output_node .meta ["val" ]
260
285
call_module_node .replace_all_uses_with (call_delegate_node )
261
286
tagged_graph_module .graph .erase_node (call_module_node )
262
287
263
- # Delete all parameters/buffers consumed by the created exported program
264
- toplevel_signature = owning_program .graph_signature
265
- for node in tagged_graph_module .graph .nodes :
266
- # Find placeholders consumed by the delegate
267
- if node .op != "placeholder" or len (node .users ) != 0 :
268
- continue
269
-
270
- if node .name in toplevel_signature .inputs_to_buffers :
271
- # Delete the consumed buffers
272
- buffer_name = toplevel_signature .inputs_to_buffers .get (node .name )
273
- if buffer_name in owning_program .state_dict :
274
- owning_program .state_dict .pop (buffer_name )
275
- else :
276
- owning_program .constants .pop (buffer_name )
277
- tagged_graph_module .graph .erase_node (node )
278
- elif node .name in toplevel_signature .inputs_to_parameters :
279
- # Delete the consumed parameters
280
- param_name = toplevel_signature .inputs_to_parameters .get (node .name )
281
- owning_program .state_dict .pop (param_name )
282
- tagged_graph_module .graph .erase_node (node )
283
-
284
- tagged_graph_module .recompile ()
288
+ if is_submodule :
289
+ assert len (toplevel_input_specs_to_delete ) == 0
290
+ assert len (toplevel_output_specs_to_delete ) == 0
291
+ elif (
292
+ len (toplevel_input_specs_to_delete ) > 0
293
+ or len (toplevel_output_specs_to_delete ) > 0
294
+ ):
295
+ _unsafe_adjust_original_program (
296
+ owning_program ,
297
+ call_delegate_node ,
298
+ toplevel_input_specs_to_delete ,
299
+ toplevel_output_specs_to_delete ,
300
+ )
301
+
285
302
return tagged_graph_module
286
303
287
304
288
305
def _partition_and_lower (
289
306
tagged_graph_module : torch .fx .GraphModule ,
290
307
partition_result : PartitionResult ,
291
308
owning_program : ExportedProgram ,
309
+ is_submodule : bool = False ,
292
310
) -> torch .fx .GraphModule :
293
311
"""
294
312
Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
295
313
"""
296
314
297
315
partitioned_module = _partition_and_lower_one_graph_module (
298
- tagged_graph_module , partition_result , owning_program
316
+ tagged_graph_module , partition_result , owning_program , is_submodule
299
317
)
300
318
301
319
# Recursively partition and lower for submodules
302
320
for name , submod , _node in get_control_flow_submodules (partitioned_module ):
303
321
partitioned_submodule = _partition_and_lower (
304
- submod , partition_result , owning_program
322
+ submod , partition_result , owning_program , is_submodule = True
305
323
)
306
324
tagged_graph_module .add_module (name , partitioned_submodule )
307
325
308
- # Run the export pass over the graph module so that the call delegate
309
- # nodes will match Edge dialect
310
- # TODO(angelayi): ExportPass will rerun the graph, however all we need
311
- # here is to add metadata to the call delegate nodes to preserve Edge
312
- # dialect. There's work going on to generate a random tensor from a
313
- # fake tensor and possibly it can help to address the issue.
314
- res = ExportPass ()(tagged_graph_module )
315
- assert res is not None
316
- tagged_graph_module = res .graph_module
317
-
318
326
return tagged_graph_module
319
327
320
328
@@ -349,6 +357,8 @@ def to_backend(
349
357
Returns:
350
358
ExportedProgram: The input program, with some portions targeted for delegation.
351
359
"""
360
+ edge_program ._validate ()
361
+
352
362
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
353
363
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
354
364
try :
@@ -377,26 +387,22 @@ def to_backend(
377
387
update_to_real_program (tagged_exported_program , edge_program )
378
388
379
389
for tag , _ in partitioner_result .partition_tags .items ():
380
- _maybe_duplicate_constant_nodes (tagged_exported_program , tag , edge_program )
390
+ _maybe_duplicate_constant_nodes (tagged_exported_program , tag )
381
391
382
392
tagged_graph_module = _partition_and_lower (
383
- tagged_exported_program .graph_module , partitioner_result , edge_program
393
+ tagged_exported_program .graph_module ,
394
+ partitioner_result ,
395
+ tagged_exported_program ,
384
396
)
385
397
386
- # TODO(angelayi): Update this signature in a less manual way (maybe through
387
- # retracing)
388
- new_signature , new_state_dict , new_constants = _get_new_signature (
389
- edge_program ,
390
- tagged_graph_module ,
391
- )
392
398
return ExportedProgram (
393
399
root = tagged_graph_module ,
394
400
graph = tagged_graph_module .graph ,
395
- graph_signature = new_signature ,
396
- state_dict = new_state_dict ,
397
- range_constraints = copy .deepcopy (edge_program .range_constraints ),
398
- module_call_graph = copy .deepcopy (edge_program .module_call_graph ),
401
+ graph_signature = tagged_exported_program . graph_signature ,
402
+ state_dict = tagged_exported_program . state_dict ,
403
+ range_constraints = copy .deepcopy (tagged_exported_program .range_constraints ),
404
+ module_call_graph = copy .deepcopy (tagged_exported_program .module_call_graph ),
399
405
example_inputs = None ,
400
- constants = new_constants ,
401
- verifiers = [edge_program .verifier ],
406
+ constants = tagged_exported_program . constants ,
407
+ verifiers = [tagged_exported_program .verifier ],
402
408
)
0 commit comments