@@ -89,9 +89,10 @@ def __init__(
89
89
dynamic_shapes : Optional [Any ] = None ,
90
90
):
91
91
self .model = model
92
- self .exported_program : Optional [ExportedProgram ] = None
93
- # Self.exported_program's pre-autograd graph module, for running
94
- # transform passes on the graph prior to torch.export().
92
+ # Note: treat this as the source of truth for the result of
93
+ # torch.export'ing a model. If the overall ExportedProgram is needed,
94
+ # make sure to re-export this graph module to persist any changes. See
95
+ # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
95
96
self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
96
97
self .modelname = modelname
97
98
self .max_seq_len = max_seq_len
@@ -186,21 +187,7 @@ def _get_edge_config(self) -> EdgeCompileConfig:
186
187
)
187
188
return edge_config
188
189
189
- def export (self , module : Optional [torch .nn .Module ] = None ) -> "LLMEdgeManager" :
190
- """
191
- Exports the model pre-autograd. This is not a full export, since it uses
192
- torch.export_for_training() to keep autograd-safe ops from getting decomposed.
193
- The full torch.export() if called later on during to_edge() or
194
- to_edge_transform_and_lower().
195
-
196
- The optional `module` argument is included so that the user can re-export
197
- an already-exported module's ExportedProgram's graph module, to persiste
198
- the changes into a new ExportedProgram.
199
-
200
- Args:
201
- module (Optional[torch.nn.Module]): module to export.
202
-
203
- """
190
+ def _export (self , module : Optional [torch .nn .Module ] = None ) -> ExportedProgram :
204
191
dynamic_shape = self ._get_dynamic_shape ()
205
192
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
206
193
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -237,14 +224,22 @@ def export(self, module: Optional[torch.nn.Module] = None) -> "LLMEdgeManager":
237
224
kwargs = self .example_kwarg_inputs ,
238
225
dynamic_shapes = dynamic_shape ,
239
226
)
240
- self .exported_program = exported_module
241
- # Need to store the graph module to record transformation passes.
242
- # Persisting those changes back to the ExportedProgram will require
243
- # an additional export().
244
- self .pre_autograd_graph_module = exported_module .module ()
245
- if hasattr (self .args , "export_only" ) and self .args .export_only :
246
- torch .export .save (exported_module , self .args .output_name )
227
+ return exported_module
247
228
229
+ def export (self ) -> "LLMEdgeManager" :
230
+ """
231
+ Exports the model pre-autograd. This is not a full export, since it uses
232
+ torch.export_for_training() to keep autograd-safe ops from getting decomposed.
233
+ The full torch.export() if called later on during to_edge() or
234
+ to_edge_transform_and_lower().
235
+ """
236
+ exported_module = self ._export ()
237
+ # Need to store the graph module to record transformation passes.
238
+ # Persisting those changes back to an ExportedProgram will require
239
+ # an additional export().
240
+ self .pre_autograd_graph_module = exported_module .module ()
241
+ if hasattr (self .args , "export_only" ) and self .args .export_only :
242
+ torch .export .save (exported_module , self .args .output_name )
248
243
return self
249
244
250
245
def run_canonical_optimizations (self ):
@@ -403,7 +398,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
403
398
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
404
399
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
405
400
with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
406
- if self .exported_program is None :
401
+ if self .pre_autograd_graph_module is None :
407
402
# Run export() if it didn't run
408
403
self .export ()
409
404
@@ -415,12 +410,9 @@ def export_to_edge(self) -> "LLMEdgeManager":
415
410
return_value = False ,
416
411
)
417
412
418
- # Prior to export, persist the changes to the pre autograd
419
- # graph module back to the source-of-truth ExportedProgram.
420
- self .export (self .pre_autograd_graph_module )
421
413
with override_export_behaviour :
422
414
self .edge_manager = export_to_edge (
423
- self .exported_program . module () , # pyre-fixme[6]
415
+ self .pre_autograd_graph_module , # pyre-fixme[6]
424
416
self .example_inputs ,
425
417
example_kwarg_inputs = self .example_kwarg_inputs ,
426
418
dynamic_shapes = dynamic_shape ,
@@ -466,13 +458,12 @@ def to_edge_transform_and_lower(
466
458
if partitioners is None :
467
459
logging .info ("No partitioner provided, skipping backend lowering..." )
468
460
469
- # Prior to export, persist the changes to the pre autograd
470
- # graph module back to the source-of-truth ExportedProgram.
471
- self .export (self .pre_autograd_graph_module )
461
+ # Need to construct ExportedProgram with the new transformed graph module.
462
+ exported_module = self ._export (self .pre_autograd_graph_module )
472
463
473
464
edge_config = self ._get_edge_config ()
474
465
self .edge_manager = to_edge_transform_and_lower (
475
- self . exported_program ,
466
+ exported_module ,
476
467
partitioner = partitioners ,
477
468
compile_config = edge_config ,
478
469
constant_methods = self .metadata ,
0 commit comments