Skip to content

Commit affe92d

Browse files
committed
Fix pre-autograd transforms not getting persisted during xnnpack export
1 parent cf8ce89 commit affe92d

File tree

1 file changed

+39
-10
lines changed

1 file changed

+39
-10
lines changed

extension/llm/export/builder.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def __init__(
8989
dynamic_shapes: Optional[Any] = None,
9090
):
9191
self.model = model
92-
self.pre_autograd_exported_program: Optional[ExportedProgram] = None
92+
self.exported_program: Optional[ExportedProgram] = None
93+
# Self.exported_program's pre-autograd graph module, for running
94+
# transform passes on the graph prior to torch.export().
9395
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
9496
self.modelname = modelname
9597
self.max_seq_len = max_seq_len
@@ -184,7 +186,21 @@ def _get_edge_config(self) -> EdgeCompileConfig:
184186
)
185187
return edge_config
186188

187-
def export(self) -> "LLMEdgeManager":
189+
def export(self, module: Optional[torch.nn.Module] = None) -> "LLMEdgeManager":
190+
"""
191+
Exports the model pre-autograd. This is not a full export, since it uses
192+
torch.export_for_training() to keep autograd-safe ops from getting decomposed.
193+
The full torch.export() if called later on during to_edge() or
194+
to_edge_transform_and_lower().
195+
196+
The optional `module` argument is included so that the user can re-export
197+
an already-exported module's ExportedProgram's graph module, to persiste
198+
the changes into a new ExportedProgram.
199+
200+
Args:
201+
module (Optional[torch.nn.Module]): module to export.
202+
203+
"""
188204
dynamic_shape = self._get_dynamic_shape()
189205
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
190206
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -201,25 +217,30 @@ def export(self) -> "LLMEdgeManager":
201217
# TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
202218
# functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
203219
exported_module = torch.export.export(
204-
self.model,
220+
self.model if not module else module,
205221
self.example_inputs,
206222
self.example_kwarg_inputs,
207223
dynamic_shapes=dynamic_shape,
208224
strict=True,
209225
)
210226
else:
211-
logging.info("Exporting with:")
227+
if module:
228+
logging.info("Re-exporting with:")
229+
else:
230+
logging.info("Exporting with:")
212231
logging.info(f"inputs: {self.example_inputs}")
213232
logging.info(f"kwargs: {self.example_kwarg_inputs}")
214233
logging.info(f"dynamic shapes: {dynamic_shape}")
215234
exported_module = export_for_training(
216-
self.model,
235+
self.model if not module else module,
217236
self.example_inputs,
218237
kwargs=self.example_kwarg_inputs,
219238
dynamic_shapes=dynamic_shape,
220239
)
221-
# `Module`.
222-
self.pre_autograd_exported_program = exported_module
240+
self.exported_program = exported_module
241+
# Need to store the graph module to record transformation passes.
242+
# Persisting those changes back to the ExportedProgram will require
243+
# an additional export().
223244
self.pre_autograd_graph_module = exported_module.module()
224245
if hasattr(self.args, "export_only") and self.args.export_only:
225246
torch.export.save(exported_module, self.args.output_name)
@@ -382,7 +403,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
382403
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
383404
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
384405
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
385-
if self.pre_autograd_graph_module is None:
406+
if self.exported_program is None:
386407
# Run export() if it didn't run
387408
self.export()
388409

@@ -394,9 +415,12 @@ def export_to_edge(self) -> "LLMEdgeManager":
394415
return_value=False,
395416
)
396417

418+
# Prior to export, persist the changes to the pre autograd
419+
# graph module back to the source-of-truth ExportedProgram.
420+
self.export(self.pre_autograd_graph_module)
397421
with override_export_behaviour:
398422
self.edge_manager = export_to_edge(
399-
self.pre_autograd_graph_module, # pyre-fixme[6]
423+
self.exported_program.module(), # pyre-fixme[6]
400424
self.example_inputs,
401425
example_kwarg_inputs=self.example_kwarg_inputs,
402426
dynamic_shapes=dynamic_shape,
@@ -441,9 +465,14 @@ def to_edge_transform_and_lower(
441465
) -> "LLMEdgeManager":
442466
if partitioners is None:
443467
logging.info("No partitioner provided, skipping backend lowering...")
468+
469+
# Prior to export, persist the changes to the pre autograd
470+
# graph module back to the source-of-truth ExportedProgram.
471+
self.export(self.pre_autograd_graph_module)
472+
444473
edge_config = self._get_edge_config()
445474
self.edge_manager = to_edge_transform_and_lower(
446-
self.pre_autograd_exported_program,
475+
self.exported_program,
447476
partitioner=partitioners,
448477
compile_config=edge_config,
449478
constant_methods=self.metadata,

0 commit comments

Comments
 (0)