27
27
28
28
from torch ._subclasses import FakeTensor
29
29
from torch .export .exported_program import (
30
+ ConstantArgument ,
30
31
ExportedProgram ,
31
32
ExportGraphSignature ,
32
33
InputKind ,
@@ -422,7 +423,7 @@ def arrange_graph_placeholders(
422
423
423
424
424
425
# TODO Don't regenerate new signature manually.
425
- def _get_new_signature (
426
+ def _get_new_signature ( # noqa: C901
426
427
original_program : ExportedProgram ,
427
428
gm : torch .fx .GraphModule ,
428
429
tag : Optional [str ] = None ,
@@ -431,6 +432,18 @@ def _get_new_signature(
431
432
Dict [str , Union [torch .Tensor , torch .nn .Parameter ]],
432
433
Dict [str , Union [torch .Tensor , torch .ScriptObject ]],
433
434
]:
435
+ """
436
+ Args:
437
+ tag: If tag is None, this means that we are constructing the graph
438
+ signature for the toplevel graph, after delegation. We need to do this
439
+ because sometimes delegates will swallow some parameters/buffers, so we
440
+ need to update the graph signature/state dict to reflect these changes.
441
+ Otherwise, if tag is not None, this means we are constructing the graph
442
+ signature for the delegated modules. In this case, we need to look
443
+ through the input nodes and see which ones were originally
444
+ parameters/buffers, and lower them down to the delegate.
445
+ """
446
+
434
447
old_signature = original_program .graph_signature
435
448
436
449
input_specs = []
@@ -441,84 +454,104 @@ def _get_new_signature(
441
454
new_state_dict = {}
442
455
new_constants = {}
443
456
444
- non_persistent_buffers = set (old_signature .non_persistent_buffers )
457
+ input_tensor_node_to_sig = {
458
+ input_spec .arg .name : input_spec
459
+ for input_spec in old_signature .input_specs
460
+ if isinstance (input_spec .arg , TensorArgument )
461
+ }
445
462
446
463
for node in gm .graph .nodes :
447
- is_tagged = node .meta .get ("delegation_tag" , None ) == tag
464
+ is_tagged = tag is None or node .meta .get ("delegation_tag" , None ) == tag
448
465
if node .op == "placeholder" :
449
- if node .name in old_signature .inputs_to_parameters and is_tagged :
450
- parameter_name = old_signature .inputs_to_parameters [node .name ]
451
- # add param to graph signature
452
- input_specs .append (
453
- InputSpec (
454
- kind = InputKind .PARAMETER ,
455
- arg = TensorArgument (name = node .name ),
456
- target = parameter_name ,
457
- )
458
- )
459
466
460
- # add param to state_dict
461
- new_state_dict [parameter_name ] = original_program .state_dict [
462
- parameter_name
463
- ]
464
- elif node .name in old_signature .inputs_to_buffers and is_tagged :
465
- buffer_name = old_signature .inputs_to_buffers [node .name ]
466
- persistent = buffer_name not in non_persistent_buffers
467
- # add buffer to graph signature
467
+ if node .name not in input_tensor_node_to_sig :
468
+ assert tag is not None
468
469
input_specs .append (
469
470
InputSpec (
470
- kind = InputKind .BUFFER ,
471
+ kind = InputKind .USER_INPUT ,
471
472
arg = TensorArgument (name = node .name ),
472
- target = buffer_name ,
473
- persistent = persistent ,
473
+ target = None ,
474
474
)
475
475
)
476
+ continue
476
477
477
- # add param to new_state_dict
478
- if persistent :
479
- new_state_dict [buffer_name ] = original_program .state_dict [
480
- buffer_name
481
- ]
482
- else :
483
- new_constants [buffer_name ] = original_program .constants [buffer_name ]
484
- elif (
485
- node .name in old_signature .inputs_to_lifted_tensor_constants
486
- and is_tagged
487
- ):
488
- constant_name = old_signature .inputs_to_lifted_tensor_constants [
489
- node .name
490
- ]
491
- # add constant to graph signature
492
- input_specs .append (
493
- InputSpec (
494
- kind = InputKind .CONSTANT_TENSOR ,
495
- arg = TensorArgument (name = node .name ),
496
- target = constant_name ,
478
+ orig_input_spec = input_tensor_node_to_sig [node .name ]
479
+
480
+ if not isinstance (orig_input_spec .arg , TensorArgument ):
481
+ input_specs .append (orig_input_spec )
482
+
483
+ elif is_tagged :
484
+ input_specs .append (orig_input_spec )
485
+
486
+ if orig_input_spec .kind in (InputKind .PARAMETER , InputKind .BUFFER ):
487
+ new_state_dict [orig_input_spec .target ] = (
488
+ original_program .state_dict [orig_input_spec .target ]
497
489
)
498
- )
490
+ elif orig_input_spec .kind in (
491
+ InputKind .CONSTANT_TENSOR ,
492
+ InputKind .CUSTOM_OBJ ,
493
+ ):
494
+ new_constants [orig_input_spec .target ] = original_program .constants [
495
+ orig_input_spec .target
496
+ ]
499
497
500
- # add constant to new_constants
501
- new_constants [constant_name ] = original_program .constants [constant_name ]
502
498
else :
503
- # not param, buffer, or lifted_tensor_constant then user input
504
499
input_specs .append (
505
500
InputSpec (
506
501
kind = InputKind .USER_INPUT ,
507
502
arg = TensorArgument (name = node .name ),
508
503
target = None ,
509
504
)
510
505
)
506
+
511
507
if node .op == "output" :
512
- for output in pytree .tree_leaves ((node .args , node .kwargs )):
513
- if not isinstance (output , torch .fx .Node ):
514
- continue
515
- output_specs .append (
516
- OutputSpec (
517
- kind = OutputKind .USER_OUTPUT ,
518
- arg = TensorArgument (name = output .name ),
519
- target = None ,
520
- )
521
- )
508
+ output_nodes = pytree .tree_leaves ((node .args , node .kwargs ))
509
+
510
+ if tag is not None :
511
+ # We are constructing output_specs for the delegate outputs.
512
+ # These don't have any buffer mutations.
513
+
514
+ for output_node in output_nodes :
515
+ if not isinstance (output_node , torch .fx .Node ):
516
+ output_specs .append (
517
+ OutputSpec (
518
+ kind = OutputKind .USER_OUTPUT ,
519
+ arg = ConstantArgument (output_node ),
520
+ target = None ,
521
+ )
522
+ )
523
+ else :
524
+ output_specs .append (
525
+ OutputSpec (
526
+ kind = OutputKind .USER_OUTPUT ,
527
+ arg = TensorArgument (name = output_node .name ),
528
+ target = None ,
529
+ )
530
+ )
531
+
532
+ else :
533
+ # We are reconstruting the toplevel module which contains
534
+ # delegates. Delegation should not change the number of outputs
535
+ # in the toplevel module, and it does not touch the mutated buffers
536
+
537
+ assert len (old_signature .output_specs ) == len (output_nodes )
538
+ for prev_output_spec , output_node in zip (
539
+ old_signature .output_specs , output_nodes
540
+ ):
541
+ if not isinstance (output_node , torch .fx .Node ):
542
+ assert isinstance (prev_output_spec .arg , ConstantArgument )
543
+ output_specs .append (
544
+ OutputSpec (
545
+ kind = OutputKind .USER_OUTPUT ,
546
+ arg = ConstantArgument (output_node ),
547
+ target = None ,
548
+ )
549
+ )
550
+
551
+ else :
552
+ new_output_spec = copy .deepcopy (prev_output_spec )
553
+ new_output_spec .arg .name = output_node .name
554
+ output_specs .append (new_output_spec )
522
555
523
556
return new_signature , new_state_dict , new_constants
524
557
0 commit comments