5
5
from collections import namedtuple
6
6
from dataclasses import dataclass , field
7
7
from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
8
+ from unittest .mock import patch
8
9
9
10
import sympy
10
11
import torch
12
+ import torch ._export
11
13
from executorch .exir .dynamic_shape import DynamicMemoryPlanningMode
12
14
from executorch .exir .emit import emit_program , EmitterOutput
13
15
from executorch .exir .error import ExportError , ExportErrorType , InternalError
25
27
from executorch .exir .schema import Program
26
28
from executorch .exir .serialize import serialize_to_flatbuffer
27
29
from executorch .exir .tracer import (
30
+ _default_decomposition_table ,
28
31
dispatch_trace ,
29
32
dynamo_trace ,
30
33
ExirDynamoConfig ,
41
44
from torch ._dynamo .eval_frame import Constraint
42
45
from torch ._export import CallSpec , export , ExportGraphSignature
43
46
from torch ._export .exported_program import ExportedProgram
47
+ from torch ._export .passes import ReplaceViewOpsWithViewCopyOpsPass
44
48
from torch ._export .passes .add_runtime_assertions_for_constraints_pass import (
45
49
InputDim ,
46
50
RangeConstraint ,
49
53
from torch .fx ._compatibility import compatibility
50
54
from torch .fx .experimental .proxy_tensor import make_fx
51
55
from torch .fx .experimental .symbolic_shapes import ShapeEnv
56
+ from torch .fx .graph import _PyTreeCodeGen , _PyTreeInfo
52
57
from torch .utils import _pytree as pytree
53
58
54
59
55
60
Val = Any
56
61
57
62
63
+ def _unlift (gm , inp_pos_to_param_buffer_name , in_spec , out_spec , state_dict ):
64
+ count = 0
65
+ # Step 1: make lifted params as get_attr
66
+ for node in gm .graph .nodes :
67
+ if node .op == "placeholder" :
68
+ if count in inp_pos_to_param_buffer_name :
69
+ with gm .graph .inserting_after (node ):
70
+ getattr_node = gm .graph .get_attr (
71
+ inp_pos_to_param_buffer_name [count ]
72
+ )
73
+ node .replace_all_uses_with (getattr_node )
74
+ metadata = node .meta
75
+ gm .graph .erase_node (node )
76
+ getattr_node .meta = metadata
77
+ count += 1
78
+
79
+ # Step 2: Fix the input/output of the graph now that we deleted
80
+ # some args.
81
+ gm .graph .lint ()
82
+ names = [f"arg_{ i } " for i in range (len (in_spec .children_specs ))]
83
+ gm .graph ._codegen = _PyTreeCodeGen (
84
+ _PyTreeInfo (
85
+ names ,
86
+ in_spec ,
87
+ out_spec ,
88
+ )
89
+ )
90
+ gm .recompile ()
91
+
92
+ # Step 3: Find state references in HigherOrderOps and recursively
93
+ # fix them.
94
+ for node in gm .graph .nodes :
95
+ if node .op == "call_function" and node .target == torch .ops .cond :
96
+ pred , true_graph , false_graph , operands = node .args
97
+ true_gm = getattr (gm , true_graph .name )
98
+ false_gm = getattr (gm , false_graph .name )
99
+ inp_pos_to_param_buffer_name_for_submod = {}
100
+ real_operands = []
101
+ for ix , operand in enumerate (operands ):
102
+ if operand .target in inp_pos_to_param_buffer_name .values ():
103
+ inp_pos_to_param_buffer_name_for_submod [ix ] = operand .target
104
+ true_gm .register_buffer (operand .target , state_dict [operand .target ])
105
+ false_gm .register_buffer (operand .target , state_dict [operand .target ])
106
+ else :
107
+ real_operands .append (operand )
108
+ node .args = (pred , true_graph , false_graph , real_operands )
109
+
110
+ _ , in_spec = pytree .tree_flatten (real_operands )
111
+
112
+ _unlift (
113
+ true_gm ,
114
+ inp_pos_to_param_buffer_name_for_submod ,
115
+ in_spec ,
116
+ None ,
117
+ state_dict ,
118
+ )
119
+ _unlift (
120
+ false_gm ,
121
+ inp_pos_to_param_buffer_name_for_submod ,
122
+ in_spec ,
123
+ None ,
124
+ state_dict ,
125
+ )
126
+ if node .op == "call_function" and node .target .__name__ == "map_impl" :
127
+ body_graph , num_mapped , * operands = node .args
128
+ body_gm = getattr (gm , body_graph .name )
129
+ inp_pos_to_buffer_name_for_submod = {}
130
+ real_operands = []
131
+ for ix , operand in enumerate (operands ):
132
+ if operand .target in inp_pos_to_param_buffer_name .values ():
133
+ inp_pos_to_buffer_name_for_submod [ix ] = operand .target
134
+ body_gm .register_buffer (operand .target , state_dict [operand .target ])
135
+ else :
136
+ real_operands .append (operand )
137
+ node .args = (body_graph , num_mapped , * real_operands )
138
+
139
+ _ , in_spec = pytree .tree_flatten (real_operands )
140
+
141
+ _unlift (
142
+ body_gm , inp_pos_to_buffer_name_for_submod , in_spec , None , state_dict
143
+ )
144
+ gm .graph .lint ()
145
+ gm .graph .eliminate_dead_code ()
146
+ gm .recompile ()
147
+ return gm
148
+
149
+
150
+ def unlift_exported_program_lifted_states (
151
+ ep : torch ._export .exported_program .ExportedProgram ,
152
+ ):
153
+ new_gm = copy .deepcopy (ep .graph_module )
154
+
155
+ # TODO Fix the period in params/buffers names later
156
+ # maybe a pass to replace graph signature with fixed names
157
+ param_buffer_name_to_corrected_name = {}
158
+
159
+ for name , stuff in ep .state_dict .items ():
160
+ if name in ep .graph_signature .buffers :
161
+ if "." in name :
162
+ new_gm .register_buffer (name .replace ("." , "_" ), stuff )
163
+ param_buffer_name_to_corrected_name [name ] = name .replace ("." , "_" )
164
+ else :
165
+ new_gm .register_buffer (name , stuff )
166
+ elif name in ep .graph_signature .parameters :
167
+ if "." in name :
168
+ new_gm .register_parameter (name .replace ("." , "_" ), stuff )
169
+ param_buffer_name_to_corrected_name [name ] = name .replace ("." , "_" )
170
+ else :
171
+ new_gm .register_parameter (name , stuff )
172
+ else :
173
+ raise AssertionError ("encountered not registered param/buffer" )
174
+
175
+ count = 0
176
+ inp_pos_to_param_buffer_name = {}
177
+ for node in new_gm .graph .nodes :
178
+ if node .op == "placeholder" :
179
+ if node .name in ep .graph_signature .inputs_to_buffers :
180
+ buffer_name = ep .graph_signature .inputs_to_buffers [node .name ]
181
+ if buffer_name in param_buffer_name_to_corrected_name :
182
+ inp_pos_to_param_buffer_name [
183
+ count
184
+ ] = param_buffer_name_to_corrected_name [buffer_name ]
185
+ else :
186
+ inp_pos_to_param_buffer_name [count ] = buffer_name
187
+ if node .name in ep .graph_signature .inputs_to_parameters :
188
+ param_name = ep .graph_signature .inputs_to_parameters [node .name ]
189
+ if param_name in param_buffer_name_to_corrected_name :
190
+ inp_pos_to_param_buffer_name [
191
+ count
192
+ ] = param_buffer_name_to_corrected_name [param_name ]
193
+ else :
194
+ inp_pos_to_param_buffer_name [count ] = param_name
195
+ count += 1
196
+ new_gm = _unlift (
197
+ new_gm ,
198
+ inp_pos_to_param_buffer_name ,
199
+ ep .call_spec .in_spec ,
200
+ ep .call_spec .out_spec ,
201
+ ep .state_dict ,
202
+ )
203
+ return new_gm
204
+
205
+
58
206
@compatibility (is_backward_compatible = False )
59
207
@dataclass
60
208
class CaptureConfig :
@@ -63,6 +211,7 @@ class CaptureConfig:
63
211
enable_dynamic_shape : bool = False
64
212
enable_aot : bool = False
65
213
_dynamo_config : "ExirDynamoConfig" = ExirDynamoConfig ()
214
+ _unlift : bool = False
66
215
67
216
68
217
@compatibility (is_backward_compatible = False )
@@ -400,8 +549,15 @@ def capture(
400
549
"Functionalization is required for enable_aot." ,
401
550
)
402
551
403
- ep = export (f , args , _add_runtime_assertions = False , constraints = constraints )
404
- return ep # pyre-ignore
552
+ # TODO remove this later
553
+ with patch ("torch._export.DECOMP_TABLE" , _default_decomposition_table ()):
554
+ ep = export (
555
+ f , args , _add_runtime_assertions = False , constraints = constraints
556
+ )
557
+ ep = ep .transform (ReplaceViewOpsWithViewCopyOpsPass ())
558
+ if not config ._unlift :
559
+ return ep # pyre-ignore
560
+ graph_module = unlift_exported_program_lifted_states (ep )
405
561
406
562
elif config .enable_dynamic_shape :
407
563
if not config ._dynamo_config .dynamic_shapes :
0 commit comments