Skip to content

Commit 6db321c

Browse files
Jack-Khuumalfet
authored andcommitted
Fix help generate (#891)
* Fixing the help mode of the download subcommand * Initial Addition of subparsers for generation * Move compile out of generation exclusive * typo * Fix test by removing temperature, which is a field eval doesn't use or expect * Typo Generater => Generator
1 parent 0f3bbec commit 6db321c

File tree

2 files changed

+104
-81
lines changed

2 files changed

+104
-81
lines changed

.ci/scripts/validate.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ function eval_model_sanity_check() {
284284
echo "*************************************************"
285285
if [ "$DTYPE" != "float16" ]; then
286286
python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
287-
python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
287+
python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
288288
cat "$MODEL_DIR/output_eval_aoti"
289289
fi;
290290

cli.py

Lines changed: 103 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -61,27 +61,14 @@ def add_arguments_for_verb(parser, verb: str) -> None:
6161
help="Model name for well-known models",
6262
)
6363

64-
parser.add_argument(
65-
"--chat",
66-
action="store_true",
67-
help="Whether to start an interactive chat session",
68-
)
64+
if verb in ["browser", "chat", "generate"]:
65+
_add_generation_args(parser)
66+
6967
parser.add_argument(
7068
"--distributed",
7169
action="store_true",
7270
help="Whether to enable distributed inference",
7371
)
74-
parser.add_argument(
75-
"--gui",
76-
action="store_true",
77-
help="Whether to use a web UI for an interactive chat session",
78-
)
79-
parser.add_argument(
80-
"--prompt",
81-
type=str,
82-
default="Hello, my name is",
83-
help="Input prompt",
84-
)
8572
parser.add_argument(
8673
"--is-chat-model",
8774
action="store_true",
@@ -93,54 +80,17 @@ def add_arguments_for_verb(parser, verb: str) -> None:
9380
default=None,
9481
help="Initialize torch seed",
9582
)
96-
parser.add_argument(
97-
"--num-samples",
98-
type=int,
99-
default=1,
100-
help="Number of samples",
101-
)
102-
parser.add_argument(
103-
"--max-new-tokens",
104-
type=int,
105-
default=200,
106-
help="Maximum number of new tokens",
107-
)
108-
parser.add_argument(
109-
"--top-k",
110-
type=int,
111-
default=200,
112-
help="Top-k for sampling",
113-
)
114-
parser.add_argument(
115-
"--temperature", type=float, default=0.8, help="Temperature for sampling"
116-
)
11783
parser.add_argument(
11884
"--compile",
11985
action="store_true",
12086
help="Whether to compile the model with torch.compile",
12187
)
122-
parser.add_argument(
123-
"--compile-prefill",
124-
action="store_true",
125-
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
126-
)
127-
parser.add_argument(
128-
"--sequential-prefill",
129-
action="store_true",
130-
help="Whether to perform prefill sequentially. Only used for model debug.",
131-
)
13288
parser.add_argument(
13389
"--profile",
13490
type=Path,
13591
default=None,
13692
help="Profile path.",
13793
)
138-
parser.add_argument(
139-
"--speculate-k",
140-
type=int,
141-
default=5,
142-
help="Speculative execution depth",
143-
)
14494
parser.add_argument(
14595
"--draft-checkpoint-path",
14696
type=Path,
@@ -171,30 +121,10 @@ def add_arguments_for_verb(parser, verb: str) -> None:
171121
default=None,
172122
help="Use the specified model tokenizer file",
173123
)
174-
parser.add_argument(
175-
"--output-pte-path",
176-
type=str,
177-
default=None,
178-
help="Output to the specified ExecuTorch .pte model file",
179-
)
180-
parser.add_argument(
181-
"--output-dso-path",
182-
type=str,
183-
default=None,
184-
help="Output to the specified AOT Inductor .dso model file",
185-
)
186-
parser.add_argument(
187-
"--dso-path",
188-
type=Path,
189-
default=None,
190-
help="Use the specified AOT Inductor .dso model file",
191-
)
192-
parser.add_argument(
193-
"--pte-path",
194-
type=Path,
195-
default=None,
196-
help="Use the specified ExecuTorch .pte model file",
197-
)
124+
125+
_add_exported_model_input_args(parser)
126+
_add_export_output_path_args(parser)
127+
198128
parser.add_argument(
199129
"--dtype",
200130
default="fast",
@@ -259,6 +189,40 @@ def add_arguments_for_verb(parser, verb: str) -> None:
259189
_add_cli_metadata_args(parser)
260190

261191

192+
# Add CLI Args representing user provided exported model files
193+
def _add_export_output_path_args(parser) -> None:
194+
output_path_parser = parser.add_argument_group("Export Output Path Args", "Specify the output path for the exported model files")
195+
output_path_parser.add_argument(
196+
"--output-pte-path",
197+
type=str,
198+
default=None,
199+
help="Output to the specified ExecuTorch .pte model file",
200+
)
201+
output_path_parser.add_argument(
202+
"--output-dso-path",
203+
type=str,
204+
default=None,
205+
help="Output to the specified AOT Inductor .dso model file",
206+
)
207+
208+
209+
# Add CLI Args representing user provided exported model files
210+
def _add_exported_model_input_args(parser) -> None:
211+
exported_model_path_parser = parser.add_argument_group("Exported Model Path Args", "Specify the path of the exported model files to ingest")
212+
exported_model_path_parser.add_argument(
213+
"--dso-path",
214+
type=Path,
215+
default=None,
216+
help="Use the specified AOT Inductor .dso model file",
217+
)
218+
exported_model_path_parser.add_argument(
219+
"--pte-path",
220+
type=Path,
221+
default=None,
222+
help="Use the specified ExecuTorch .pte model file",
223+
)
224+
225+
262226
# Add CLI Args that are relevant to any subcommand execution
263227
def _add_cli_metadata_args(parser) -> None:
264228
parser.add_argument(
@@ -297,22 +261,81 @@ def _configure_artifact_inventory_args(parser, verb: str) -> None:
297261
)
298262

299263

264+
# Add CLI Args specific to user prompted generation
265+
def _add_generation_args(parser) -> None:
266+
generator_parser = parser.add_argument_group("Generation Args", "Configs for generating output based on provided prompt")
267+
generator_parser.add_argument(
268+
"--prompt",
269+
type=str,
270+
default="Hello, my name is",
271+
help="Input prompt for manual output generation",
272+
)
273+
generator_parser.add_argument(
274+
"--chat",
275+
action="store_true",
276+
help="Whether to start an interactive chat session",
277+
)
278+
generator_parser.add_argument(
279+
"--gui",
280+
action="store_true",
281+
help="Whether to use a web UI for an interactive chat session",
282+
)
283+
generator_parser.add_argument(
284+
"--num-samples",
285+
type=int,
286+
default=1,
287+
help="Number of samples",
288+
)
289+
generator_parser.add_argument(
290+
"--max-new-tokens",
291+
type=int,
292+
default=200,
293+
help="Maximum number of new tokens",
294+
)
295+
generator_parser.add_argument(
296+
"--top-k",
297+
type=int,
298+
default=200,
299+
help="Top-k for sampling",
300+
)
301+
generator_parser.add_argument(
302+
"--temperature", type=float, default=0.8, help="Temperature for sampling"
303+
)
304+
generator_parser.add_argument(
305+
"--compile-prefill",
306+
action="store_true",
307+
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
308+
)
309+
generator_parser.add_argument(
310+
"--sequential-prefill",
311+
action="store_true",
312+
help="Whether to perform prefill sequentially. Only used for model debug.",
313+
)
314+
generator_parser.add_argument(
315+
"--speculate-k",
316+
type=int,
317+
default=5,
318+
help="Speculative execution depth",
319+
)
320+
321+
300322
# Add CLI Args specific to Model Evaluation
301323
def _add_evaluation_args(parser) -> None:
302-
parser.add_argument(
324+
eval_parser = parser.add_argument_group("Evaluation Args", "Configs for evaluating model performance")
325+
eval_parser.add_argument(
303326
"--tasks",
304327
nargs="+",
305328
type=str,
306329
default=["wikitext"],
307330
help="List of lm-eluther tasks to evaluate. Usage: --tasks task1 task2",
308331
)
309-
parser.add_argument(
332+
eval_parser.add_argument(
310333
"--limit",
311334
type=int,
312335
default=None,
313336
help="Number of samples to evaluate",
314337
)
315-
parser.add_argument(
338+
eval_parser.add_argument(
316339
"--max-seq-length",
317340
type=int,
318341
default=None,

0 commit comments

Comments
 (0)