6
6
7
7
import copy
8
8
import logging
9
- from contextlib import contextmanager
9
+ from contextlib import contextmanager , nullcontext
10
10
from functools import singledispatch
11
11
from typing import Generator , List
12
12
13
13
import torch
14
+ import torch .utils ._pytree as pytree
14
15
15
16
from executorch .exir .backend .backend_details import BackendDetails , PreprocessResult
16
17
from executorch .exir .backend .compile_spec_schema import CompileSpec
25
26
26
27
from executorch .exir .graph_module import get_control_flow_submodules
27
28
from executorch .exir .lowered_backend_module import (
28
- _get_new_signature ,
29
+ _unsafe_adjust_original_program ,
29
30
create_exported_program_from_submodule ,
30
31
create_submodule_from_nodes ,
31
32
LoweredBackendModule ,
32
33
)
33
- from executorch .exir .pass_base import ExportPass
34
34
from executorch .exir .program ._fake_program import (
35
35
get_fake_program ,
36
36
update_to_real_program ,
@@ -193,6 +193,7 @@ def _partition_and_lower_one_graph_module(
193
193
tagged_graph_module : torch .fx .GraphModule ,
194
194
partition_result : PartitionResult ,
195
195
owning_program : ExportedProgram ,
196
+ is_submodule : bool ,
196
197
) -> torch .fx .GraphModule :
197
198
"""
198
199
Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
@@ -210,21 +211,40 @@ def _partition_and_lower_one_graph_module(
210
211
211
212
logging .debug (f"For tag { tag } , found nodes { node_list } " )
212
213
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
213
- submodule , call_module_node = create_submodule_from_nodes (
214
- tagged_graph_module , node_list , tag
214
+
215
+ replace_ctx = (
216
+ tagged_graph_module ._set_replace_hook (
217
+ owning_program .graph_signature .get_replace_hook ()
218
+ )
219
+ if not is_submodule
220
+ else nullcontext ()
215
221
)
222
+ with replace_ctx :
223
+ submodule , call_module_node = create_submodule_from_nodes (
224
+ tagged_graph_module , node_list , tag
225
+ )
226
+
216
227
tagged_graph_module_output_node = [
217
228
node for node in tagged_graph_module .graph .nodes if node .op == "output"
218
- ]
229
+ ][ 0 ]
219
230
submodule_output_node = [
220
231
node for node in submodule .graph .nodes if node .op == "output"
221
- ]
222
- # Copy the output node meta from the original output node, because create_submodule_from_nodes doesn't cover the meta field
223
- submodule_output_node [0 ].meta = tagged_graph_module_output_node [0 ].meta
232
+ ][0 ]
233
+ # Copy the output node meta from the original output node, because
234
+ # create_submodule_from_nodes doesn't cover the meta field
235
+ submodule_output_node .meta = tagged_graph_module_output_node .meta
224
236
logging .debug (f"Partitioned graph module: { tagged_graph_module } " )
225
237
226
- submodule_program = create_exported_program_from_submodule (
227
- submodule , owning_program , tag
238
+ (
239
+ submodule_program ,
240
+ toplevel_input_specs_to_delete ,
241
+ toplevel_output_specs_to_delete ,
242
+ ) = create_exported_program_from_submodule (
243
+ submodule ,
244
+ owning_program ,
245
+ tag ,
246
+ call_module_node ,
247
+ is_submodule ,
228
248
)
229
249
230
250
lowered_submodule = to_backend (
@@ -257,64 +277,48 @@ def _partition_and_lower_one_graph_module(
257
277
call_delegate_node .meta ["debug_handle" ] = len (
258
278
tagged_graph_module .graph .nodes
259
279
)
280
+ call_delegate_node .meta ["val" ] = submodule_output_node .meta ["val" ]
260
281
call_module_node .replace_all_uses_with (call_delegate_node )
261
282
tagged_graph_module .graph .erase_node (call_module_node )
262
283
263
- # Delete all parameters/buffers consumed by the created exported program
264
- toplevel_signature = owning_program .graph_signature
265
- for node in tagged_graph_module .graph .nodes :
266
- # Find placeholders consumed by the delegate
267
- if node .op != "placeholder" or len (node .users ) != 0 :
268
- continue
269
-
270
- if node .name in toplevel_signature .inputs_to_buffers :
271
- # Delete the consumed buffers
272
- buffer_name = toplevel_signature .inputs_to_buffers .get (node .name )
273
- if buffer_name in owning_program .state_dict :
274
- owning_program .state_dict .pop (buffer_name )
275
- else :
276
- owning_program .constants .pop (buffer_name )
277
- tagged_graph_module .graph .erase_node (node )
278
- elif node .name in toplevel_signature .inputs_to_parameters :
279
- # Delete the consumed parameters
280
- param_name = toplevel_signature .inputs_to_parameters .get (node .name )
281
- owning_program .state_dict .pop (param_name )
282
- tagged_graph_module .graph .erase_node (node )
283
-
284
- tagged_graph_module .recompile ()
284
+ if is_submodule :
285
+ assert len (toplevel_input_specs_to_delete ) == 0
286
+ assert len (toplevel_output_specs_to_delete ) == 0
287
+ elif (
288
+ len (toplevel_input_specs_to_delete ) > 0
289
+ or len (toplevel_output_specs_to_delete ) > 0
290
+ ):
291
+ _unsafe_adjust_original_program (
292
+ owning_program ,
293
+ call_delegate_node ,
294
+ toplevel_input_specs_to_delete ,
295
+ toplevel_output_specs_to_delete ,
296
+ )
297
+
285
298
return tagged_graph_module
286
299
287
300
288
301
def _partition_and_lower (
289
302
tagged_graph_module : torch .fx .GraphModule ,
290
303
partition_result : PartitionResult ,
291
304
owning_program : ExportedProgram ,
305
+ is_submodule : bool = False ,
292
306
) -> torch .fx .GraphModule :
293
307
"""
294
308
Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
295
309
"""
296
310
297
311
partitioned_module = _partition_and_lower_one_graph_module (
298
- tagged_graph_module , partition_result , owning_program
312
+ tagged_graph_module , partition_result , owning_program , is_submodule
299
313
)
300
314
301
315
# Recursively partition and lower for submodules
302
316
for name , submod , _node in get_control_flow_submodules (partitioned_module ):
303
317
partitioned_submodule = _partition_and_lower (
304
- submod , partition_result , owning_program
318
+ submod , partition_result , owning_program , is_submodule = True
305
319
)
306
320
tagged_graph_module .add_module (name , partitioned_submodule )
307
321
308
- # Run the export pass over the graph module so that the call delegate
309
- # nodes will match Edge dialect
310
- # TODO(angelayi): ExportPass will rerun the graph, however all we need
311
- # here is to add metadata to the call delegate nodes to preserve Edge
312
- # dialect. There's work going on to generate a random tensor from a
313
- # fake tensor and possibly it can help to address the issue.
314
- res = ExportPass ()(tagged_graph_module )
315
- assert res is not None
316
- tagged_graph_module = res .graph_module
317
-
318
322
return tagged_graph_module
319
323
320
324
@@ -349,6 +353,8 @@ def to_backend(
349
353
Returns:
350
354
ExportedProgram: The input program, with some portions targeted for delegation.
351
355
"""
356
+ edge_program ._validate ()
357
+
352
358
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
353
359
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
354
360
try :
@@ -377,26 +383,22 @@ def to_backend(
377
383
update_to_real_program (tagged_exported_program , edge_program )
378
384
379
385
for tag , _ in partitioner_result .partition_tags .items ():
380
- _maybe_duplicate_constant_nodes (tagged_exported_program , tag , edge_program )
386
+ _maybe_duplicate_constant_nodes (tagged_exported_program , tag )
381
387
382
388
tagged_graph_module = _partition_and_lower (
383
- tagged_exported_program .graph_module , partitioner_result , edge_program
389
+ tagged_exported_program .graph_module ,
390
+ partitioner_result ,
391
+ tagged_exported_program ,
384
392
)
385
393
386
- # TODO(angelayi): Update this signature in a less manual way (maybe through
387
- # retracing)
388
- new_signature , new_state_dict , new_constants = _get_new_signature (
389
- edge_program ,
390
- tagged_graph_module ,
391
- )
392
394
return ExportedProgram (
393
395
root = tagged_graph_module ,
394
396
graph = tagged_graph_module .graph ,
395
- graph_signature = new_signature ,
396
- state_dict = new_state_dict ,
397
- range_constraints = copy .deepcopy (edge_program .range_constraints ),
398
- module_call_graph = copy .deepcopy (edge_program .module_call_graph ),
397
+ graph_signature = tagged_exported_program . graph_signature ,
398
+ state_dict = tagged_exported_program . state_dict ,
399
+ range_constraints = copy .deepcopy (tagged_exported_program .range_constraints ),
400
+ module_call_graph = copy .deepcopy (tagged_exported_program .module_call_graph ),
399
401
example_inputs = None ,
400
- constants = new_constants ,
401
- verifiers = [edge_program .verifier ],
402
+ constants = tagged_exported_program . constants ,
403
+ verifiers = [tagged_exported_program .verifier ],
402
404
)
0 commit comments