-
Notifications
You must be signed in to change notification settings - Fork 607
Introduce hydra framework with backwards compatibility #11029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
673e39f
c48b67e
ae14d3e
f8fc412
44fa6dc
6c6cf65
c501c37
6967137
8069737
b7de48b
048fa21
e53c704
6a4aa22
d9c2e46
8a75fb0
1c75e94
4dd719d
5fe7814
f98618d
ea61786
10c3aba
bfd2dec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,30 +4,50 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Example script for exporting Llama2 to flatbuffer | ||
|
||
import logging | ||
|
||
# force=True to ensure logging while in debugger. Set up logger before any | ||
# other imports. | ||
import logging | ||
|
||
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" | ||
logging.basicConfig(level=logging.INFO, format=FORMAT, force=True) | ||
|
||
import argparse | ||
import runpy | ||
import sys | ||
|
||
import torch | ||
|
||
from .export_llama_lib import build_args_parser, export_llama | ||
|
||
sys.setrecursionlimit(4096) | ||
|
||
|
||
def parse_hydra_arg(): | ||
"""First parse out the arg for whether to use Hydra or the old CLI.""" | ||
parser = argparse.ArgumentParser(add_help=True) | ||
parser.add_argument("--hydra", action="store_true") | ||
args, remaining = parser.parse_known_args() | ||
return args.hydra, remaining | ||
|
||
|
||
def main() -> None: | ||
seed = 42 | ||
torch.manual_seed(seed) | ||
parser = build_args_parser() | ||
args = parser.parse_args() | ||
export_llama(args) | ||
|
||
use_hydra, remaining_args = parse_hydra_arg() | ||
if use_hydra: | ||
# The import runs the main function of export_llama_hydra with the remaining args | ||
# under the Hydra framework. | ||
sys.argv = [arg for arg in sys.argv if arg != "--hydra"] | ||
print(f"running with {sys.argv}") | ||
runpy.run_module( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not just use importlib? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh nevermind. you are running this module not importing it |
||
"executorch.examples.models.llama.export_llama_hydra", run_name="__main__" | ||
) | ||
else: | ||
# Use the legacy version of the export_llama script which uses argsparse. | ||
from executorch.examples.models.llama.export_llama_args import ( | ||
main as export_llama_args_main, | ||
) | ||
|
||
export_llama_args_main(remaining_args) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Run export_llama with the legacy argparse setup. | ||
""" | ||
|
||
from .export_llama_lib import build_args_parser, export_llama | ||
|
||
|
||
def main(args) -> None: | ||
parser = build_args_parser() | ||
args = parser.parse_args(args) | ||
export_llama(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Run export_llama using the new Hydra CLI. | ||
""" | ||
|
||
import hydra | ||
|
||
from executorch.examples.models.llama.config.llm_config import LlmConfig | ||
from executorch.examples.models.llama.export_llama_lib import export_llama | ||
from hydra.core.config_store import ConfigStore | ||
|
||
cs = ConfigStore.instance() | ||
cs.store(name="llm_config", node=LlmConfig) | ||
|
||
|
||
@hydra.main(version_base=None, config_name="llm_config") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i dont think i know enough about hydra but is there a test that showcases the use? |
||
def main(llm_config: LlmConfig) -> None: | ||
export_llama(llm_config) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you mean to print this?