Skip to content

Commit 1b2c60c

Browse files
authored
Fix pre-autograd transforms not getting persisted during xnnpack export (#9118)
1 parent 70d4427 commit 1b2c60c

File tree

1 file changed

+31
-11
lines changed

1 file changed

+31
-11
lines changed

extension/llm/export/builder.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ def __init__(
8989
dynamic_shapes: Optional[Any] = None,
9090
):
9191
self.model = model
92-
self.pre_autograd_exported_program: Optional[ExportedProgram] = None
92+
# Note: treat this as the source of truth for the result of
93+
# torch.export'ing a model. If the overall ExportedProgram is needed,
94+
# make sure to re-export this graph module to persist any changes. See
95+
# https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
9396
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
9497
self.modelname = modelname
9598
self.max_seq_len = max_seq_len
@@ -184,7 +187,7 @@ def _get_edge_config(self) -> EdgeCompileConfig:
184187
)
185188
return edge_config
186189

187-
def export(self) -> "LLMEdgeManager":
190+
def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
188191
dynamic_shape = self._get_dynamic_shape()
189192
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
190193
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -201,29 +204,42 @@ def export(self) -> "LLMEdgeManager":
201204
# TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
202205
# functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
203206
exported_module = torch.export.export(
204-
self.model,
207+
self.model if not module else module,
205208
self.example_inputs,
206209
self.example_kwarg_inputs,
207210
dynamic_shapes=dynamic_shape,
208211
strict=True,
209212
)
210213
else:
211-
logging.info("Exporting with:")
214+
if module:
215+
logging.info("Re-exporting with:")
216+
else:
217+
logging.info("Exporting with:")
212218
logging.info(f"inputs: {self.example_inputs}")
213219
logging.info(f"kwargs: {self.example_kwarg_inputs}")
214220
logging.info(f"dynamic shapes: {dynamic_shape}")
215221
exported_module = export_for_training(
216-
self.model,
222+
self.model if not module else module,
217223
self.example_inputs,
218224
kwargs=self.example_kwarg_inputs,
219225
dynamic_shapes=dynamic_shape,
220226
)
221-
# `Module`.
222-
self.pre_autograd_exported_program = exported_module
223-
self.pre_autograd_graph_module = exported_module.module()
224-
if hasattr(self.args, "export_only") and self.args.export_only:
225-
torch.export.save(exported_module, self.args.output_name)
227+
return exported_module
226228

229+
def export(self) -> "LLMEdgeManager":
230+
"""
231+
Exports the model pre-autograd. This is not a full export, since it uses
232+
torch.export_for_training() to keep autograd-safe ops from getting decomposed.
233+
The full torch.export() if called later on during to_edge() or
234+
to_edge_transform_and_lower().
235+
"""
236+
exported_module = self._export()
237+
# Need to store the graph module to record transformation passes.
238+
# Persisting those changes back to an ExportedProgram will require
239+
# an additional export().
240+
self.pre_autograd_graph_module = exported_module.module()
241+
if hasattr(self.args, "export_only") and self.args.export_only:
242+
torch.export.save(exported_module, self.args.output_name)
227243
return self
228244

229245
def run_canonical_optimizations(self):
@@ -441,9 +457,13 @@ def to_edge_transform_and_lower(
441457
) -> "LLMEdgeManager":
442458
if partitioners is None:
443459
logging.info("No partitioner provided, skipping backend lowering...")
460+
461+
# Need to construct ExportedProgram with the new transformed graph module.
462+
exported_module = self._export(self.pre_autograd_graph_module)
463+
444464
edge_config = self._get_edge_config()
445465
self.edge_manager = to_edge_transform_and_lower(
446-
self.pre_autograd_exported_program,
466+
exported_module,
447467
partitioner=partitioners,
448468
compile_config=edge_config,
449469
constant_methods=self.metadata,

0 commit comments

Comments
 (0)