Skip to content

Commit 7be5645

Browse files
mikekgfbmalfet
authored andcommitted
cli (#398)
* cli * typos
1 parent 9c98d75 commit 7be5645

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

build/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from __future__ import annotations
8-
8+
from typing import List
9+
from pathlib import Path
10+
import os
911
import logging
1012

1113
import torch
@@ -33,6 +35,8 @@ def name_to_dtype(name):
3335
else:
3436
raise RuntimeError(f"unsupported dtype name {name} specified")
3537

38+
def allowable_dtype_names() -> List[str]:
39+
return name_to_dtype_dict.keys()
3640

3741
name_to_dtype_dict = {
3842
"fp32": torch.float,
@@ -45,6 +49,18 @@ def name_to_dtype(name):
4549
"bfloat16": torch.bfloat16,
4650
}
4751

52+
53+
#########################################################################
54+
### general model build utility functions for CLI ###
55+
56+
def allowable_params_table() -> List[dtr]:
57+
config_path = Path(f"{str(Path(__file__).parent)}/known_model_params")
58+
known_model_params = [
59+
config.replace(".json", "") for config in os.listdir(config_path)
60+
]
61+
return known_model_params
62+
63+
4864
#########################################################################
4965
### general model build utility functions ###
5066

cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import json
88
from pathlib import Path
99

10+
from build.utils import allowable_dtype_names, allowable_params_table
11+
1012
import torch
1113

1214
# CPU is always available and also exportable to ExecuTorch
@@ -208,6 +210,7 @@ def add_arguments(parser):
208210
"-d",
209211
"--dtype",
210212
default="float32",
213+
choices = allowable_dtype_names(),
211214
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32",
212215
)
213216
parser.add_argument(
@@ -239,12 +242,14 @@ def add_arguments(parser):
239242
"--params-table",
240243
type=str,
241244
default=None,
245+
choices=allowable_params_table(),
242246
help="Parameter table to use",
243247
)
244248
parser.add_argument(
245249
"--device",
246250
type=str,
247251
default=default_device,
252+
choices=["cpu", "cuda", "mps"],
248253
help="Hardware device to use. Options: cpu, cuda, mps",
249254
)
250255
parser.add_argument(

0 commit comments

Comments
 (0)