Skip to content

Commit f2e95fa

Browse files
committed
Introduce hydra framework with backwards compatibility
Pull Request resolved: #11029 @imported-using-ghimport Differential Revision: [D75263989](https://our.internmc.facebook.com/intern/diff/D75263989/) ghstack-source-id: 289208258
1 parent a011764 commit f2e95fa

File tree

7 files changed

+115
-11
lines changed

7 files changed

+115
-11
lines changed

examples/models/llama/TARGETS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ runtime.python_library(
132132
name = "export_library",
133133
srcs = [
134134
"export_llama.py",
135+
"export_llama_args.py",
136+
"export_llama_hydra.py",
135137
"export_llama_lib.py",
136138
"model.py",
137139
],
@@ -148,6 +150,8 @@ runtime.python_library(
148150
":source_transformation",
149151
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
150152
"//caffe2:torch",
153+
"//executorch/examples/models/llama/config:llm_config",
154+
"//executorch/backends/vulkan/_passes:vulkan_passes",
151155
"//executorch/exir/passes:init_mutable_pass",
152156
"//executorch/examples/models:model_base",
153157
"//executorch/examples/models:models",

examples/models/llama/config/llm_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Uses dataclasses, which integrate with OmegaConf and Hydra.
1313
"""
1414

15+
import argparse
1516
import ast
1617
import re
1718
from dataclasses import dataclass, field
@@ -468,6 +469,18 @@ class LlmConfig:
468469
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
469470
backend: BackendConfig = field(default_factory=BackendConfig)
470471

472+
@classmethod
473+
def from_args(args: argparse.Namespace) -> "LlmConfig":
474+
"""
475+
To support legacy purposes, this function converts CLI args from
476+
argparse to an LlmConfig, which is used by the LLM export process.
477+
"""
478+
llm_config = LlmConfig()
479+
480+
# TODO: conversion code.
481+
482+
return llm_config
483+
471484
def __post_init__(self):
472485
self._validate_low_bit()
473486

examples/models/llama/export_llama.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,50 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# Example script for exporting Llama2 to flatbuffer
8-
9-
import logging
10-
117
# force=True to ensure logging while in debugger. Set up logger before any
128
# other imports.
9+
import logging
10+
1311
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
1412
logging.basicConfig(level=logging.INFO, format=FORMAT, force=True)
1513

14+
import argparse
15+
import runpy
1616
import sys
1717

1818
import torch
1919

20-
from .export_llama_lib import build_args_parser, export_llama
21-
2220
sys.setrecursionlimit(4096)
2321

2422

23+
def parse_hydra_arg():
24+
"""First parse out the arg for whether to use Hydra or the old CLI."""
25+
parser = argparse.ArgumentParser(add_help=True)
26+
parser.add_argument("--hydra", action="store_true")
27+
args, remaining = parser.parse_known_args()
28+
return args.hydra, remaining
29+
30+
2531
def main() -> None:
2632
seed = 42
2733
torch.manual_seed(seed)
28-
parser = build_args_parser()
29-
args = parser.parse_args()
30-
export_llama(args)
34+
35+
use_hydra, remaining_args = parse_hydra_arg()
36+
if use_hydra:
37+
# The import runs the main function of export_llama_hydra with the remaining args
38+
# under the Hydra framework.
39+
sys.argv = [arg for arg in sys.argv if arg != "--hydra"]
40+
print(f"running with {sys.argv}")
41+
runpy.run_module(
42+
"executorch.examples.models.llama.export_llama_hydra", run_name="__main__"
43+
)
44+
else:
45+
# Use the legacy version of the export_llama script which uses argsparse.
46+
from executorch.examples.models.llama.export_llama_args import (
47+
main as export_llama_args_main,
48+
)
49+
50+
export_llama_args_main(remaining_args)
3151

3252

3353
if __name__ == "__main__":
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Run export_llama with the legacy argparse setup.
9+
"""
10+
11+
from .export_llama_lib import build_args_parser, export_llama
12+
13+
14+
def main(args) -> None:
15+
parser = build_args_parser()
16+
args = parser.parse_args(args)
17+
export_llama(args)
18+
19+
20+
if __name__ == "__main__":
21+
main()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Run export_llama using the new Hydra CLI.
9+
"""
10+
11+
import hydra
12+
13+
from executorch.examples.models.llama.config.llm_config import LlmConfig
14+
from executorch.examples.models.llama.export_llama_lib import export_llama
15+
from hydra.core.config_store import ConfigStore
16+
17+
cs = ConfigStore.instance()
18+
cs.store(name="llm_config", node=LlmConfig)
19+
20+
21+
@hydra.main(version_base=None, config_name="llm_config")
22+
def main(llm_config: LlmConfig) -> None:
23+
export_llama(llm_config)
24+
25+
26+
if __name__ == "__main__":
27+
main()

examples/models/llama/export_llama_lib.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from executorch.devtools.backend_debug import print_delegation_info
2828

2929
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
30+
31+
from executorch.examples.models.llama.config.llm_config import LlmConfig
3032
from executorch.examples.models.llama.hf_download import (
3133
download_and_convert_hf_checkpoint,
3234
)
@@ -50,6 +52,7 @@
5052
get_vulkan_quantizer,
5153
)
5254
from executorch.util.activation_memory_profiler import generate_memory_trace
55+
from omegaconf.dictconfig import DictConfig
5356

5457
from ..model_factory import EagerModelFactory
5558
from .source_transformation.apply_spin_quant_r1_r2 import (
@@ -567,7 +570,23 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
567570
return return_val
568571

569572

570-
def export_llama(args) -> str:
573+
def export_llama(
574+
export_options: Union[argparse.Namespace, DictConfig],
575+
) -> str:
576+
if isinstance(export_options, argparse.Namespace):
577+
# Legacy CLI.
578+
args = export_options
579+
llm_config = LlmConfig.from_args(export_options) # noqa: F841
580+
elif isinstance(export_options, DictConfig):
581+
# Hydra CLI.
582+
llm_config = export_options # noqa: F841
583+
else:
584+
raise ValueError(
585+
"Input to export_llama must be either of type argparse.Namespace or LlmConfig"
586+
)
587+
588+
# TODO: refactor rest of export_llama to use llm_config instead of args.
589+
571590
# If a checkpoint isn't provided for an HF OSS model, download and convert the
572591
# weights first.
573592
if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS:

examples/models/llama/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# Install tokenizers for hf .json tokenizer.
1111
# Install snakeviz for cProfile flamegraph
1212
# Install lm-eval for Model Evaluation with lm-evalution-harness.
13-
pip install huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
13+
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
1414

1515
# Call the install helper for further setup
1616
python examples/models/llama/install_requirement_helper.py

0 commit comments

Comments
 (0)