@@ -488,8 +488,16 @@ def _get_new_signature( # noqa: C901
488
488
else {}
489
489
)
490
490
491
+ toplevel_output_node_to_sig : Dict [str , OutputSpec ] = (
492
+ {
493
+ output_spec .arg .name : output_spec
494
+ for output_spec in old_signature .output_specs
495
+ }
496
+ if not is_submodule
497
+ else {}
498
+ )
499
+
491
500
for node in gm .graph .nodes :
492
- is_tagged = tag is None or node .meta .get ("delegation_tag" , None ) == tag
493
501
if node .op == "placeholder" :
494
502
495
503
if node .name not in input_node_to_sig :
@@ -507,7 +515,7 @@ def _get_new_signature( # noqa: C901
507
515
if not isinstance (orig_input_spec .arg , TensorArgument ):
508
516
input_specs .append (orig_input_spec )
509
517
510
- elif is_tagged :
518
+ elif node . meta . get ( "delegation_tag" , None ) == tag :
511
519
input_specs .append (orig_input_spec )
512
520
513
521
if orig_input_spec .kind == InputKind .USER_INPUT :
@@ -551,11 +559,55 @@ def _get_new_signature( # noqa: C901
551
559
)
552
560
553
561
if node .op == "output" :
554
- output_nodes = pytree .tree_leaves ((node .args , node .kwargs ))
562
+ buffer_mutation_idxs : Dict [int , OutputSpec ] = {}
563
+ for user in call_module_node .users .keys ():
564
+ if user .name in toplevel_output_node_to_sig :
565
+ assert (
566
+ user .op == "call_function" and user .target == operator .getitem
567
+ ), f"Invalid user { user } , node.op is { user .op } and node.target is { user .target } "
568
+ getitem_idx = user .args [1 ]
569
+ assert isinstance (
570
+ getitem_idx , int
571
+ ), f"Invalid getitem type: { type (getitem_idx )} "
572
+ buffer_mutation_idxs [getitem_idx ] = toplevel_output_node_to_sig [
573
+ user .name
574
+ ]
575
+
576
+ for i , output_node in enumerate (node .args [0 ]):
577
+ if i in buffer_mutation_idxs :
578
+ assert isinstance (output_node , torch .fx .Node )
579
+ orig_output_spec = buffer_mutation_idxs [i ]
580
+
581
+ if (
582
+ orig_output_spec .kind == OutputKind .BUFFER_MUTATION
583
+ and orig_output_spec .target in new_state_dict
584
+ ):
585
+ # If the delegate wants to consume the buffer, then
586
+ # the delegate should also consume the buffer
587
+ # mutation (output spec would be a BUFFER_MUTATION).
588
+ # Otherwise the delegate will just return the result
589
+ # of the mutation as a USER_OUTPUT.
590
+ output_specs .append (
591
+ OutputSpec (
592
+ kind = OutputKind .BUFFER_MUTATION ,
593
+ arg = TensorArgument (name = output_node .name ),
594
+ target = orig_output_spec .target ,
595
+ )
596
+ )
597
+ output_specs_to_delete [orig_output_spec .arg .name ] = (
598
+ orig_output_spec
599
+ )
555
600
556
- for output_node in output_nodes :
601
+ else :
602
+ output_specs .append (
603
+ OutputSpec (
604
+ kind = OutputKind .USER_OUTPUT ,
605
+ arg = TensorArgument (name = output_node .name ),
606
+ target = None ,
607
+ )
608
+ )
557
609
558
- if not isinstance (output_node , torch .fx .Node ):
610
+ elif not isinstance (output_node , torch .fx .Node ):
559
611
output_specs .append (
560
612
OutputSpec (
561
613
kind = OutputKind .USER_OUTPUT ,
@@ -774,7 +826,7 @@ def get_lowered_backend_modules(
774
826
return lowered_programs
775
827
776
828
777
- def _unsafe_adjust_original_program (
829
+ def _unsafe_adjust_original_program ( # noqa: C901
778
830
original_program : ExportedProgram ,
779
831
call_delegate_node : torch .fx .Node ,
780
832
input_specs_to_delete : Dict [str , InputSpec ],
@@ -830,3 +882,50 @@ def _unsafe_adjust_original_program(
830
882
del original_program ._constants [input_spec .target ]
831
883
else :
832
884
raise RuntimeError (f"Invalid input spec { input_spec } received" )
885
+
886
+ # Delete buffer mutations from the output which were consumed by the delegate
887
+ toplevel_output_node = None
888
+ for node in reversed (original_program .graph .nodes ):
889
+ if node .op == "output" :
890
+ toplevel_output_node = node
891
+ break
892
+
893
+ assert toplevel_output_node is not None
894
+ assert (
895
+ len (toplevel_output_node .args ) == 1
896
+ ), f"Invalid output node: { toplevel_output_node } with args { toplevel_output_node .args } "
897
+
898
+ new_output_args = [
899
+ arg
900
+ for arg in toplevel_output_node .args [0 ]
901
+ if not isinstance (arg , torch .fx .Node ) or arg .name not in output_specs_to_delete
902
+ ]
903
+ toplevel_output_node .args = (tuple (new_output_args ),)
904
+
905
+ # Delete the buffer mutation getitem nodes
906
+ getitem_idxs : List [int ] = []
907
+ user_nodes = list (call_delegate_node .users .keys ())
908
+ for user in user_nodes :
909
+ if user .name in output_specs_to_delete :
910
+ assert (
911
+ user .op == "call_function" and user .target == operator .getitem
912
+ ), f"Invalid user { user } , node.op is { node .op } and node.target is { node .target } "
913
+ user_idx = user .args [1 ]
914
+ assert isinstance (user_idx , int ), f"Invalid getitem type: { type (user_idx )} "
915
+ getitem_idxs .append (user_idx )
916
+ original_program .graph .erase_node (user )
917
+
918
+ getitem_idxs .sort (reverse = True )
919
+
920
+ # Adjust all the getitem indices after the deleted getitems
921
+ user_nodes = list (call_delegate_node .users .keys ())
922
+ for user in user_nodes :
923
+ assert user .op == "call_function" and user .target == operator .getitem
924
+ user_idx = user .args [1 ]
925
+ assert isinstance (user_idx , int )
926
+ for i , idx in enumerate (getitem_idxs ):
927
+ if user_idx > idx :
928
+ user .args = (user .args [0 ], user_idx - (len (getitem_idxs ) - i ))
929
+ break
930
+
931
+ original_program ._validate ()
0 commit comments