Skip to content

Commit 692e834

Browse files
authored
Add additional supported models to config/data/models.json (#329)
1 parent 21dbd1e commit 692e834

File tree

6 files changed

+76
-25
lines changed

6 files changed

+76
-25
lines changed

build/builder.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch._dynamo.config
1616
import torch._inductor.config
17+
1718
from config.model_config import resolve_model_config
1819
from quantize import name_to_dtype, quantize_model
1920

@@ -76,6 +77,7 @@ def from_args(cls, args): # -> BuilderArgs:
7677
checkpoint_dir = args.checkpoint_dir
7778

7879
checkpoint_path = args.checkpoint_path
80+
params_table = args.params_table
7981
if args.model: # Using a named, well-known model
8082
model_config = resolve_model_config(args.model)
8183

@@ -84,6 +86,9 @@ def from_args(cls, args): # -> BuilderArgs:
8486
/ model_config.name
8587
/ model_config.checkpoint_file
8688
)
89+
# The transformers config is keyed on the last section
90+
# of the name/path.
91+
params_table = model_config.transformer_params_key or model_config.name.split("/")[-1]
8792

8893
is_chat_model = False
8994
if args.is_chat_model:
@@ -108,7 +113,7 @@ def from_args(cls, args): # -> BuilderArgs:
108113
checkpoint_dir=checkpoint_dir,
109114
checkpoint_path=checkpoint_path,
110115
params_path=args.params_path,
111-
params_table=args.params_table,
116+
params_table=params_table,
112117
gguf_path=args.gguf_path,
113118
gguf_kwargs=None,
114119
dso_path=args.dso_path,
@@ -147,9 +152,8 @@ def from_args(cls, args): # -> TokenizerArgs:
147152
tokenizer_path = args.tokenizer_path
148153
elif args.model: # Using a named, well-known model
149154
model_config = resolve_model_config(args.model)
150-
tokenizer_path = (
151-
Path(args.model_directory) / model_config.name / "tokenizer.model"
152-
)
155+
tokenizer_path = Path(args.model_directory) / model_config.name / model_config.tokenizer_file
156+
153157
elif args.checkpoint_path:
154158
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
155159
elif hasattr(args, "checkpoint_dir") and args.checkpoint_dir:
@@ -234,7 +238,7 @@ def _load_model_default(builder_args):
234238
if builder_args.params_path:
235239
model = Transformer.from_params(builder_args.params_path)
236240
elif builder_args.params_table:
237-
model = Transformer.from_table(builder_args.params_path)
241+
model = Transformer.from_table(builder_args.params_table)
238242
else:
239243
model = Transformer.from_name(builder_args.checkpoint_path.parent.name)
240244

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"n_layers": 8, "n_heads": 8, "dim": 512, "hidden_dim": 1376}

build/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def from_name(cls, name: str):
112112
return ModelArgs.from_params(config_path / f"{config[0]}.json")
113113

114114

115-
116115
class KVCache(nn.Module):
117116
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=None):
118117
super().__init__()

config/data/models.json

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,38 @@
11
{
2-
"meta-llama/Meta-Llama-3-8B-Instruct": {
3-
"aliases": ["llama3", "llama3-8b"],
2+
"meta-llama/Llama-2-7b-hf": {
3+
"aliases": ["llama2", "llama2-7b"],
44
"distribution_channel": "HuggingFaceSnapshot",
5-
"distribution_path": "meta-llama/Meta-Llama-3-8B-Instruct"
5+
"distribution_path": "meta-llama/Llama-2-7b-hf",
6+
"transformer_params_key": "7B"
67
},
78
"meta-llama/Llama-2-7b-chat-hf": {
8-
"aliases": ["llama2", "llama2-7b"],
9+
"aliases": ["llama2-chat", "llama2-7b-chat"],
10+
"distribution_channel": "HuggingFaceSnapshot",
11+
"distribution_path": "meta-llama/Llama-2-7b-chat-hf",
12+
"transformer_params_key": "7B"
13+
},
14+
"meta-llama/Llama-2-13b-chat-hf": {
15+
"aliases": ["llama2-13b-chat"],
916
"distribution_channel": "HuggingFaceSnapshot",
10-
"distribution_path": "meta-llama/Llama-2-7b-chat-hf"
17+
"distribution_path": "meta-llama/Llama-2-13b-chat-hf",
18+
"transformer_params_key": "13B"
19+
},
20+
"meta-llama/Llama-2-70b-chat-hf": {
21+
"aliases": ["llama2-70b-chat"],
22+
"distribution_channel": "HuggingFaceSnapshot",
23+
"distribution_path": "meta-llama/Llama-2-70b-chat-hf",
24+
"transformer_params_key": "70B"
25+
},
26+
"meta-llama/Meta-Llama-3-8B": {
27+
"aliases": ["llama3"],
28+
"distribution_channel": "HuggingFaceSnapshot",
29+
"distribution_path": "meta-llama/Meta-Llama-3-8B"
30+
},
31+
"meta-llama/Meta-Llama-3-8B-Instruct": {
32+
"aliases": ["llama3-chat", "llama3-instruct"],
33+
"distribution_channel": "HuggingFaceSnapshot",
34+
"distribution_path": "meta-llama/Meta-Llama-3-8B-Instruct",
35+
"transformer_params_key": "Meta-Llama-3-8B"
1136
},
1237
"meta-llama/CodeLlama-7b-Python-hf": {
1338
"aliases": ["codellama", "codellama-7b"],
@@ -17,7 +42,14 @@
1742
"mistralai/Mistral-7B-Instruct-v0.2": {
1843
"aliases": ["mistral-7b", "mistral-7b-instruct"],
1944
"distribution_channel": "HuggingFaceSnapshot",
20-
"distribution_path": "mistralai/Mistral-7B-Instruct-v0.2"
45+
"distribution_path": "mistralai/Mistral-7B-Instruct-v0.2",
46+
"transformer_params_key": "Mistral-7B"
47+
},
48+
"openlm-research/open_llama_7b": {
49+
"aliases": ["open-llama", "open-llama-7b"],
50+
"distribution_channel": "HuggingFaceSnapshot",
51+
"distribution_path": "openlm-research/open_llama_7b",
52+
"transformer_params_key": "7B"
2153
},
2254
"stories15M": {
2355
"distribution_channel": "DirectDownload",
@@ -27,6 +59,14 @@
2759
],
2860
"checkpoint_file": "stories15M.pt"
2961
},
62+
"stories42M": {
63+
"distribution_channel": "DirectDownload",
64+
"distribution_path": [
65+
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.pt",
66+
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model"
67+
],
68+
"checkpoint_file": "stories42M.pt"
69+
},
3070
"stories110M": {
3171
"distribution_channel": "DirectDownload",
3272
"distribution_path": [

config/model_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class ModelConfig:
4444
default=ModelDistributionChannel.HuggingFaceSnapshot
4545
)
4646
checkpoint_file: str = field(default="model.pth")
47+
tokenizer_file: str = field(default="tokenizer.model")
48+
transformer_params_key: str = field(default=None)
4749

4850

4951
# Keys are stored in lowercase.

download.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,49 +9,54 @@
99
from typing import Optional, Sequence
1010

1111
from build.convert_hf_checkpoint import convert_hf_checkpoint
12-
from config.model_config import ModelDistributionChannel, resolve_model_config
12+
from config.model_config import (
13+
ModelConfig,
14+
ModelDistributionChannel,
15+
resolve_model_config,
16+
)
1317

1418
from requests.exceptions import HTTPError
1519

1620

17-
def _download_and_convert_hf_snapshot(
18-
model: str, models_dir: Path, hf_token: Optional[str]
21+
def _download_hf_snapshot(
22+
model_config: ModelConfig, models_dir: Path, hf_token: Optional[str]
1923
):
20-
model_dir = models_dir / model
24+
model_dir = models_dir / model_config.name
2125
os.makedirs(model_dir, exist_ok=True)
2226

2327
from huggingface_hub import snapshot_download
2428

2529
# Download and store the HF model artifacts.
26-
print(f"Downloading {model} from Hugging Face...")
30+
print(f"Downloading {model_config.name} from HuggingFace...")
2731
try:
2832
snapshot_download(
29-
model,
33+
model_config.distribution_path,
3034
local_dir=model_dir,
3135
local_dir_use_symlinks=False,
3236
token=hf_token,
3337
ignore_patterns="*safetensors*",
3438
)
3539
except HTTPError as e:
3640
if e.response.status_code == 401:
41+
os.rmdir(model_dir)
3742
raise RuntimeError(
3843
"Access denied. Run huggingface-cli login to authenticate."
3944
)
40-
os.rmdir(model_dir)
4145
else:
4246
raise e
4347

48+
4449
# Convert the model to the torchchat format.
45-
print(f"Converting {model} to torchchat format...")
46-
convert_hf_checkpoint(model_dir=model_dir, model_name=model, remove_bin_files=True)
50+
print(f"Converting {model_config.name} to torchchat format...")
51+
convert_hf_checkpoint(model_dir=model_dir, model_name=model_config.name, remove_bin_files=True)
4752

4853

4954
def _download_direct(
50-
model: str,
55+
model_config: ModelConfig,
5156
urls: Sequence[str],
5257
models_dir: Path,
5358
):
54-
model_dir = models_dir / model
59+
model_dir = models_dir / model_config.name
5560
os.makedirs(model_dir, exist_ok=True)
5661

5762
for url in urls:
@@ -70,9 +75,9 @@ def download_and_convert(
7075
model_config.distribution_channel
7176
== ModelDistributionChannel.HuggingFaceSnapshot
7277
):
73-
_download_and_convert_hf_snapshot(model_config.name, models_dir, hf_token)
78+
_download_hf_snapshot(model_config, models_dir, hf_token)
7479
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
75-
_download_direct(model_config.name, model_config.distribution_path, models_dir)
80+
_download_direct(model_config, model_config.distribution_path, models_dir)
7681
else:
7782
raise RuntimeError(
7883
f"Unknown distribution channel {model_config.distribution_channel}."

0 commit comments

Comments
 (0)