Skip to content

Commit 2f6d64f

Browse files
authored
torch.export()-only export Llama arg (#6695)
1 parent 49756f6 commit 2f6d64f

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,13 @@ def build_args_parser() -> argparse.ArgumentParser:
451451
default=None,
452452
help="path to the input pruning token mapping file (token_map.json)",
453453
)
454+
455+
parser.add_argument(
456+
"--export_only",
457+
default=False,
458+
action="store_true",
459+
help="If true, stops right after torch.export() and saves the exported model.",
460+
)
454461
return parser
455462

456463

@@ -625,12 +632,14 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
625632
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
626633

627634
# export_to_edge
628-
builder_exported_to_edge = (
629-
_prepare_for_llama_export(modelname, args)
630-
.export()
631-
.pt2e_quantize(quantizers)
632-
.export_to_edge()
633-
)
635+
builder_exported = _prepare_for_llama_export(modelname, args).export()
636+
637+
if args.export_only:
638+
exit()
639+
640+
builder_exported_to_edge = builder_exported.pt2e_quantize(
641+
quantizers
642+
).export_to_edge()
634643

635644
modelname = builder_exported_to_edge.modelname
636645

extension/llm/export/builder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,22 +186,25 @@ def export(self) -> "LLMEdgeManager":
186186
# functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
187187
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
188188
# `Module`.
189-
self.pre_autograd_graph_module = torch.export.export(
189+
exported_module = torch.export.export(
190190
self.model,
191191
self.example_inputs,
192192
self.example_kwarg_inputs,
193193
dynamic_shapes=dynamic_shape,
194194
strict=True,
195-
).module()
195+
)
196196
else:
197197
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
198198
# `Module`.
199-
self.pre_autograd_graph_module = export_for_training(
199+
exported_module = export_for_training(
200200
self.model,
201201
self.example_inputs,
202202
kwargs=self.example_kwarg_inputs,
203203
dynamic_shapes=dynamic_shape,
204-
).module()
204+
)
205+
self.pre_autograd_graph_module = exported_module.module()
206+
if hasattr(self.args, "export_only") and self.args.export_only:
207+
torch.export.save(exported_module, self.args.output_name)
205208

206209
return self
207210

0 commit comments

Comments
 (0)