Skip to content

Commit 9209436

Browse files
mikekgfbmalfet
authored andcommitted
Replace uses of str with Path, as per Nikita's suggestion (#326)
* move transformer configs into JSON files * fixes * replace str with path for path manipulation * code merge
1 parent 3bc5fe3 commit 9209436

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

build/model.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def find_multiple(n: int, k: int) -> int:
2222
return n
2323
return n + k - (n % k)
2424

25-
config_dir = f"{str(Path(__file__).parent)}/known_model_params"
26-
config_path = Path(config_dir)
25+
config_path = Path(f"{str(Path(__file__).parent)}/known_model_params")
2726

2827
@dataclass
2928
class ModelArgs:
@@ -74,25 +73,24 @@ def from_params(cls, params_path):
7473
@classmethod
7574
def from_table(cls, name: str):
7675
print(f"name {name}")
77-
json_path = Path(f"{config_dir}/{name}.json")
76+
json_path = config_path / f"{name}.json"
7877
if json_path.is_file():
7978
return ModelArgs.from_params(json_path)
8079
else:
81-
config_dir = f"{__file__}/known_model_params"
82-
known_model_params = [config.replace(".json", "") for config in os.listdir(config_dir)]
80+
known_model_params = [config.replace(".json", "") for config in os.listdir(config_path)]
8381
raise RuntimeError(f"unknown table index {name} for transformer config, must be from {known_model_params}")
8482

8583
@classmethod
8684
def from_name(cls, name: str):
87-
print(f"Name {name}")
88-
json_path=f"{config_dir}/{name}.json"
85+
print(f"name {name}")
86+
json_path=config_path / f"{name}.json"
8987
if Path(json_path).is_file():
9088
return ModelArgs.from_params(json_path)
9189

92-
known_model_params = [config.replace(".json", "") for config in os.listdir(config_dir)]
90+
known_model_params = [config.replace(".json", "") for config in os.listdir(config_path)]
9391

92+
print(f"known configs: {known_model_params}")
9493
# Fuzzy search by name (e.g. "7B" and "Mistral-7B")
95-
print(f"Known configs: {known_model_params}")
9694
config = [
9795
config
9896
for config in known_model_params
@@ -111,7 +109,7 @@ def from_name(cls, name: str):
111109
f"Unknown model directory name {name}. Must be one of {known_model_params}."
112110
)
113111

114-
return ModelArgs.from_params(f"{config_dir}/{config[0]}.json")
112+
return ModelArgs.from_params(config_path / f"{config[0]}.json")
115113

116114

117115

0 commit comments

Comments
 (0)