@@ -89,7 +89,9 @@ def __init__(
89
89
dynamic_shapes : Optional [Any ] = None ,
90
90
):
91
91
self .model = model
92
- self .pre_autograd_exported_program : Optional [ExportedProgram ] = None
92
+ self .exported_program : Optional [ExportedProgram ] = None
93
+ # Self.exported_program's pre-autograd graph module, for running
94
+ # transform passes on the graph prior to torch.export().
93
95
self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
94
96
self .modelname = modelname
95
97
self .max_seq_len = max_seq_len
@@ -184,7 +186,21 @@ def _get_edge_config(self) -> EdgeCompileConfig:
184
186
)
185
187
return edge_config
186
188
187
- def export (self ) -> "LLMEdgeManager" :
189
+ def export (self , module : Optional [torch .nn .Module ] = None ) -> "LLMEdgeManager" :
190
+ """
191
+ Exports the model pre-autograd. This is not a full export, since it uses
192
+ torch.export_for_training() to keep autograd-safe ops from getting decomposed.
193
+ The full torch.export() if called later on during to_edge() or
194
+ to_edge_transform_and_lower().
195
+
196
+ The optional `module` argument is included so that the user can re-export
197
+ an already-exported module's ExportedProgram's graph module, to persiste
198
+ the changes into a new ExportedProgram.
199
+
200
+ Args:
201
+ module (Optional[torch.nn.Module]): module to export.
202
+
203
+ """
188
204
dynamic_shape = self ._get_dynamic_shape ()
189
205
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
190
206
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -201,25 +217,30 @@ def export(self) -> "LLMEdgeManager":
201
217
# TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
202
218
# functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
203
219
exported_module = torch .export .export (
204
- self .model ,
220
+ self .model if not module else module ,
205
221
self .example_inputs ,
206
222
self .example_kwarg_inputs ,
207
223
dynamic_shapes = dynamic_shape ,
208
224
strict = True ,
209
225
)
210
226
else :
211
- logging .info ("Exporting with:" )
227
+ if module :
228
+ logging .info ("Re-exporting with:" )
229
+ else :
230
+ logging .info ("Exporting with:" )
212
231
logging .info (f"inputs: { self .example_inputs } " )
213
232
logging .info (f"kwargs: { self .example_kwarg_inputs } " )
214
233
logging .info (f"dynamic shapes: { dynamic_shape } " )
215
234
exported_module = export_for_training (
216
- self .model ,
235
+ self .model if not module else module ,
217
236
self .example_inputs ,
218
237
kwargs = self .example_kwarg_inputs ,
219
238
dynamic_shapes = dynamic_shape ,
220
239
)
221
- # `Module`.
222
- self .pre_autograd_exported_program = exported_module
240
+ self .exported_program = exported_module
241
+ # Need to store the graph module to record transformation passes.
242
+ # Persisting those changes back to the ExportedProgram will require
243
+ # an additional export().
223
244
self .pre_autograd_graph_module = exported_module .module ()
224
245
if hasattr (self .args , "export_only" ) and self .args .export_only :
225
246
torch .export .save (exported_module , self .args .output_name )
@@ -382,7 +403,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
382
403
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
383
404
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
384
405
with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
385
- if self .pre_autograd_graph_module is None :
406
+ if self .exported_program is None :
386
407
# Run export() if it didn't run
387
408
self .export ()
388
409
@@ -394,9 +415,12 @@ def export_to_edge(self) -> "LLMEdgeManager":
394
415
return_value = False ,
395
416
)
396
417
418
+ # Prior to export, persist the changes to the pre autograd
419
+ # graph module back to the source-of-truth ExportedProgram.
420
+ self .export (self .pre_autograd_graph_module )
397
421
with override_export_behaviour :
398
422
self .edge_manager = export_to_edge (
399
- self .pre_autograd_graph_module , # pyre-fixme[6]
423
+ self .exported_program . module () , # pyre-fixme[6]
400
424
self .example_inputs ,
401
425
example_kwarg_inputs = self .example_kwarg_inputs ,
402
426
dynamic_shapes = dynamic_shape ,
@@ -441,9 +465,14 @@ def to_edge_transform_and_lower(
441
465
) -> "LLMEdgeManager" :
442
466
if partitioners is None :
443
467
logging .info ("No partitioner provided, skipping backend lowering..." )
468
+
469
+ # Prior to export, persist the changes to the pre autograd
470
+ # graph module back to the source-of-truth ExportedProgram.
471
+ self .export (self .pre_autograd_graph_module )
472
+
444
473
edge_config = self ._get_edge_config ()
445
474
self .edge_manager = to_edge_transform_and_lower (
446
- self .pre_autograd_exported_program ,
475
+ self .exported_program ,
447
476
partitioner = partitioners ,
448
477
compile_config = edge_config ,
449
478
constant_methods = self .metadata ,
0 commit comments