Skip to content

Commit 2466096

Browse files
orionrmalfet
authored andcommitted
[CLI] Cleaned up torchchat.py cli to include all of our options (#290)
Cleans up our CLI interface through `torchchat.py`, but also allows direct access to the command files. Updates help messages as well.
1 parent 21834b9 commit 2466096

File tree

7 files changed

+233
-102
lines changed

7 files changed

+233
-102
lines changed

cli.py

Lines changed: 124 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,29 @@
99

1010
import torch
1111

12-
default_device = "cpu"
12+
# CPU is always available and also exportable to ExecuTorch
13+
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
1314

1415
def check_args(args, name: str) -> None:
1516
pass
1617

18+
def add_arguments_for_chat(parser):
19+
# Only chat specific options should be here
20+
_add_arguments_common(parser)
21+
22+
23+
def add_arguments_for_browser(parser):
24+
# Only browser specific options should be here
25+
_add_arguments_common(parser)
26+
parser.add_argument(
27+
"--port",
28+
type=int,
29+
default=5000,
30+
help="Port for the web server in browser mode"
31+
)
32+
_add_arguments_common(parser)
33+
34+
1735
def add_arguments_for_download(parser):
1836
# Only download specific options should be here
1937
_add_arguments_common(parser)
@@ -33,158 +51,204 @@ def add_arguments_for_export(parser):
3351
# Only export specific options should be here
3452
_add_arguments_common(parser)
3553

36-
def add_arguments_for_browser(parser):
37-
# Only browser specific options should be here
38-
_add_arguments_common(parser)
39-
parser.add_argument(
40-
"--port",
41-
type=int,
42-
default=5000,
43-
help="Port for the web server for browser mode."
44-
)
4554

4655
def _add_arguments_common(parser):
4756
# Model specification. TODO Simplify this.
4857
# A model can be specified using a positional model name or HuggingFace
4958
# path. Alternatively, the model can be specified via --gguf-path or via
5059
# an explicit --checkpoint-dir, --checkpoint-path, or --tokenizer-path.
51-
5260
parser.add_argument(
5361
"model",
5462
type=str,
5563
nargs="?",
5664
default=None,
57-
help="Model name for well-known models.",
65+
help="Model name for well-known models",
5866
)
5967

68+
69+
def add_arguments(parser):
6070
# TODO: Refactor this so that only common options are here
61-
# and subcommand-specific options are inside individual
71+
# and command-specific options are inside individual
6272
# add_arguments_for_generate, add_arguments_for_export etc.
73+
6374
parser.add_argument(
64-
"--seed",
65-
type=int,
66-
default=1234, # set None for release
67-
help="Initialize torch seed",
68-
)
69-
parser.add_argument(
70-
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
75+
"--chat",
76+
action="store_true",
77+
help="Whether to start an interactive chat session",
7178
)
7279
parser.add_argument(
73-
"--tiktoken",
80+
"--gui",
7481
action="store_true",
75-
help="Whether to use tiktoken tokenizer.",
82+
help="Whether to use a web UI for an interactive chat session",
7683
)
7784
parser.add_argument(
78-
"--chat",
79-
action="store_true",
80-
help="Use torchchat for an interactive chat session.",
85+
"--prompt",
86+
type=str,
87+
default="Hello, my name is",
88+
help="Input prompt",
8189
)
8290
parser.add_argument(
8391
"--is-chat-model",
8492
action="store_true",
85-
help="Indicate that the model was trained to support chat functionality.",
93+
help="Indicate that the model was trained to support chat functionality",
8694
)
8795
parser.add_argument(
88-
"--gui",
96+
"--seed",
97+
type=int,
98+
default=None,
99+
help="Initialize torch seed",
100+
)
101+
parser.add_argument(
102+
"--tiktoken",
89103
action="store_true",
90-
help="Use torchchat to for an interactive gui-chat session.",
104+
help="Whether to use tiktoken tokenizer",
105+
)
106+
parser.add_argument(
107+
"--num-samples",
108+
type=int,
109+
default=1,
110+
help="Number of samples",
111+
)
112+
parser.add_argument(
113+
"--max-new-tokens",
114+
type=int,
115+
default=200,
116+
help="Maximum number of new tokens",
91117
)
92-
parser.add_argument("--num-samples", type=int, default=1, help="Number of samples.")
93118
parser.add_argument(
94-
"--max-new-tokens", type=int, default=200, help="Maximum number of new tokens."
119+
"--top-k",
120+
type=int,
121+
default=200,
122+
help="Top-k for sampling",
95123
)
96-
parser.add_argument("--top-k", type=int, default=200, help="Top-k for sampling.")
97124
parser.add_argument(
98-
"--temperature", type=float, default=0.8, help="Temperature for sampling."
125+
"--temperature",
126+
type=float,
127+
default=0.8,
128+
help="Temperature for sampling"
99129
)
100130
parser.add_argument(
101-
"--compile", action="store_true", help="Whether to compile the model."
131+
"--compile",
132+
action="store_true",
133+
help="Whether to compile the model with torch.compile",
102134
)
103135
parser.add_argument(
104136
"--compile-prefill",
105137
action="store_true",
106-
help="Whether to compile the prefill (improves prefill perf, but higher compile times)",
138+
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
139+
)
140+
parser.add_argument(
141+
"--profile",
142+
type=Path,
143+
default=None,
144+
help="Profile path.",
107145
)
108-
parser.add_argument("--profile", type=Path, default=None, help="Profile path.")
109146
parser.add_argument(
110-
"--speculate-k", type=int, default=5, help="Speculative execution depth."
147+
"--speculate-k",
148+
type=int,
149+
default=5,
150+
help="Speculative execution depth",
111151
)
112152
parser.add_argument(
113153
"--draft-checkpoint-path",
114154
type=Path,
115155
default=None,
116-
help="Draft checkpoint path.",
156+
help="Use the specified draft checkpoint path",
117157
)
118158
parser.add_argument(
119159
"--checkpoint-path",
120160
type=Path,
121161
default="not_specified",
122-
help="Model checkpoint path.",
162+
help="Use the specified model checkpoint path",
123163
)
124-
# parser.add_argument(
125-
# "--checkpoint-dir",
126-
# type=Path,
127-
# default=None,
128-
# help="Model checkpoint directory.",
129-
# )
130164
parser.add_argument(
131165
"--params-path",
132166
type=Path,
133167
default=None,
134-
help="Parameter file path.",
168+
help="Use the specified parameter file",
135169
)
136170
parser.add_argument(
137171
"--gguf-path",
138172
type=Path,
139173
default=None,
140-
help="GGUF file path.",
174+
help="Use the specified GGUF model file",
141175
)
142176
parser.add_argument(
143177
"--tokenizer-path",
144178
type=Path,
145179
default=None,
146-
help="Model checkpoint path.",
180+
help="Use the specified model tokenizer file",
181+
)
182+
parser.add_argument(
183+
"--output-pte-path",
184+
type=str,
185+
default=None,
186+
help="Output to the specified ExecuTorch .pte model file",
187+
)
188+
parser.add_argument(
189+
"--output-dso-path",
190+
type=str,
191+
default=None,
192+
help="Output to the specified AOT Inductor .dso model file",
147193
)
148-
parser.add_argument("--output-pte-path", type=str, default=None, help="Filename")
149-
parser.add_argument("--output-dso-path", type=str, default=None, help="Filename")
150194
parser.add_argument(
151-
"--dso-path", type=Path, default=None, help="Use the specified AOTI DSO model."
195+
"--dso-path",
196+
type=Path,
197+
default=None,
198+
help="Use the specified AOT Inductor .dso model file",
152199
)
153200
parser.add_argument(
154201
"--pte-path",
155202
type=Path,
156203
default=None,
157-
help="Use the specified Executorch PTE model.",
204+
help="Use the specified ExecuTorch .pte model file",
158205
)
159206
parser.add_argument(
160-
"-d",
161-
"--dtype",
207+
"-d", "--dtype",
162208
default="float32",
163209
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32",
164210
)
165-
parser.add_argument("-v", "--verbose", action="store_true")
166211
parser.add_argument(
167-
"--quantize", type=str, default="{ }", help="Quantization options."
212+
"-v", "--verbose",
213+
action="store_true",
214+
help="Verbose output",
215+
)
216+
parser.add_argument(
217+
"--quantize",
218+
type=str,
219+
default="{ }",
220+
help="Quantization options",
221+
)
222+
parser.add_argument(
223+
"--params-table",
224+
type=str,
225+
default=None,
226+
help="Parameter table to use",
168227
)
169-
parser.add_argument("--params-table", type=str, default=None, help="Device to use")
170228
parser.add_argument(
171-
"--device", type=str, default=default_device, help="Device to use"
229+
"--device",
230+
type=str,
231+
default=default_device,
232+
help="Hardware device to use. Options: cpu, gpu, mps",
172233
)
173234
parser.add_argument(
174235
"--tasks",
175236
nargs="+",
176237
type=str,
177238
default=["hellaswag"],
178-
help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2",
239+
help="List of lm-eluther tasks to evaluate. Usage: --tasks task1 task2",
179240
)
180241
parser.add_argument(
181-
"--limit", type=int, default=None, help="number of samples to evaluate"
242+
"--limit",
243+
type=int,
244+
default=None,
245+
help="Number of samples to evaluate",
182246
)
183247
parser.add_argument(
184248
"--max-seq-length",
185249
type=int,
186250
default=None,
187-
help="maximum length sequence to evaluate",
251+
help="Maximum length sequence to evaluate",
188252
)
189253
parser.add_argument(
190254
"--hf-token",
@@ -201,7 +265,6 @@ def _add_arguments_common(parser):
201265

202266

203267
def arg_init(args):
204-
205268
if Path(args.quantize).is_file():
206269
with open(args.quantize, "r") as f:
207270
args.quantize = json.loads(f.read())

eval.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
)
2121

2222
from build.model import Transformer
23-
from cli import add_arguments_for_eval, arg_init
23+
from cli import (
24+
add_arguments,
25+
add_arguments_for_eval,
26+
arg_init,
27+
)
2428
from download import download_and_convert, is_model_downloaded
2529
from generate import encode_tokens, model_forward
2630

@@ -281,7 +285,8 @@ def main(args) -> None:
281285

282286

283287
if __name__ == "__main__":
284-
parser = argparse.ArgumentParser(description="Export specific CLI.")
288+
parser = argparse.ArgumentParser(description="torchchat eval CLI")
289+
add_arguments(parser)
285290
add_arguments_for_eval(parser)
286291
args = parser.parse_args()
287292
args = arg_init(args)

export.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
_unset_gguf_kwargs,
1616
BuilderArgs,
1717
)
18-
from cli import add_arguments_for_export, arg_init, check_args
18+
from cli import (
19+
add_arguments,
20+
add_arguments_for_export,
21+
arg_init,
22+
check_args,
23+
)
1924
from download import download_and_convert, is_model_downloaded
2025
from export_aoti import export_model as export_model_aoti
2126

@@ -106,7 +111,8 @@ def main(args):
106111

107112

108113
if __name__ == "__main__":
109-
parser = argparse.ArgumentParser(description="Export specific CLI.")
114+
parser = argparse.ArgumentParser(description="torchchat export CLI")
115+
add_arguments(parser)
110116
add_arguments_for_export(parser)
111117
args = parser.parse_args()
112118
check_args(args, "export")

export_aoti.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from torch.export import Dim
1212

13+
# CPU is always available and also exportable to ExecuTorch
1314
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
1415

1516

export_et.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torch._export import capture_pre_autograd_graph
3131

3232

33+
# CPU is always available and also exportable to ExecuTorch
3334
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
3435

3536

generate.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
TokenizerArgs,
2727
)
2828
from build.model import Transformer
29-
from cli import add_arguments_for_generate, arg_init, check_args
29+
from cli import (
30+
add_arguments,
31+
add_arguments_for_generate,
32+
arg_init,
33+
check_args,
34+
)
3035
from download import download_and_convert, is_model_downloaded
3136
from quantize import set_precision
3237

@@ -568,7 +573,8 @@ def main(args):
568573

569574

570575
if __name__ == "__main__":
571-
parser = argparse.ArgumentParser(description="Generate specific CLI.")
576+
parser = argparse.ArgumentParser(description="torchchat generate CLI")
577+
add_arguments(parser)
572578
add_arguments_for_generate(parser)
573579
args = parser.parse_args()
574580
check_args(args, "generate")

0 commit comments

Comments
 (0)