Skip to content

Commit 365cf56

Browse files
GregoryComermalfet
authored andcommitted
Refactor common download logic into cli.py (#407)
1 parent e18577f commit 365cf56

File tree

5 files changed

+16
-16
lines changed

5 files changed

+16
-16
lines changed

cli.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99

1010
from build.utils import allowable_dtype_names, allowable_params_table
11+
from download import download_and_convert, is_model_downloaded
1112

1213
import torch
1314

@@ -19,6 +20,18 @@ def check_args(args, name: str) -> None:
1920
pass
2021

2122

23+
# Handle CLI arguments that are common to a majority of subcommands.
24+
def handle_common_args(args) -> None:
25+
# Handle model download. Skip this for download, since it has slightly
26+
# different semantics.
27+
if (
28+
args.command != "download"
29+
and args.model
30+
and not is_model_downloaded(args.model, args.model_directory)
31+
):
32+
download_and_convert(args.model, args.model_directory, args.hf_token)
33+
34+
2235
def add_arguments_for_chat(parser):
2336
# Only chat specific options should be here
2437
_add_arguments_common(parser)

eval.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from build.model import Transformer
2222
from build.utils import set_precision
2323
from cli import add_arguments, add_arguments_for_eval, arg_init
24-
from download import download_and_convert, is_model_downloaded
2524
from generate import encode_tokens, model_forward
2625

2726
torch._dynamo.config.automatic_dynamic_shapes = True
@@ -221,10 +220,6 @@ def main(args) -> None:
221220
222221
"""
223222

224-
# If a named model was provided and not downloaded, download it.
225-
if args.model and not is_model_downloaded(args.model, args.model_directory):
226-
download_and_convert(args.model, args.model_directory, args.hf_token)
227-
228223
builder_args = BuilderArgs.from_args(args)
229224
tokenizer_args = TokenizerArgs.from_args(args)
230225
quantize = args.quantize

export.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from build.utils import set_backend, set_precision, use_aoti_backend, use_et_backend
2222
from cli import add_arguments, add_arguments_for_export, arg_init, check_args
23-
from download import download_and_convert, is_model_downloaded
2423
from export_aoti import export_model as export_model_aoti
2524

2625
try:
@@ -35,11 +34,6 @@
3534

3635

3736
def main(args):
38-
# THIS BELONGS INTO CLI
39-
# If a named model was provided and not downloaded, download it.
40-
# if args.model and not is_model_downloaded(args.model, args.model_directory):
41-
# download_and_convert(args.model, args.model_directory, args.hf_token)
42-
4337
builder_args = BuilderArgs.from_args(args)
4438
quantize = args.quantize
4539

generate.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from build.model import Transformer
2727
from build.utils import device_sync, set_precision
2828
from cli import add_arguments, add_arguments_for_generate, arg_init, check_args
29-
from download import download_and_convert, is_model_downloaded
3029

3130
logger = logging.getLogger(__name__)
3231

@@ -586,10 +585,6 @@ def callback(x):
586585

587586

588587
def main(args):
589-
# If a named model was provided and not downloaded, download it.
590-
if args.model and not is_model_downloaded(args.model, args.model_directory):
591-
download_and_convert(args.model, args.model_directory, args.hf_token)
592-
593588
builder_args = BuilderArgs.from_args(args)
594589
speculative_builder_args = BuilderArgs.from_speculative_args(args)
595590
tokenizer_args = TokenizerArgs.from_args(args)

torchchat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
add_arguments_for_generate,
2020
arg_init,
2121
check_args,
22+
handle_common_args,
2223
)
2324

2425
default_device = "cpu"
@@ -97,6 +98,8 @@
9798
format="%(message)s", level=logging.DEBUG if args.verbose else logging.INFO
9899
)
99100

101+
handle_common_args(args)
102+
100103
if args.command == "chat":
101104
# enable "chat"
102105
args.chat = True

0 commit comments

Comments
 (0)