Skip to content

Commit 9f2e240

Browse files
committed
[BE] Replace copypasted with loops for arg parsing (#787)
* [BE] Replace copypasted with loops for arg parsing At least for defining subparsers * Fix loose reference
1 parent 3163ada commit 9f2e240

File tree

5 files changed

+30
-117
lines changed

5 files changed

+30
-117
lines changed

cli.py

Lines changed: 6 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,63 +27,21 @@
2727
).expanduser()
2828

2929

30+
KNOWN_VERBS = ["chat", "browser", "download", "generate", "eval", "export", "list", "remove", "where"]
31+
3032
# Handle CLI arguments that are common to a majority of subcommands.
31-
def check_args(args, name: str) -> None:
33+
def check_args(args, verb: str) -> None:
3234
# Handle model download. Skip this for download, since it has slightly
3335
# different semantics.
3436
if (
35-
name not in ["download", "list", "remove"]
37+
verb not in ["download", "list", "remove"]
3638
and args.model
3739
and not is_model_downloaded(args.model, args.model_directory)
3840
):
3941
download_and_convert(args.model, args.model_directory, args.hf_token)
4042

4143

42-
def add_arguments_for_chat(parser):
43-
# Only chat specific options should be here
44-
_add_arguments_common(parser)
45-
46-
47-
def add_arguments_for_browser(parser):
48-
# Only browser specific options should be here
49-
_add_arguments_common(parser)
50-
51-
52-
def add_arguments_for_download(parser):
53-
# Only download specific options should be here
54-
_add_arguments_common(parser)
55-
56-
57-
def add_arguments_for_generate(parser):
58-
# Only generate specific options should be here
59-
_add_arguments_common(parser)
60-
61-
62-
def add_arguments_for_eval(parser):
63-
# Only eval specific options should be here
64-
_add_arguments_common(parser)
65-
66-
67-
def add_arguments_for_export(parser):
68-
# Only export specific options should be here
69-
_add_arguments_common(parser)
70-
71-
72-
def add_arguments_for_list(parser):
73-
# Only list specific options should be here
74-
_add_arguments_common(parser)
75-
76-
77-
def add_arguments_for_remove(parser):
78-
# Only remove specific options should be here
79-
_add_arguments_common(parser)
80-
81-
def add_arguments_for_where(parser):
82-
# Only remove specific options should be here
83-
_add_arguments_common(parser)
84-
85-
86-
def _add_arguments_common(parser):
44+
def add_arguments_for_verb(parser, verb: str):
8745
# Model specification. TODO Simplify this.
8846
# A model can be specified using a positional model name or HuggingFace
8947
# path. Alternatively, the model can be specified via --gguf-path or via
@@ -316,7 +274,7 @@ def arg_init(args):
316274
)
317275

318276
if sys.version_info.major != 3 or sys.version_info.minor < 10:
319-
raise RuntimeError("Please use Python 3.10 or later.")
277+
raise RuntimeError("Please use Python 3.10 or later.")
320278

321279
if hasattr(args, "quantize") and Path(args.quantize).is_file():
322280
with open(args.quantize, "r") as f:

eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from build.model import Transformer
2222
from build.utils import set_precision
23-
from cli import add_arguments_for_eval, arg_init
23+
from cli import add_arguments_for_verb, arg_init
2424
from generate import encode_tokens, model_forward
2525

2626
torch._dynamo.config.automatic_dynamic_shapes = True
@@ -278,7 +278,7 @@ def main(args) -> None:
278278

279279
if __name__ == "__main__":
280280
parser = argparse.ArgumentParser(description="torchchat eval CLI")
281-
add_arguments_for_eval(parser)
281+
add_arguments_for_verb(parser, "eval")
282282
args = parser.parse_args()
283283
args = arg_init(args)
284284
main(args)

export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020

2121
from build.utils import set_backend, set_precision
22-
from cli import add_arguments_for_export, arg_init, check_args
22+
from cli import add_arguments_for_verb, arg_init, check_args
2323
from export_aoti import export_model as export_model_aoti
2424

2525
try:
@@ -112,7 +112,7 @@ def main(args):
112112

113113
if __name__ == "__main__":
114114
parser = argparse.ArgumentParser(description="torchchat export CLI")
115-
add_arguments_for_export(parser)
115+
add_arguments_for_verb(parser, "export")
116116
args = parser.parse_args()
117117
check_args(args, "export")
118118
args = arg_init(args)

generate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from build.model import Transformer
2626
from build.utils import device_sync, set_precision
27-
from cli import add_arguments_for_generate, arg_init, check_args, logger
27+
from cli import add_arguments_for_verb, arg_init, check_args, logger
2828

2929
B_INST, E_INST = "[INST]", "[/INST]"
3030
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
@@ -801,8 +801,9 @@ def main(args):
801801

802802
if __name__ == "__main__":
803803
parser = argparse.ArgumentParser(description="torchchat generate CLI")
804-
add_arguments_for_generate(parser)
804+
verb = "generate"
805+
add_arguments_for_verb(parser, verb)
805806
args = parser.parse_args()
806-
check_args(args, "generate")
807+
check_args(args, verb)
807808
args = arg_init(args)
808809
main(args)

torchchat.py

Lines changed: 16 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,8 @@
1010
import sys
1111

1212
from cli import (
13-
add_arguments_for_browser,
14-
add_arguments_for_chat,
15-
add_arguments_for_download,
16-
add_arguments_for_eval,
17-
add_arguments_for_export,
18-
add_arguments_for_generate,
19-
add_arguments_for_list,
20-
add_arguments_for_remove,
21-
add_arguments_for_where,
13+
add_arguments_for_verb,
14+
KNOWN_VERBS,
2215
arg_init,
2316
check_args,
2417
)
@@ -39,59 +32,20 @@
3932
)
4033
subparsers.required = True
4134

42-
parser_chat = subparsers.add_parser(
43-
"chat",
44-
help="Chat interactively with a model",
45-
)
46-
add_arguments_for_chat(parser_chat)
47-
48-
parser_browser = subparsers.add_parser(
49-
"browser",
50-
help="Chat interactively in a browser",
51-
)
52-
add_arguments_for_browser(parser_browser)
53-
54-
parser_download = subparsers.add_parser(
55-
"download",
56-
help="Download a model from Hugging Face or others",
57-
)
58-
add_arguments_for_download(parser_download)
59-
60-
parser_generate = subparsers.add_parser(
61-
"generate",
62-
help="Generate responses from a model given a prompt",
63-
)
64-
add_arguments_for_generate(parser_generate)
65-
66-
parser_eval = subparsers.add_parser(
67-
"eval",
68-
help="Evaluate a model given a prompt",
69-
)
70-
add_arguments_for_eval(parser_eval)
71-
72-
parser_export = subparsers.add_parser(
73-
"export",
74-
help="Export a model for AOT Inductor or ExecuTorch",
75-
)
76-
add_arguments_for_export(parser_export)
77-
78-
parser_list = subparsers.add_parser(
79-
"list",
80-
help="List supported models",
81-
)
82-
add_arguments_for_list(parser_list)
83-
84-
parser_remove = subparsers.add_parser(
85-
"remove",
86-
help="Remove downloaded model artifacts",
87-
)
88-
add_arguments_for_remove(parser_remove)
89-
90-
parser_where = subparsers.add_parser(
91-
"where",
92-
help="Return directory containing downloaded model artifacts",
93-
)
94-
add_arguments_for_where(parser_where)
35+
VERB_HELP = {
36+
"chat": "Chat interactively with a model",
37+
"browser": "Chat interactively in a browser",
38+
"download": "Download a model from Hugging Face or others",
39+
"generate": "Generate responses from a model given a prompt",
40+
"eval": "Evaluate a model given a prompt",
41+
"export": "Export a model for AOT Inductor or ExecuTorch",
42+
"list": "List supported models",
43+
"remove": "Remove downloaded model artifacts",
44+
"where": "Return directory containing downloaded model artifacts",
45+
}
46+
for verb in KNOWN_VERBS:
47+
subparser = subparsers.add_parser(verb, help=VERB_HELP[verb])
48+
add_arguments_for_verb(subparser, verb)
9549

9650
# Now parse the arguments
9751
args = parser.parse_args()

0 commit comments

Comments
 (0)