@@ -89,7 +89,10 @@ def __init__(
89
89
dynamic_shapes : Optional [Any ] = None ,
90
90
):
91
91
self .model = model
92
- self .pre_autograd_exported_program : Optional [ExportedProgram ] = None
92
+ # Note: treat this as the source of truth for the result of
93
+ # torch.export'ing a model. If the overall ExportedProgram is needed,
94
+ # make sure to re-export this graph module to persist any changes. See
95
+ # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
93
96
self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
94
97
self .modelname = modelname
95
98
self .max_seq_len = max_seq_len
@@ -184,7 +187,7 @@ def _get_edge_config(self) -> EdgeCompileConfig:
184
187
)
185
188
return edge_config
186
189
187
- def export (self ) -> "LLMEdgeManager" :
190
+ def _export (self , module : Optional [ torch . nn . Module ] = None ) -> ExportedProgram :
188
191
dynamic_shape = self ._get_dynamic_shape ()
189
192
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
190
193
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -201,29 +204,42 @@ def export(self) -> "LLMEdgeManager":
201
204
# TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
202
205
# functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
203
206
exported_module = torch .export .export (
204
- self .model ,
207
+ self .model if not module else module ,
205
208
self .example_inputs ,
206
209
self .example_kwarg_inputs ,
207
210
dynamic_shapes = dynamic_shape ,
208
211
strict = True ,
209
212
)
210
213
else :
211
- logging .info ("Exporting with:" )
214
+ if module :
215
+ logging .info ("Re-exporting with:" )
216
+ else :
217
+ logging .info ("Exporting with:" )
212
218
logging .info (f"inputs: { self .example_inputs } " )
213
219
logging .info (f"kwargs: { self .example_kwarg_inputs } " )
214
220
logging .info (f"dynamic shapes: { dynamic_shape } " )
215
221
exported_module = export_for_training (
216
- self .model ,
222
+ self .model if not module else module ,
217
223
self .example_inputs ,
218
224
kwargs = self .example_kwarg_inputs ,
219
225
dynamic_shapes = dynamic_shape ,
220
226
)
221
- # `Module`.
222
- self .pre_autograd_exported_program = exported_module
223
- self .pre_autograd_graph_module = exported_module .module ()
224
- if hasattr (self .args , "export_only" ) and self .args .export_only :
225
- torch .export .save (exported_module , self .args .output_name )
227
+ return exported_module
226
228
229
+ def export (self ) -> "LLMEdgeManager" :
230
+ """
231
+ Exports the model pre-autograd. This is not a full export, since it uses
232
+ torch.export_for_training() to keep autograd-safe ops from getting decomposed.
233
+ The full torch.export() if called later on during to_edge() or
234
+ to_edge_transform_and_lower().
235
+ """
236
+ exported_module = self ._export ()
237
+ # Need to store the graph module to record transformation passes.
238
+ # Persisting those changes back to an ExportedProgram will require
239
+ # an additional export().
240
+ self .pre_autograd_graph_module = exported_module .module ()
241
+ if hasattr (self .args , "export_only" ) and self .args .export_only :
242
+ torch .export .save (exported_module , self .args .output_name )
227
243
return self
228
244
229
245
def run_canonical_optimizations (self ):
@@ -441,9 +457,13 @@ def to_edge_transform_and_lower(
441
457
) -> "LLMEdgeManager" :
442
458
if partitioners is None :
443
459
logging .info ("No partitioner provided, skipping backend lowering..." )
460
+
461
+ # Need to construct ExportedProgram with the new transformed graph module.
462
+ exported_module = self ._export (self .pre_autograd_graph_module )
463
+
444
464
edge_config = self ._get_edge_config ()
445
465
self .edge_manager = to_edge_transform_and_lower (
446
- self . pre_autograd_exported_program ,
466
+ exported_module ,
447
467
partitioner = partitioners ,
448
468
compile_config = edge_config ,
449
469
constant_methods = self .metadata ,
0 commit comments