Skip to content

Commit f963342

Browse files
committed
Introduce hydra framework with backwards compatibility
Pull Request resolved: #11029 @imported-using-ghimport Differential Revision: [D75263989](https://our.internmc.facebook.com/intern/diff/D75263989/) ghstack-source-id: 287798912
1 parent 97f345c commit f963342

File tree

7 files changed

+160
-11
lines changed

7 files changed

+160
-11
lines changed

examples/models/llama/TARGETS

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ runtime.python_binary(
8282
],
8383
deps = [
8484
":export_library",
85+
":export_llama_args",
86+
":export_llama_hydra",
8587
"//caffe2:torch",
8688
"//executorch/extension/pybindings:aten_lib",
8789
],
@@ -148,6 +150,8 @@ runtime.python_library(
148150
":source_transformation",
149151
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
150152
"//caffe2:torch",
153+
"//executorch/examples/models/llama/config:llm_config",
154+
"//executorch/examples/models/llama/config:llm_config_utils",
151155
"//executorch/backends/vulkan/_passes:vulkan_passes",
152156
"//executorch/exir/passes:init_mutable_pass",
153157
"//executorch/examples/models:model_base",
@@ -231,6 +235,40 @@ runtime.python_library(
231235
],
232236
)
233237

238+
runtime.python_library(
239+
name = "export_llama_args",
240+
srcs = [
241+
"export_llama_args.py",
242+
],
243+
_is_external_target = True,
244+
base_module = "executorch.examples.models.llama",
245+
visibility = [
246+
"//executorch/examples/...",
247+
"@EXECUTORCH_CLIENTS",
248+
],
249+
deps = [
250+
":export_library",
251+
],
252+
)
253+
254+
runtime.python_library(
255+
name = "export_llama_hydra",
256+
srcs = [
257+
"export_llama_hydra.py",
258+
],
259+
_is_external_target = True,
260+
base_module = "executorch.examples.models.llama",
261+
visibility = [
262+
"//executorch/examples/...",
263+
"@EXECUTORCH_CLIENTS",
264+
],
265+
deps = [
266+
":export_library",
267+
"//executorch/examples/models/llama/config:llm_config",
268+
"fbsource//third-party/pypi/hydra-core:hydra-core",
269+
],
270+
)
271+
234272
runtime.python_test(
235273
name = "quantized_kv_cache_test",
236274
srcs = [
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import argparse
9+
10+
from executorch.examples.models.llama.config.llm_config import LlmConfig
11+
12+
13+
def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
14+
"""
15+
To support legacy purposes, this function converts CLI args from
16+
argparse to an LlmConfig, which is used by the LLM export process.
17+
"""
18+
llm_config = LlmConfig()
19+
20+
# TODO: conversion code.
21+
22+
return llm_config

examples/models/llama/export_llama.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,50 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# Example script for exporting Llama2 to flatbuffer
8-
9-
import logging
10-
117
# force=True to ensure logging while in debugger. Set up logger before any
128
# other imports.
9+
import logging
10+
1311
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
1412
logging.basicConfig(level=logging.INFO, format=FORMAT, force=True)
1513

14+
import argparse
15+
import runpy
1616
import sys
1717

1818
import torch
1919

20-
from .export_llama_lib import build_args_parser, export_llama
21-
2220
sys.setrecursionlimit(4096)
2321

2422

23+
def parse_hydra_arg():
24+
"""First parse out the arg for whether to use Hydra or the old CLI."""
25+
parser = argparse.ArgumentParser(add_help=True)
26+
parser.add_argument("--hydra", action="store_true")
27+
args, remaining = parser.parse_known_args()
28+
return args.hydra, remaining
29+
30+
2531
def main() -> None:
2632
seed = 42
2733
torch.manual_seed(seed)
28-
parser = build_args_parser()
29-
args = parser.parse_args()
30-
export_llama(args)
34+
35+
use_hydra, remaining_args = parse_hydra_arg()
36+
if use_hydra:
37+
# The import runs the main function of export_llama_hydra with the remaining args
38+
# under the Hydra framework.
39+
sys.argv = [arg for arg in sys.argv if arg != "--hydra"]
40+
print(f"running with {sys.argv}")
41+
runpy.run_module(
42+
"executorch.examples.models.llama.export_llama_hydra", run_name="__main__"
43+
)
44+
else:
45+
# Use the legacy version of the export_llama script which uses argsparse.
46+
from executorch.examples.models.llama.export_llama_args import (
47+
main as export_llama_args_main,
48+
)
49+
50+
export_llama_args_main(remaining_args)
3151

3252

3353
if __name__ == "__main__":
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Run export_llama with the legacy argparse setup.
9+
"""
10+
11+
from .export_llama_lib import build_args_parser, export_llama
12+
13+
14+
def main(args) -> None:
15+
parser = build_args_parser()
16+
args = parser.parse_args(args)
17+
export_llama(args)
18+
19+
20+
if __name__ == "__main__":
21+
main()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Run export_llama using the new Hydra CLI.
9+
"""
10+
11+
import hydra
12+
13+
from executorch.examples.models.llama.config.llm_config import LlmConfig
14+
from executorch.examples.models.llama.export_llama_lib import export_llama
15+
from hydra.core.config_store import ConfigStore
16+
17+
cs = ConfigStore.instance()
18+
cs.store(name="llm_config", node=LlmConfig)
19+
20+
21+
@hydra.main(version_base=None, config_name="llm_config")
22+
def main(llm_config: LlmConfig) -> None:
23+
export_llama(llm_config)
24+
25+
26+
if __name__ == "__main__":
27+
main()

examples/models/llama/export_llama_lib.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
from executorch.devtools.backend_debug import print_delegation_info
2929

3030
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
31+
32+
from executorch.examples.models.llama.config.llm_config_utils import (
33+
convert_args_to_llm_config,
34+
)
3135
from executorch.examples.models.llama.hf_download import (
3236
download_and_convert_hf_checkpoint,
3337
)
@@ -51,6 +55,7 @@
5155
get_vulkan_quantizer,
5256
)
5357
from executorch.util.activation_memory_profiler import generate_memory_trace
58+
from omegaconf.dictconfig import DictConfig
5459

5560
from ..model_factory import EagerModelFactory
5661
from .source_transformation.apply_spin_quant_r1_r2 import (
@@ -568,7 +573,23 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
568573
return return_val
569574

570575

571-
def export_llama(args) -> str:
576+
def export_llama(
577+
export_options: Union[argparse.Namespace, DictConfig],
578+
) -> str:
579+
if isinstance(export_options, argparse.Namespace):
580+
# Legacy CLI.
581+
args = export_options
582+
llm_config = convert_args_to_llm_config(export_options) # noqa: F841
583+
elif isinstance(export_options, DictConfig):
584+
# Hydra CLI.
585+
llm_config = export_options # noqa: F841
586+
else:
587+
raise ValueError(
588+
"Input to export_llama must be either of type argparse.Namespace or LlmConfig"
589+
)
590+
591+
# TODO: refactor rest of export_llama to use llm_config instead of args.
592+
572593
# If a checkpoint isn't provided for an HF OSS model, download and convert the
573594
# weights first.
574595
if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS:

examples/models/llama/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# Install tokenizers for hf .json tokenizer.
1111
# Install snakeviz for cProfile flamegraph
1212
# Install lm-eval for Model Evaluation with lm-evalution-harness.
13-
pip install huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
13+
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
1414

1515
# Call the install helper for further setup
1616
python examples/models/llama/install_requirement_helper.py

0 commit comments

Comments
 (0)