Skip to content

Introduce extension/llm/export_llm #11746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ runtime.python_binary(
":export_library",
"//caffe2:torch",
"//executorch/extension/pybindings:aten_lib",
"//executorch/extension/llm/export:export_llm_lib",
],
)

Expand Down Expand Up @@ -133,8 +134,6 @@ runtime.python_library(
name = "export_library",
srcs = [
"export_llama.py",
"export_llama_args.py",
"export_llama_hydra.py",
"export_llama_lib.py",
"model.py",
],
Expand Down
4 changes: 2 additions & 2 deletions examples/models/llama/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class BaseConfig:
checkpoint_dir: Optional[str] = None
tokenizer_path: Optional[str] = None
metadata: Optional[str] = None
use_lora: int = int
use_lora: int = 0
fairseq2: bool = False
preq_mode: Optional[PreqMode] = None
preq_group_size: int = 32
Expand Down Expand Up @@ -214,7 +214,7 @@ class ExportConfig:

max_seq_length: int = 128
max_context_length: int = 128
output_dir: Optional[str] = None
output_dir: str = "."
output_name: Optional[str] = None
so_library: Optional[str] = None
export_only: bool = False
Expand Down
16 changes: 9 additions & 7 deletions examples/models/llama/export_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

import torch

from executorch.examples.models.llama.export_llama_lib import (
build_args_parser,
export_llama,
)

sys.setrecursionlimit(4096)


Expand All @@ -39,15 +44,12 @@ def main() -> None:
sys.argv = [arg for arg in sys.argv if arg != "--hydra"]
print(f"running with {sys.argv}")
runpy.run_module(
"executorch.examples.models.llama.export_llama_hydra", run_name="__main__"
"executorch.extension.llm.export.export_llm", run_name="__main__"
)
else:
# Use the legacy version of the export_llama script which uses argsparse.
from executorch.examples.models.llama.export_llama_args import (
main as export_llama_args_main,
)

export_llama_args_main(remaining_args)
parser = build_args_parser()
remaining_args = parser.parse_args(remaining_args)
export_llama(remaining_args)


if __name__ == "__main__":
Expand Down
21 changes: 0 additions & 21 deletions examples/models/llama/export_llama_args.py

This file was deleted.

28 changes: 0 additions & 28 deletions examples/models/llama/export_llama_hydra.py

This file was deleted.

35 changes: 35 additions & 0 deletions extension/llm/export/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,41 @@ runtime.python_library(
],
)

runtime.python_binary(
name = "export_llm",
srcs = [
"export_llm.py",
],
main_function = "executorch.extension.llm.export.export_llm.main",
preload_deps = [
"//executorch/extension/llm/custom_ops:model_sharding_py",
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/kernels/quantized:aot_lib",
],
deps = [
"fbsource//third-party/pypi/hydra-core:hydra-core",
"fbsource//third-party/pypi/omegaconf:omegaconf",
"//executorch/examples/models/llama:export_library",
"//executorch/extension/pybindings:aten_lib",
],
)

runtime.python_library(
name = "export_llm_lib",
srcs = [
"export_llm.py",
],
deps = [
"fbsource//third-party/pypi/hydra-core:hydra-core",
"fbsource//third-party/pypi/omegaconf:omegaconf",
"//executorch/examples/models/llama:export_library",
],
visibility = [
"//executorch/examples/...",
"//executorch/extension/llm/...",
],
)

runtime.python_test(
name = "export_passes_test",
srcs = [
Expand Down
45 changes: 45 additions & 0 deletions extension/llm/export/export_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Export an LLM with ExecuTorch. Currently follows the following steps:
1. Instantiate our custom PyTorch transformer definition from examples/llama/models/llama_transformer.py.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't mention implementation details in the docblock.

If it's a public API, the docblock should contain description of the contract.

Copy link
Contributor Author

@jackzhxng jackzhxng Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will move this information to a README

2. Load weights into the model.
3. Apply source transformations/TorchAO quantization.
4. Export model to intermediate IRs.
5. Graph transformations/PT2E quantization.
6. Partition graph and delegate to backend(s).
7. Export to final ExecuTorch .pte format.

Example usage using full CLI arguments:
python -m extension.llm.export.export_llm \
base.model_class="llama3" \
model.use_sdpa_with_kv_cache=True \
model.use_kv_cache=True \
debug.verbose=True \
backend.xnnpack.enabled=True \
backend.xnnpack.extended_ops=True \
quantization.qmode="8da4w"
"""

import hydra

from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.examples.models.llama.export_llama_lib import export_llama
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf

cs = ConfigStore.instance()
cs.store(name="llm_config", node=LlmConfig)


@hydra.main(version_base=None, config_path=None, config_name="llm_config")
def main(llm_config: LlmConfig) -> None:
export_llama(OmegaConf.to_object(llm_config))


if __name__ == "__main__":
main()
Loading