8
8
9
9
import copy
10
10
import operator
11
+ from collections import defaultdict
11
12
from typing import Any , Dict , List , Optional , Set , Tuple , Union
12
13
13
14
import torch
@@ -488,8 +489,12 @@ def _get_new_signature( # noqa: C901
488
489
else {}
489
490
)
490
491
492
+ toplevel_output_node_to_sig : Dict [str , List [OutputSpec ]] = defaultdict (list )
493
+ if not is_submodule :
494
+ for output_spec in old_signature .output_specs :
495
+ toplevel_output_node_to_sig [output_spec .arg .name ].append (output_spec )
496
+
491
497
for node in gm .graph .nodes :
492
- is_tagged = tag is None or node .meta .get ("delegation_tag" , None ) == tag
493
498
if node .op == "placeholder" :
494
499
495
500
if node .name not in input_node_to_sig :
@@ -507,7 +512,7 @@ def _get_new_signature( # noqa: C901
507
512
if not isinstance (orig_input_spec .arg , TensorArgument ):
508
513
input_specs .append (orig_input_spec )
509
514
510
- elif is_tagged :
515
+ elif node . meta . get ( "delegation_tag" , None ) == tag :
511
516
input_specs .append (orig_input_spec )
512
517
513
518
if orig_input_spec .kind == InputKind .USER_INPUT :
@@ -551,11 +556,72 @@ def _get_new_signature( # noqa: C901
551
556
)
552
557
553
558
if node .op == "output" :
554
- output_nodes = pytree .tree_leaves ((node .args , node .kwargs ))
559
+ buffer_mutation_idxs : Dict [int , List [OutputSpec ]] = defaultdict (list )
560
+ for user in call_module_node .users .keys ():
561
+ if user .name in toplevel_output_node_to_sig :
562
+ assert (
563
+ user .op == "call_function" and user .target == operator .getitem
564
+ ), f"Invalid user { user } , node.op is { user .op } and node.target is { user .target } "
565
+ getitem_idx = user .args [1 ]
566
+ assert isinstance (
567
+ getitem_idx , int
568
+ ), f"Invalid getitem type: { type (getitem_idx )} "
569
+ buffer_mutation_idxs [getitem_idx ].extend (
570
+ toplevel_output_node_to_sig [user .name ]
571
+ )
555
572
556
- for output_node in output_nodes :
573
+ for i , output_node in enumerate (node .args [0 ]):
574
+ if i in buffer_mutation_idxs :
575
+ assert isinstance (output_node , torch .fx .Node )
576
+ orig_output_specs = buffer_mutation_idxs [i ]
577
+
578
+ if any (
579
+ orig_output_spec .kind == OutputKind .BUFFER_MUTATION
580
+ and orig_output_spec .target in new_state_dict
581
+ for orig_output_spec in orig_output_specs
582
+ ):
583
+ # If the delegate wants to consume the buffer, then the
584
+ # delegate should also consume the buffer mutation
585
+ # (output spec would be a BUFFER_MUTATION). Otherwise
586
+ # the delegate will just return the result of the
587
+ # mutation as a USER_OUTPUT.
588
+
589
+ orig_output_spec = [
590
+ orig_output_spec
591
+ for orig_output_spec in orig_output_specs
592
+ if orig_output_spec .kind == OutputKind .BUFFER_MUTATION
593
+ and orig_output_spec .target in new_state_dict
594
+ ][0 ]
595
+
596
+ assert len (orig_output_specs ) == 1 , (
597
+ f"Constant { orig_output_spec .target } was tagged to be "
598
+ "consumed by the buffer, and was found to also contain "
599
+ "a buffer mutation. However this buffer mutation node "
600
+ "was found to also be used as other types of outputs "
601
+ "which is currently not supported. Please file an "
602
+ "issue on Github. \n \n "
603
+ f"The toplevel program: { original_program } \n "
604
+ )
605
+ output_specs .append (
606
+ OutputSpec (
607
+ kind = OutputKind .BUFFER_MUTATION ,
608
+ arg = TensorArgument (name = output_node .name ),
609
+ target = orig_output_spec .target ,
610
+ )
611
+ )
612
+ output_specs_to_delete [orig_output_spec .arg .name ] = (
613
+ orig_output_spec
614
+ )
615
+ else :
616
+ output_specs .append (
617
+ OutputSpec (
618
+ kind = OutputKind .USER_OUTPUT ,
619
+ arg = TensorArgument (name = output_node .name ),
620
+ target = None ,
621
+ )
622
+ )
557
623
558
- if not isinstance (output_node , torch .fx .Node ):
624
+ elif not isinstance (output_node , torch .fx .Node ):
559
625
output_specs .append (
560
626
OutputSpec (
561
627
kind = OutputKind .USER_OUTPUT ,
@@ -630,6 +696,9 @@ def create_exported_program_from_submodule(
630
696
in_spec = pytree .tree_flatten ((tuple (subgraph_signature .user_inputs ), {}))[1 ]
631
697
out_spec = pytree .tree_flatten (subgraph_signature .user_outputs )[1 ]
632
698
699
+ print (submodule .graph )
700
+ print (subgraph_signature )
701
+
633
702
return (
634
703
ExportedProgram (
635
704
root = submodule ,
@@ -774,7 +843,7 @@ def get_lowered_backend_modules(
774
843
return lowered_programs
775
844
776
845
777
- def _unsafe_adjust_original_program (
846
+ def _unsafe_adjust_original_program ( # noqa: C901
778
847
original_program : ExportedProgram ,
779
848
call_delegate_node : torch .fx .Node ,
780
849
input_specs_to_delete : Dict [str , InputSpec ],
@@ -830,3 +899,50 @@ def _unsafe_adjust_original_program(
830
899
del original_program ._constants [input_spec .target ]
831
900
else :
832
901
raise RuntimeError (f"Invalid input spec { input_spec } received" )
902
+
903
+ # Delete buffer mutations from the output which were consumed by the delegate
904
+ toplevel_output_node = None
905
+ for node in reversed (original_program .graph .nodes ):
906
+ if node .op == "output" :
907
+ toplevel_output_node = node
908
+ break
909
+
910
+ assert toplevel_output_node is not None
911
+ assert (
912
+ len (toplevel_output_node .args ) == 1
913
+ ), f"Invalid output node: { toplevel_output_node } with args { toplevel_output_node .args } "
914
+
915
+ new_output_args = [
916
+ arg
917
+ for arg in toplevel_output_node .args [0 ]
918
+ if not isinstance (arg , torch .fx .Node ) or arg .name not in output_specs_to_delete
919
+ ]
920
+ toplevel_output_node .args = (tuple (new_output_args ),)
921
+
922
+ # Delete the buffer mutation getitem nodes
923
+ getitem_idxs : List [int ] = []
924
+ user_nodes = list (call_delegate_node .users .keys ())
925
+ for user in user_nodes :
926
+ if user .name in output_specs_to_delete :
927
+ assert (
928
+ user .op == "call_function" and user .target == operator .getitem
929
+ ), f"Invalid user { user } , node.op is { node .op } and node.target is { node .target } "
930
+ user_idx = user .args [1 ]
931
+ assert isinstance (user_idx , int ), f"Invalid getitem type: { type (user_idx )} "
932
+ getitem_idxs .append (user_idx )
933
+ original_program .graph .erase_node (user )
934
+
935
+ getitem_idxs .sort (reverse = True )
936
+
937
+ # Adjust all the getitem indices after the deleted getitems
938
+ user_nodes = list (call_delegate_node .users .keys ())
939
+ for user in user_nodes :
940
+ assert user .op == "call_function" and user .target == operator .getitem
941
+ user_idx = user .args [1 ]
942
+ assert isinstance (user_idx , int )
943
+ for i , idx in enumerate (getitem_idxs ):
944
+ if user_idx > idx :
945
+ user .args = (user .args [0 ], user_idx - (len (getitem_idxs ) - i ))
946
+ break
947
+
948
+ original_program ._validate ()
0 commit comments