File tree Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Original file line number Diff line number Diff line change 5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
from __future__ import annotations
8
-
8
+ from typing import List
9
+ from pathlib import Path
10
+ import os
9
11
import logging
10
12
11
13
import torch
@@ -33,6 +35,8 @@ def name_to_dtype(name):
33
35
else :
34
36
raise RuntimeError (f"unsupported dtype name { name } specified" )
35
37
38
+ def allowable_dtype_names () -> List [str ]:
39
+ return name_to_dtype_dict .keys ()
36
40
37
41
name_to_dtype_dict = {
38
42
"fp32" : torch .float ,
@@ -45,6 +49,18 @@ def name_to_dtype(name):
45
49
"bfloat16" : torch .bfloat16 ,
46
50
}
47
51
52
+
53
+ #########################################################################
54
+ ### general model build utility functions for CLI ###
55
+
56
+ def allowable_params_table () -> List [dtr ]:
57
+ config_path = Path (f"{ str (Path (__file__ ).parent )} /known_model_params" )
58
+ known_model_params = [
59
+ config .replace (".json" , "" ) for config in os .listdir (config_path )
60
+ ]
61
+ return known_model_params
62
+
63
+
48
64
#########################################################################
49
65
### general model build utility functions ###
50
66
Original file line number Diff line number Diff line change 7
7
import json
8
8
from pathlib import Path
9
9
10
+ from build .utils import allowable_dtype_names , allowable_params_table
11
+
10
12
import torch
11
13
12
14
# CPU is always available and also exportable to ExecuTorch
@@ -208,6 +210,7 @@ def add_arguments(parser):
208
210
"-d" ,
209
211
"--dtype" ,
210
212
default = "float32" ,
213
+ choices = allowable_dtype_names (),
211
214
help = "Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32" ,
212
215
)
213
216
parser .add_argument (
@@ -239,12 +242,14 @@ def add_arguments(parser):
239
242
"--params-table" ,
240
243
type = str ,
241
244
default = None ,
245
+ choices = allowable_params_table (),
242
246
help = "Parameter table to use" ,
243
247
)
244
248
parser .add_argument (
245
249
"--device" ,
246
250
type = str ,
247
251
default = default_device ,
252
+ choices = ["cpu" , "cuda" , "mps" ],
248
253
help = "Hardware device to use. Options: cpu, cuda, mps" ,
249
254
)
250
255
parser .add_argument (
You can’t perform that action at this time.
0 commit comments