36
36
)
37
37
from torch ._export import ExportedProgram
38
38
from torch ._export .passes import ReplaceViewOpsWithViewCopyOpsPass
39
- from torch .export .exported_program import InputKind , InputSpec , TensorArgument
39
+ from torch .export .exported_program import (
40
+ _get_updated_range_constraints ,
41
+ ConstantArgument ,
42
+ ExportGraphSignature ,
43
+ InputKind ,
44
+ InputSpec ,
45
+ OutputSpec ,
46
+ TensorArgument ,
47
+ )
40
48
from torch .fx import _pytree as fx_pytree
41
49
from torch .fx ._compatibility import compatibility
50
+ from torch .fx .passes .infra .pass_manager import PassManager
42
51
from torch .utils import _pytree as pytree
43
52
44
53
Val = Any
45
54
46
55
56
+ def _get_updated_graph_signature (
57
+ old_signature : ExportGraphSignature ,
58
+ new_gm : torch .fx .GraphModule ,
59
+ ) -> ExportGraphSignature :
60
+ """
61
+ Update the graph signature's user_input/user_outputs.
62
+ """
63
+ new_input_specs = []
64
+ for i , node in enumerate (new_gm .graph .nodes ):
65
+ if node .op != "placeholder" :
66
+ break
67
+
68
+ assert i < len (
69
+ old_signature .input_specs
70
+ ), "Number of inputs changed after transformation"
71
+ old_input_spec = old_signature .input_specs [i ]
72
+ arg = (
73
+ old_input_spec .arg
74
+ if isinstance (old_input_spec .arg , ConstantArgument )
75
+ else type (old_input_spec .arg )(node .name )
76
+ )
77
+ new_input_specs .append (
78
+ InputSpec (old_input_spec .kind , arg , old_input_spec .target )
79
+ )
80
+
81
+ output_node = list (new_gm .graph .nodes )[- 1 ]
82
+ assert output_node .op == "output"
83
+
84
+ new_output_specs = []
85
+ for i , node in enumerate (output_node .args [0 ]):
86
+ assert i < len (
87
+ old_signature .output_specs
88
+ ), "Number of outputs changed after transformation"
89
+ old_output_spec = old_signature .output_specs [i ]
90
+ arg = (
91
+ old_output_spec .arg
92
+ if isinstance (old_output_spec .arg , ConstantArgument )
93
+ else type (old_output_spec .arg )(node .name )
94
+ )
95
+ new_output_specs .append (
96
+ OutputSpec (old_output_spec .kind , arg , old_output_spec .target )
97
+ )
98
+
99
+ new_signature = ExportGraphSignature (
100
+ input_specs = new_input_specs , output_specs = new_output_specs
101
+ )
102
+ return new_signature
103
+
104
+
105
+ def _transform (self , * passes : PassType ) -> "ExportedProgram" :
106
+ pm = PassManager (list (passes ))
107
+ res = pm (self .graph_module )
108
+ transformed_gm = res .graph_module if res is not None else self .graph_module
109
+ assert transformed_gm is not None
110
+
111
+ if transformed_gm is self .graph_module and not res .modified :
112
+ return self
113
+
114
+ transformed_ep = ExportedProgram (
115
+ transformed_gm ,
116
+ transformed_gm .graph ,
117
+ _get_updated_graph_signature (self .graph_signature , transformed_gm ),
118
+ self .state_dict ,
119
+ _get_updated_range_constraints (transformed_gm ),
120
+ copy .deepcopy (self .equality_constraints ),
121
+ copy .deepcopy (self ._module_call_graph ),
122
+ self .example_inputs ,
123
+ self .verifier ,
124
+ self .tensor_constants ,
125
+ )
126
+ transformed_ep .graph_module .meta .update (self .graph_module .meta )
127
+ transformed_ep .graph_module .meta .update (res .graph_module .meta )
128
+ return transformed_ep
129
+
130
+
47
131
def _copy_module (new_prog , new_gm ):
48
132
new_prog .meta .update (new_gm .meta )
49
133
new_prog .graph = new_gm .graph
@@ -231,7 +315,7 @@ def __init__(
231
315
self .after_to_edge_passes = after_to_edge_passes
232
316
233
317
def transform (self , * passes : PassType ) -> "ExirExportedProgram" :
234
- self .exported_program = self .exported_program . _transform ( * passes )
318
+ self .exported_program = _transform ( self .exported_program , * passes )
235
319
return self
236
320
237
321
def __call__ (self , * args : Any ) -> Any :
@@ -419,7 +503,7 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
419
503
new_ep .exported_program = ExportedProgram (
420
504
new_gm ,
421
505
new_gm .graph ,
422
- new_ep .exported_program .graph_signature ,
506
+ _get_updated_graph_signature ( new_ep .exported_program .graph_signature , new_gm ) ,
423
507
new_ep .exported_program .state_dict ,
424
508
new_ep .exported_program .range_constraints ,
425
509
new_ep .exported_program .equality_constraints ,
@@ -755,7 +839,9 @@ def to_edge(
755
839
edge_program = ExportedProgram (
756
840
root = gm ,
757
841
graph = gm .graph ,
758
- graph_signature = edge_program .graph_signature ,
842
+ graph_signature = _get_updated_graph_signature (
843
+ edge_program .graph_signature , gm
844
+ ),
759
845
state_dict = edge_program .state_dict ,
760
846
range_constraints = edge_program .range_constraints ,
761
847
equality_constraints = edge_program .equality_constraints ,
@@ -770,7 +856,7 @@ def to_edge(
770
856
)
771
857
passes = []
772
858
passes .extend (aten_to_edge_passes .passes [- 2 :])
773
- edge_program = edge_program . _transform (* passes )
859
+ edge_program = _transform (edge_program , * passes )
774
860
edge_programs [name ] = edge_program
775
861
return EdgeProgramManager (edge_programs , constant_methods , config )
776
862
@@ -856,7 +942,7 @@ def transform(
856
942
if isinstance (passes , dict ):
857
943
for name , program in self ._edge_programs .items ():
858
944
if name in passes .keys ():
859
- new_programs [name ] = program . _transform (* passes [name ])
945
+ new_programs [name ] = _transform (program , * passes [name ])
860
946
EXIREdgeDialectVerifier (enable = check_ir_validity )(
861
947
new_programs [name ].graph_module
862
948
)
@@ -865,7 +951,7 @@ def transform(
865
951
866
952
else : # apply passes to every method
867
953
for name , program in self ._edge_programs .items ():
868
- new_programs [name ] = program . _transform (* passes )
954
+ new_programs [name ] = _transform (program , * passes )
869
955
EXIREdgeDialectVerifier (enable = check_ir_validity )(
870
956
new_programs [name ].graph_module
871
957
)
0 commit comments