11
11
from executorch .exir .dynamic_shape import DynamicMemoryPlanningMode
12
12
from executorch .exir .emit import emit_program , EmitterOutput
13
13
from executorch .exir .error import ExportError , ExportErrorType , InternalError
14
- from executorch .exir .graph_module import (
15
- attach_export_graph_metadata ,
16
- EXIR_METADATA ,
17
- get_exir_meta ,
18
- make_export_graph_module ,
19
- reduce_graph_module ,
20
- )
21
14
from executorch .exir .pass_manager import PassManager , PassType
22
15
from executorch .exir .passes import (
23
16
aten_to_edge_passes ,
@@ -119,35 +112,17 @@ class ExecutorchBackendConfig:
119
112
120
113
121
114
# TODO(ycao): set up "__all__" to limit symbol exposure
122
- def _to_edge (expo_prog , config : EdgeCompileConfig ) -> "ExirExportedProgram" :
123
- meta = get_exir_meta (expo_prog .graph_module )
115
+ def _to_edge (ep , config : EdgeCompileConfig ) -> "ExirExportedProgram" :
124
116
if config ._check_ir_validity :
125
117
try :
126
- EXIRATenDialectVerifier ()(expo_prog .graph_module )
118
+ EXIRATenDialectVerifier ()(ep .graph_module )
127
119
except ExportError :
128
120
logging .info (
129
121
"If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, "
130
122
"like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))."
131
123
)
132
124
raise
133
- mutation = meta .mutation
134
- num_mutated = len (mutation ) if mutation is not None else 0
135
- output_node = next (iter (reversed (expo_prog .graph_module .graph .nodes )))
136
- assert output_node .op == "output"
137
- output_node .args = (output_node .args [0 ][num_mutated :],)
138
- expo_prog .graph_module .graph .eliminate_dead_code ()
139
125
140
- ep = ExirExportedProgram (
141
- expo_prog .graph_module ,
142
- expo_prog .graph_module .graph ,
143
- ExportGraphSignature ([], [], [], [], {}, {}, {}, None ),
144
- CallSpec (meta .in_spec , meta .out_spec ),
145
- {},
146
- {},
147
- [],
148
- False ,
149
- )
150
- attach_export_graph_metadata (ep .graph_module , meta )
151
126
op_replace_pass = [OpReplacePass ()] if config ._use_edge_ops else []
152
127
passes = aten_to_edge_passes .passes + op_replace_pass + config .passes
153
128
new_ep = ep .transform (* passes )
@@ -198,7 +173,6 @@ def transform(self, *passes: PassType) -> "ExirExportedProgram":
198
173
ep .equality_constraints ,
199
174
self .after_to_edge_passes ,
200
175
)
201
-
202
176
transformed_ep .graph_module .meta .update (ep .graph_module .meta )
203
177
transformed_ep .graph_module .meta .update (self .graph_module .meta )
204
178
return transformed_ep
@@ -227,24 +201,6 @@ def _to_server(
227
201
# return graph_module now.
228
202
return res .graph_module
229
203
230
- @property
231
- def meta (self ):
232
- return self .graph_module .meta
233
-
234
- @property
235
- def in_spec (self ):
236
- meta = get_exir_meta (self .graph_module )
237
- return meta .in_spec
238
-
239
- @property
240
- def out_spec (self ):
241
- meta = get_exir_meta (self .graph_module )
242
- return meta .out_spec
243
-
244
- @property
245
- def graph (self ):
246
- return self .graph_module .graph
247
-
248
204
@property
249
205
def code (self ):
250
206
return self .graph_module .code
@@ -257,7 +213,6 @@ def to_executorch(
257
213
raise RuntimeError ("Must run to_edge before to_executorch." )
258
214
config = config or ExecutorchBackendConfig ()
259
215
new_prog = self .transform (* edge_to_executorch_passes (config ))
260
- meta = get_exir_meta (new_prog .graph_module )
261
216
executorch_prog = ExecutorchProgram (
262
217
new_prog ,
263
218
emit_stacktrace = config .emit_stacktrace ,
@@ -266,20 +221,10 @@ def to_executorch(
266
221
constant_tensor_alignment = config .constant_tensor_alignment ,
267
222
delegate_alignment = config .delegate_alignment ,
268
223
)
269
- attach_export_graph_metadata (executorch_prog .graph_module , meta )
270
- # We only need to update the meta of the root graph module since it is
271
- # reconstructed in ExecutorchProgram. The submodules are exactly the
272
- # original submodules in new_prog.
273
224
executorch_prog .graph_module .meta .update (new_prog .graph_module .meta )
274
225
executorch_prog .graph_module .meta .update (self .graph_module .meta )
275
226
return executorch_prog
276
227
277
- def __reduce__ (
278
- self ,
279
- ) -> Tuple [Callable [..., "ExirExportedProgram" ], Tuple [bytes ]]:
280
- _ , (pickled_states ,) = self .graph_module .__reduce__ ()
281
- return (edge_gm_deserializer , (pickled_states ,))
282
-
283
228
def __deepcopy__ (self , memo : Optional [Dict [int , Any ]] = None ) -> "ExportedProgram" :
284
229
gm = self .graph_module .__deepcopy__ (memo )
285
230
new_ep = ExirExportedProgram (
@@ -292,9 +237,6 @@ def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "ExportedProgra
292
237
copy .deepcopy (self .equality_constraints ),
293
238
self .after_to_edge_passes ,
294
239
)
295
- attach_export_graph_metadata (
296
- new_ep .graph_module , get_exir_meta (self .graph_module )
297
- )
298
240
new_ep .graph_module .meta .update (self .graph_module .meta )
299
241
return new_ep
300
242
@@ -418,27 +360,6 @@ def get_multi_method_graph_module(self) -> "MultiMethodExirExportedProgram":
418
360
return self ._executorch_dialect_ir_program
419
361
420
362
421
- def edge_gm_deserializer (pickled_states : bytes ) -> "ExirExportedProgram" :
422
- loaded_gm = reduce_graph_module (pickled_states )
423
- # restore node.meta["val"], which is deleted before pickling
424
- annotated_gm = aten_to_edge_passes (loaded_gm ).graph_module
425
- meta = get_exir_meta (annotated_gm )
426
-
427
- ep = ExirExportedProgram (
428
- annotated_gm .module ,
429
- annotated_gm .graph ,
430
- ExportGraphSignature ([], [], [], [], {}, {}, {}, None ),
431
- CallSpec (meta .in_spec , meta .out_spec ),
432
- {},
433
- {},
434
- [],
435
- after_to_edge_passes = True ,
436
- )
437
- attach_export_graph_metadata (ep .graph_module , meta )
438
- ep .graph_module .meta .update (annotated_gm .graph_module .meta )
439
- return ep
440
-
441
-
442
363
# This is to bootstrap the missing meta["val"] when 1. ph consists of scalar
443
364
# 2. meta["val"] is not properly set in dispatch_trace.
444
365
def _instantiate_missing_placeholder_val_with_real_inputs (gm , args ):
@@ -468,7 +389,6 @@ def capture(
468
389
)
469
390
470
391
config = config or CaptureConfig ()
471
- mutation = None
472
392
out_spec = None
473
393
# TODO (zhxchen17) Always functionalize in a second pass no matter which path is taken.
474
394
flat_args = tuple (pytree .tree_flatten (args )[0 ])
@@ -578,15 +498,8 @@ def convert_to_fake(x):
578
498
tracing_mode = tracing_mode ,
579
499
_allow_non_fake_inputs = True ,
580
500
)(* args )
581
- module = make_export_graph_module (
582
- graph_module ,
583
- graph_module .graph ,
584
- )
585
501
586
- flatten_output (module )
587
- meta = get_exir_meta (module )
588
- meta .in_spec = in_spec
589
- meta .out_spec = out_spec
502
+ flatten_output (graph_module )
590
503
591
504
else :
592
505
warnings .warn (
@@ -604,24 +517,22 @@ def convert_to_fake(x):
604
517
raise InternalError (
605
518
"Using AOT mode is not supported for leagacy capture mode, please use pt2_mode=True instead."
606
519
)
607
- module = dispatch_trace (f , args )
608
- # TODO(shunting) move this into ExportModuleState
609
- meta = module .meta [EXIR_METADATA ]
610
- _instantiate_missing_placeholder_val_with_real_inputs (module , flat_args )
611
- module ._apply (torch .Tensor .contiguous )
612
- meta = get_exir_meta (module )
613
- meta .mutation = mutation # pyre-ignore
520
+ graph_module = dispatch_trace (f , args )
521
+ in_spec , out_spec = graph_module .in_spec , graph_module .out_spec
522
+
523
+ _instantiate_missing_placeholder_val_with_real_inputs (graph_module , flat_args )
524
+ graph_module ._apply (torch .Tensor .contiguous )
525
+
614
526
ep = ExirExportedProgram (
615
- module ,
616
- module .graph ,
527
+ graph_module ,
528
+ graph_module .graph ,
617
529
ExportGraphSignature ([], [], [], [], {}, {}, {}, None ),
618
- CallSpec (meta . in_spec , meta . out_spec ),
530
+ CallSpec (in_spec , out_spec ),
619
531
{},
620
532
{},
621
533
[],
622
534
False ,
623
535
)
624
- attach_export_graph_metadata (ep .graph_module , meta )
625
536
return ep
626
537
627
538
@@ -720,18 +631,6 @@ def access_property_of_default_method(self, property_name: str):
720
631
"to access property: { property_name } ."""
721
632
return getattr (default_program .graph_module , property_name )
722
633
723
- @property
724
- def meta (self ):
725
- return self .access_property_of_default_method ("meta" )
726
-
727
- @property
728
- def in_spec (self ):
729
- return self .meta [EXIR_METADATA ].in_spec
730
-
731
- @property
732
- def out_spec (self ):
733
- return self .meta [EXIR_METADATA ].out_spec
734
-
735
634
@property
736
635
def graph (self ):
737
636
return self .access_property_of_default_method ("graph" )
0 commit comments