Skip to content

Commit 4b34158

Browse files
GregoryComermalfet
authored andcommitted
Update CLI arg handling (#488)
Fixes pytorch/torchchat#468 and pytorch/torchchat#466. Updates named arguments to be registered on subparsers, which allows removal of the arg re-ordering code.
1 parent 7c13f3b commit 4b34158

File tree

5 files changed

+6
-33
lines changed

5 files changed

+6
-33
lines changed

cli.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,6 @@ def _add_arguments_common(parser):
8080
help="Model name for well-known models",
8181
)
8282

83-
84-
def add_arguments(parser):
85-
# TODO: Refactor this so that only common options are here
86-
# and command-specific options are inside individual
87-
# add_arguments_for_generate, add_arguments_for_export etc.
88-
8983
parser.add_argument(
9084
"--chat",
9185
action="store_true",
@@ -301,10 +295,10 @@ def add_arguments(parser):
301295

302296

303297
def arg_init(args):
304-
if Path(args.quantize).is_file():
298+
if hasattr(args, 'quantize') and Path(args.quantize).is_file():
305299
with open(args.quantize, "r") as f:
306300
args.quantize = json.loads(f.read())
307301

308-
if args.seed:
302+
if hasattr(args, 'seed') and args.seed:
309303
torch.manual_seed(args.seed)
310304
return args

eval.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from build.model import Transformer
2222
from build.utils import set_precision
23-
from cli import add_arguments, add_arguments_for_eval, arg_init
23+
from cli import add_arguments_for_eval, arg_init
2424
from generate import encode_tokens, model_forward
2525

2626
torch._dynamo.config.automatic_dynamic_shapes = True
@@ -289,7 +289,6 @@ def main(args) -> None:
289289

290290
if __name__ == "__main__":
291291
parser = argparse.ArgumentParser(description="torchchat eval CLI")
292-
add_arguments(parser)
293292
add_arguments_for_eval(parser)
294293
args = parser.parse_args()
295294
args = arg_init(args)

export.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020

2121
from build.utils import set_backend, set_precision
22-
from cli import add_arguments, add_arguments_for_export, arg_init, check_args
22+
from cli import add_arguments_for_export, arg_init, check_args
2323
from export_aoti import export_model as export_model_aoti
2424

2525
try:
@@ -104,7 +104,6 @@ def main(args):
104104

105105
if __name__ == "__main__":
106106
parser = argparse.ArgumentParser(description="torchchat export CLI")
107-
add_arguments(parser)
108107
add_arguments_for_export(parser)
109108
args = parser.parse_args()
110109
check_args(args, "export")

generate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from build.model import Transformer
2727
from build.utils import device_sync, set_precision
28-
from cli import add_arguments, add_arguments_for_generate, arg_init, check_args
28+
from cli import add_arguments_for_generate, arg_init, check_args
2929

3030
logger = logging.getLogger(__name__)
3131

@@ -767,7 +767,6 @@ def main(args):
767767

768768
if __name__ == "__main__":
769769
parser = argparse.ArgumentParser(description="torchchat generate CLI")
770-
add_arguments(parser)
771770
add_arguments_for_generate(parser)
772771
args = parser.parse_args()
773772
check_args(args, "generate")

torchchat.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import sys
1111

1212
from cli import (
13-
add_arguments,
1413
add_arguments_for_browser,
1514
add_arguments_for_chat,
1615
add_arguments_for_download,
@@ -30,17 +29,14 @@
3029
# Initialize the top-level parser
3130
parser = argparse.ArgumentParser(
3231
prog="torchchat",
33-
description="Welcome to the torchchat CLI!",
3432
add_help=True,
3533
)
36-
# Default command is to print help
37-
parser.set_defaults(func=parser.print_help())
3834

39-
add_arguments(parser)
4035
subparsers = parser.add_subparsers(
4136
dest="command",
4237
help="The specific command to run",
4338
)
39+
subparsers.required = True
4440

4541
parser_chat = subparsers.add_parser(
4642
"chat",
@@ -90,20 +86,6 @@
9086
)
9187
add_arguments_for_remove(parser_remove)
9288

93-
# Move all flags to the front of sys.argv since we don't
94-
# want to use the subparser syntax
95-
flag_args = []
96-
positional_args = []
97-
i = 1
98-
while i < len(sys.argv):
99-
if sys.argv[i].startswith("-"):
100-
flag_args += sys.argv[i : i + 2]
101-
i += 2
102-
else:
103-
positional_args.append(sys.argv[i])
104-
i += 1
105-
sys.argv = sys.argv[:1] + flag_args + positional_args
106-
10789
# Now parse the arguments
10890
args = parser.parse_args()
10991
args = arg_init(args)

0 commit comments

Comments
 (0)