Skip to content

Commit d8deda2

Browse files
committed
Export llama uses to_edge_lower_and_transform
1 parent 2600cc8 commit d8deda2

File tree

2 files changed

+67
-47
lines changed

2 files changed

+67
-47
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -659,11 +659,12 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
659659
if args.export_only:
660660
exit()
661661

662-
builder_exported_to_edge = builder_exported.pt2e_quantize(
663-
quantizers
664-
).export_to_edge()
662+
# builder_exported_to_edge = builder_exported.pt2e_quantize(
663+
# quantizers
664+
# ).export_to_edge()
665665

666-
modelname = builder_exported_to_edge.modelname
666+
# modelname = builder_exported_to_edge.modelname
667+
modelname = builder_exported.modelname
667668

668669
# to_backend
669670
partitioners = []
@@ -768,6 +769,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
768769
for partitioner in partitioners:
769770
logging.info(f"--> {partitioner.__class__.__name__}")
770771

772+
breakpoint()
771773
if args.generate_etrecord:
772774
if not builder_exported_to_edge.edge_manager:
773775
raise ValueError("Unable to generate etrecord due to missing edge manager.")
@@ -793,14 +795,19 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
793795
)
794796
logging.info("Generated etrecord.bin")
795797
else:
796-
builder = builder_exported_to_edge.to_backend(partitioners)
798+
builder_lowered = builder_exported.pt2e_quantize(
799+
quantizers
800+
).to_edge_transform_and_lower(
801+
partitioners
802+
)
803+
# builder = builder_exported_to_edge.to_backend(partitioners)
797804
if args.num_sharding > 0 and args.qnn:
798805
from executorch.backends.qualcomm.utils.utils import canonicalize_program
799806

800807
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
801808
canonicalize_program(builder.edge_manager.exported_program())
802809

803-
builder = builder.to_executorch()
810+
builder = builder_lowered.to_executorch()
804811

805812
if args.profile_memory:
806813
generate_memory_trace(builder.export_program, "memory_profile.json")

extension/llm/export/builder.py

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
DuplicateDynamicQuantChainPass,
2222
)
2323
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
24-
from executorch.exir import EdgeProgramManager
24+
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
2525
from executorch.exir.backend.partitioner import Partitioner
2626

2727
from executorch.exir.backend.utils import format_delegated_graph
@@ -216,6 +216,7 @@ def export(self) -> "LLMEdgeManager":
216216
)
217217
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
218218
# `Module`.
219+
self.pre_autograd_exported_program = exported_module
219220
self.pre_autograd_graph_module = exported_module.module()
220221
if hasattr(self.args, "export_only") and self.args.export_only:
221222
torch.export.save(exported_module, self.args.output_name)
@@ -305,51 +306,51 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
305306
), "export_to_edge is already called, please call pt2e_quantize before export_to_edge"
306307
logging.info(f"Using pt2e {quantizers} to quantizing the model...")
307308

309+
if not quantizers:
310+
logging.info("No quantizer provided, passing...")
311+
return self
312+
308313
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
309314
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
310-
if quantizers:
311-
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
312-
if self.verbose:
313-
logging.info(f"Applied quantizers: {quantizers}")
314-
composed_quantizer = ComposableQuantizer(quantizers)
315-
assert (
316-
self.pre_autograd_graph_module is not None
317-
), "Please run export() first"
318-
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
315+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
316+
if self.verbose:
317+
logging.info(f"Applied quantizers: {quantizers}")
318+
composed_quantizer = ComposableQuantizer(quantizers)
319+
assert (
320+
self.pre_autograd_graph_module is not None
321+
), "Please run export() first"
322+
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
323+
logging.info(
324+
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
325+
)
326+
# Calibrate
327+
if (
328+
self.calibration_tasks is not None
329+
and self.calibration_limit is not None
330+
and self.calibration_seq_length is not None
331+
and self.calibration_data is not None
332+
and self.tokenizer_path is not None
333+
):
319334
logging.info(
320335
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
321336
)
322-
# Calibrate
323-
if (
324-
self.calibration_tasks is not None
325-
and self.calibration_limit is not None
326-
and self.calibration_seq_length is not None
327-
and self.calibration_data is not None
328-
and self.tokenizer_path is not None
329-
):
330-
logging.info(
331-
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
332-
)
333-
self.pt2e_calibrate(
334-
prepared_module=m,
335-
calibration_tasks=self.calibration_tasks,
336-
calibration_limit=self.calibration_limit,
337-
calibration_seq_length=self.calibration_seq_length,
338-
calibration_data=self.calibration_data,
339-
tokenizer_path=self.tokenizer_path,
340-
)
341-
else:
342-
logging.info(
343-
"No calibration provided, using dummy input to calibrate..."
344-
)
345-
m(*self.example_inputs)
346-
m = convert_pt2e(m)
347-
DuplicateDynamicQuantChainPass()(m)
348-
self.pre_autograd_graph_module = m
349-
return self
350-
else:
351-
logging.info("No quantizer provided, passing...")
352-
return self
337+
self.pt2e_calibrate(
338+
prepared_module=m,
339+
calibration_tasks=self.calibration_tasks,
340+
calibration_limit=self.calibration_limit,
341+
calibration_seq_length=self.calibration_seq_length,
342+
calibration_data=self.calibration_data,
343+
tokenizer_path=self.tokenizer_path,
344+
)
345+
else:
346+
logging.info(
347+
"No calibration provided, using dummy input to calibrate..."
348+
)
349+
m(*self.example_inputs, **self.example_kwarg_inputs)
350+
m = convert_pt2e(m)
351+
DuplicateDynamicQuantChainPass()(m)
352+
self.pre_autograd_graph_module = m
353+
return self
353354

354355
def export_to_edge(self) -> "LLMEdgeManager":
355356
"""
@@ -415,6 +416,18 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag
415416

416417
return self
417418

419+
def to_edge_transform_and_lower(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
420+
if partitioners is None:
421+
logging.info("No partitioner provided, skipping backend lowering...")
422+
breakpoint()
423+
edge_config = self._get_edge_config()
424+
self.edge_manager = to_edge_transform_and_lower(
425+
self.pre_autograd_exported_program,
426+
partitioner=partitioners,
427+
compile_config=edge_config,
428+
)
429+
return self
430+
418431
def to_executorch(self) -> "LLMEdgeManager":
419432
"""
420433
Lower the model to executorch and get an ExecutorchProgram.

0 commit comments

Comments
 (0)