@@ -22,8 +22,7 @@ def find_multiple(n: int, k: int) -> int:
22
22
return n
23
23
return n + k - (n % k )
24
24
25
- config_dir = f"{ str (Path (__file__ ).parent )} /known_model_params"
26
- config_path = Path (config_dir )
25
+ config_path = Path (f"{ str (Path (__file__ ).parent )} /known_model_params" )
27
26
28
27
@dataclass
29
28
class ModelArgs :
@@ -74,25 +73,24 @@ def from_params(cls, params_path):
74
73
@classmethod
75
74
def from_table (cls , name : str ):
76
75
print (f"name { name } " )
77
- json_path = Path ( f"{ config_dir } / { name } .json" )
76
+ json_path = config_path / f"{ name } .json"
78
77
if json_path .is_file ():
79
78
return ModelArgs .from_params (json_path )
80
79
else :
81
- config_dir = f"{ __file__ } /known_model_params"
82
- known_model_params = [config .replace (".json" , "" ) for config in os .listdir (config_dir )]
80
+ known_model_params = [config .replace (".json" , "" ) for config in os .listdir (config_path )]
83
81
raise RuntimeError (f"unknown table index { name } for transformer config, must be from { known_model_params } " )
84
82
85
83
@classmethod
86
84
def from_name (cls , name : str ):
87
- print (f"Name { name } " )
88
- json_path = f" { config_dir } / { name } .json"
85
+ print (f"name { name } " )
86
+ json_path = config_path / f" { name } .json"
89
87
if Path (json_path ).is_file ():
90
88
return ModelArgs .from_params (json_path )
91
89
92
- known_model_params = [config .replace (".json" , "" ) for config in os .listdir (config_dir )]
90
+ known_model_params = [config .replace (".json" , "" ) for config in os .listdir (config_path )]
93
91
92
+ print (f"known configs: { known_model_params } " )
94
93
# Fuzzy search by name (e.g. "7B" and "Mistral-7B")
95
- print (f"Known configs: { known_model_params } " )
96
94
config = [
97
95
config
98
96
for config in known_model_params
@@ -111,7 +109,7 @@ def from_name(cls, name: str):
111
109
f"Unknown model directory name { name } . Must be one of { known_model_params } ."
112
110
)
113
111
114
- return ModelArgs .from_params (f" { config_dir } / { config [0 ]} .json" )
112
+ return ModelArgs .from_params (config_path / f" { config [0 ]} .json" )
115
113
116
114
117
115
0 commit comments