Skip to content

Commit 9978148

Browse files
committed
Graph module as SOT
1 parent affe92d commit 9978148

File tree

1 file changed

+25
-34
lines changed

1 file changed

+25
-34
lines changed

extension/llm/export/builder.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ def __init__(
8989
dynamic_shapes: Optional[Any] = None,
9090
):
9191
self.model = model
92-
self.exported_program: Optional[ExportedProgram] = None
93-
# Self.exported_program's pre-autograd graph module, for running
94-
# transform passes on the graph prior to torch.export().
92+
# Note: treat this as the source of truth for the result of
93+
# torch.export'ing a model. If the overall ExportedProgram is needed,
94+
# make sure to re-export this graph module to persist any changes. See
95+
# https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
9596
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
9697
self.modelname = modelname
9798
self.max_seq_len = max_seq_len
@@ -186,21 +187,7 @@ def _get_edge_config(self) -> EdgeCompileConfig:
186187
)
187188
return edge_config
188189

189-
def export(self, module: Optional[torch.nn.Module] = None) -> "LLMEdgeManager":
190-
"""
191-
Exports the model pre-autograd. This is not a full export, since it uses
192-
torch.export_for_training() to keep autograd-safe ops from getting decomposed.
193-
The full torch.export() if called later on during to_edge() or
194-
to_edge_transform_and_lower().
195-
196-
The optional `module` argument is included so that the user can re-export
197-
an already-exported module's ExportedProgram's graph module, to persiste
198-
the changes into a new ExportedProgram.
199-
200-
Args:
201-
module (Optional[torch.nn.Module]): module to export.
202-
203-
"""
190+
def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
204191
dynamic_shape = self._get_dynamic_shape()
205192
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
206193
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -237,14 +224,22 @@ def export(self, module: Optional[torch.nn.Module] = None) -> "LLMEdgeManager":
237224
kwargs=self.example_kwarg_inputs,
238225
dynamic_shapes=dynamic_shape,
239226
)
240-
self.exported_program = exported_module
241-
# Need to store the graph module to record transformation passes.
242-
# Persisting those changes back to the ExportedProgram will require
243-
# an additional export().
244-
self.pre_autograd_graph_module = exported_module.module()
245-
if hasattr(self.args, "export_only") and self.args.export_only:
246-
torch.export.save(exported_module, self.args.output_name)
227+
return exported_module
247228

229+
def export(self) -> "LLMEdgeManager":
230+
"""
231+
Exports the model pre-autograd. This is not a full export, since it uses
232+
torch.export_for_training() to keep autograd-safe ops from getting decomposed.
233+
The full torch.export() if called later on during to_edge() or
234+
to_edge_transform_and_lower().
235+
"""
236+
exported_module = self._export()
237+
# Need to store the graph module to record transformation passes.
238+
# Persisting those changes back to an ExportedProgram will require
239+
# an additional export().
240+
self.pre_autograd_graph_module = exported_module.module()
241+
if hasattr(self.args, "export_only") and self.args.export_only:
242+
torch.export.save(exported_module, self.args.output_name)
248243
return self
249244

250245
def run_canonical_optimizations(self):
@@ -403,7 +398,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
403398
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
404399
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
405400
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
406-
if self.exported_program is None:
401+
if self.pre_autograd_graph_module is None:
407402
# Run export() if it didn't run
408403
self.export()
409404

@@ -415,12 +410,9 @@ def export_to_edge(self) -> "LLMEdgeManager":
415410
return_value=False,
416411
)
417412

418-
# Prior to export, persist the changes to the pre autograd
419-
# graph module back to the source-of-truth ExportedProgram.
420-
self.export(self.pre_autograd_graph_module)
421413
with override_export_behaviour:
422414
self.edge_manager = export_to_edge(
423-
self.exported_program.module(), # pyre-fixme[6]
415+
self.pre_autograd_graph_module, # pyre-fixme[6]
424416
self.example_inputs,
425417
example_kwarg_inputs=self.example_kwarg_inputs,
426418
dynamic_shapes=dynamic_shape,
@@ -466,13 +458,12 @@ def to_edge_transform_and_lower(
466458
if partitioners is None:
467459
logging.info("No partitioner provided, skipping backend lowering...")
468460

469-
# Prior to export, persist the changes to the pre autograd
470-
# graph module back to the source-of-truth ExportedProgram.
471-
self.export(self.pre_autograd_graph_module)
461+
# Need to construct ExportedProgram with the new transformed graph module.
462+
exported_module = self._export(self.pre_autograd_graph_module)
472463

473464
edge_config = self._get_edge_config()
474465
self.edge_manager = to_edge_transform_and_lower(
475-
self.exported_program,
466+
exported_module,
476467
partitioner=partitioners,
477468
compile_config=edge_config,
478469
constant_methods=self.metadata,

0 commit comments

Comments
 (0)