Skip to content

Commit 80ffe93

Browse files
GregoryComermalfet
authored andcommitted
Implement download subcommand, optional positional model name argument (#234)
* Implement download option * Add support for model aliases * Support model name as a positional parameter * Merge GenerateArgs changes * Run lint * Revert chat subcommand/arg changes * Add mistral-7b-instruct alias, fix lints * Add model config for known models * Move known model config to config/models.json * Make model names case-insensitive * Move known model configuration from build/model.py to config/model_config.py * Fix lints * Fixing issues after rebasing * Update README
1 parent f25cd19 commit 80ffe93

File tree

15 files changed

+327
-75
lines changed

15 files changed

+327
-75
lines changed

.github/workflows/pull.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,11 @@ jobs:
361361
cat ./output_eager2
362362
echo "Tests complete."
363363
364+
- name: Test download
365+
run: |
366+
367+
python torchchat.py generate stories15M
368+
364369
test-tinystories-eager:
365370
strategy:
366371
matrix:

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ __pycache__/
55

66
# C extensions
77
*.so
8+
9+
.model-artifacts/

README.md

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,10 @@ python torchchat.py --help
3232
3333
```
3434

35-
### Dowenload a Model and Tokenizer
35+
### Generating Text
3636

3737
```
38-
#download a model
39-
python torchchat.py download llama2
40-
41-
#generate text using the model
42-
38+
python torchchat.py generate stories15M
4339
```
4440
That’s all there is to it!
4541
Read on to learn how to use the full power of torchchat.
@@ -48,7 +44,15 @@ Read on to learn how to use the full power of torchchat.
4844
For the full details on all commands and parameters run `python torchchat.py --help`
4945

5046
### Download
51-
TODO: Fill this out
47+
For supported models, torchchat can download model weights. Most models use HuggingFace as the distribution channel, so you will need to create a HuggingFace
48+
account and install `huggingface-cli`.
49+
50+
To install `huggingface-cli`, run `pip install huggingface-cli`. After installing, create a user access token [as documented here](https://huggingface.co/docs/hub/en/security-tokens). Run `huggingface-cli login`, which will prompt for the newly created token. Once this is done, torchchat will be able to download model artifacts from
51+
HuggingFace.
52+
53+
```
54+
python torchchat.py download llama2
55+
```
5256

5357
### Chat
5458
Designed for interactive and conversational use.
@@ -69,7 +73,7 @@ For more information run `python torchchat.py generate --help`
6973

7074
**Examples**
7175
```
72-
#Generate for Mac with some parameters
76+
python torchchat.py generate llama2 --device=cpu --dtype=fp16
7377
```
7478

7579
### Export
@@ -80,7 +84,7 @@ For more information run `python torchchat.py export --help`
8084
**Examples**
8185

8286
```
83-
#Export Example
87+
python torchchat.py export stories15M --output-pte-path=stories15m.pte
8488
```
8589

8690
### Browser

build/builder.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
import torch._dynamo.config
1616
import torch._inductor.config
17-
17+
from config.model_config import resolve_model_config
1818
from quantize import name_to_dtype, quantize_model
1919

2020
from sentencepiece import SentencePieceProcessor
@@ -42,7 +42,7 @@ class BuilderArgs:
4242
def __post_init__(self):
4343
if not (
4444
(self.checkpoint_path and self.checkpoint_path.is_file())
45-
or (self.checkpoint_dir and self.checkpoint_path.is_dir())
45+
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
4646
or (self.gguf_path and self.gguf_path.is_file())
4747
or (self.dso_path and Path(self.dso_path).is_file())
4848
or (self.pte_path and Path(self.pte_path).is_file())
@@ -73,7 +73,17 @@ def from_args(cls, args): # -> BuilderArgs:
7373
# Handle disabled checkpoint_dir option
7474
checkpoint_dir = None
7575
if hasattr(args, "checkpoint_dir"):
76-
checkpoint_dir = args.checkpoint_dir
76+
checkpoint_dir = args.checkpoint_dir
77+
78+
checkpoint_path = args.checkpoint_path
79+
if args.model: # Using a named, well-known model
80+
model_config = resolve_model_config(args.model)
81+
82+
checkpoint_path = (
83+
Path(args.model_directory)
84+
/ model_config.name
85+
/ model_config.checkpoint_file
86+
)
7787

7888
is_chat_model = False
7989
if args.is_chat_model:
@@ -94,8 +104,8 @@ def from_args(cls, args): # -> BuilderArgs:
94104
is_chat_model = True
95105

96106
return cls(
97-
checkpoint_path=args.checkpoint_path,
98107
checkpoint_dir=checkpoint_dir,
108+
checkpoint_path=checkpoint_path,
99109
params_path=args.params_path,
100110
params_table=args.params_table,
101111
gguf_path=args.gguf_path,
@@ -134,9 +144,12 @@ def from_args(cls, args): # -> TokenizerArgs:
134144

135145
if args.tokenizer_path:
136146
tokenizer_path = args.tokenizer_path
147+
elif args.model: # Using a named, well-known model
148+
model_config = resolve_model_config(args.model)
149+
tokenizer_path = Path(args.model_directory) / model_config.name / "tokenizer.model"
137150
elif args.checkpoint_path:
138151
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
139-
elif args.checkpoint_dir:
152+
elif hasattr(args, "checkpoint_dir") and args.checkpoint_dir:
140153
tokenizer_path = args.checkpoint_dir / "tokenizer.model"
141154
else:
142155
raise RuntimeError("cannot find tokenizer model")
@@ -356,4 +369,10 @@ def validate_args(model: Transformer, tokenizer_args: TokenizerArgs):
356369
is_tiktoken = tokenizer_args.is_tiktoken
357370
if use_tiktoken != is_tiktoken:
358371
raise RuntimeError(f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}")
359-
372+
373+
def resolve_model_name(model: str) -> str:
374+
# If the provided model name is an alias, retrieve the full path.
375+
if model in model_aliases:
376+
return model_aliases[model]
377+
else:
378+
return model

scripts/convert_hf_checkpoint.py renamed to build/convert_hf_checkpoint.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import json
7+
import os
78
import re
89
import sys
910
from pathlib import Path
@@ -22,19 +23,20 @@
2223
@torch.inference_mode()
2324
def convert_hf_checkpoint(
2425
*,
25-
checkpoint_dir: Optional[Path] = None,
26+
model_dir: Optional[Path] = None,
2627
model_name: Optional[str] = None,
28+
remove_bin_files: bool = False,
2729
) -> None:
28-
if checkpoint_dir is None:
29-
checkpoint_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
30+
if model_dir is None:
31+
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
3032
if model_name is None:
31-
model_name = checkpoint_dir.name
33+
model_name = model_dir.name
3234

3335
config = ModelArgs.from_name(model_name)
3436
print(f"Model config {config.__dict__}")
3537

3638
# Load the json file containing weight mapping
37-
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
39+
model_map_json = model_dir / "pytorch_model.bin.index.json"
3840

3941
assert model_map_json.is_file()
4042

@@ -56,7 +58,7 @@ def convert_hf_checkpoint(
5658
"model.norm.weight": "norm.weight",
5759
"lm_head.weight": "output.weight",
5860
}
59-
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
61+
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}
6062

6163
def permute(w, n_heads):
6264
dim = config.dim
@@ -97,8 +99,13 @@ def permute(w, n_heads):
9799
del final_result[key]
98100
del final_result[key.replace("wq", "wk")]
99101
del final_result[key.replace("wq", "wv")]
100-
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
101-
torch.save(final_result, checkpoint_dir / "model.pth")
102+
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
103+
torch.save(final_result, model_dir / "model.pth")
104+
print("Done.")
105+
106+
if remove_bin_files:
107+
for file in bin_files:
108+
os.remove(file)
102109

103110

104111
if __name__ == "__main__":
@@ -114,6 +121,6 @@ def permute(w, n_heads):
114121

115122
args = parser.parse_args()
116123
convert_hf_checkpoint(
117-
checkpoint_dir=args.checkpoint_dir,
124+
model_dir=args.checkpoint_dir,
118125
model_name=args.model_name,
119126
)

cli.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
def check_args(args, name: str) -> None:
1515
pass
1616

17+
def add_arguments_for_download(parser):
18+
# Only download specific options should be here
19+
_add_arguments_common(parser)
20+
21+
1722
def add_arguments_for_generate(parser):
1823
# Only generate specific options should be here
1924
_add_arguments_common(parser)
@@ -39,6 +44,19 @@ def add_arguments_for_browser(parser):
3944
)
4045

4146
def _add_arguments_common(parser):
47+
# Model specification. TODO Simplify this.
48+
# A model can be specified using a positional model name or HuggingFace
49+
# path. Alternatively, the model can be specified via --gguf-path or via
50+
# an explicit --checkpoint-dir, --checkpoint-path, or --tokenizer-path.
51+
52+
parser.add_argument(
53+
"model",
54+
type=str,
55+
nargs="?",
56+
default=None,
57+
help="Model name for well-known models.",
58+
)
59+
4260
# TODO: Refactor this so that only common options are here
4361
# and subcommand-specific options are inside individual
4462
# add_arguments_for_generate, add_arguments_for_export etc.
@@ -168,6 +186,18 @@ def _add_arguments_common(parser):
168186
default=None,
169187
help="maximum length sequence to evaluate",
170188
)
189+
parser.add_argument(
190+
"--hf-token",
191+
type=str,
192+
default=None,
193+
help="A HuggingFace API token to use when downloading model artifacts",
194+
)
195+
parser.add_argument(
196+
"--model-directory",
197+
type=Path,
198+
default=".model-artifacts",
199+
help="The directory to store downloaded model artifacts",
200+
)
171201

172202

173203
def arg_init(args):

config/__init__.py

Whitespace-only changes.

config/data/models.json

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"meta-llama/Llama-2-7b-chat-hf": {
3+
"aliases": ["llama2", "llama2-7b"],
4+
"distribution_channel": "HuggingFaceSnapshot",
5+
"distribution_path": "meta-llama/Llama-2-7b-chat-hf"
6+
},
7+
"mistralai/Mistral-7B-Instruct-v0.2": {
8+
"aliases": ["mistral-7b-instruct"],
9+
"distribution_channel": "HuggingFaceSnapshot",
10+
"distribution_path": "mistralai/Mistral-7B-Instruct-v0.2"
11+
},
12+
"stories15M": {
13+
"distribution_channel": "DirectDownload",
14+
"distribution_path": [
15+
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt",
16+
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model"
17+
],
18+
"checkpoint_file": "stories15M.pt"
19+
},
20+
"stories110M": {
21+
"distribution_channel": "DirectDownload",
22+
"distribution_path": [
23+
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt",
24+
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model"
25+
],
26+
"checkpoint_file": "stories110M.pt"
27+
}
28+
}

config/model_config.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import json
7+
from dataclasses import dataclass, field
8+
from enum import Enum
9+
from pathlib import Path
10+
from typing import Dict, Sequence, Union
11+
12+
"""
13+
Known Model Configs:
14+
15+
For models that are known to work with torchchat, we provide a config under
16+
config/data/models.json to support automatically downloading the model and
17+
converting to the expected format for use with torchchat.
18+
19+
There are two supported distribution channels:
20+
21+
1) HuggingFaceSnapshot: Download a model from HuggingFace.
22+
2) DirectDownload: Download a list of model artifacts from URLs. No conversion
23+
is done.
24+
"""
25+
26+
27+
# Specifies the distribution channel to download model artifacts from. Enum
28+
# variants are specified as strings to simplify JSON (de)serialization.
29+
class ModelDistributionChannel(str, Enum):
30+
# Download a full model snapshot from HuggingFace, such as
31+
# meta-llama/Llama-2-7b-chat-hf and convert to torchchat format.
32+
HuggingFaceSnapshot = "HuggingFaceSnapshot"
33+
34+
# Download one or more files over HTTP(S).
35+
DirectDownload = "DirectDownload"
36+
37+
38+
@dataclass
39+
class ModelConfig:
40+
name: str = field(default="")
41+
aliases: Sequence[str] = field(default_factory=list)
42+
distribution_path: Union[str, Sequence[str]] = field(default="")
43+
distribution_channel: ModelDistributionChannel = field(
44+
default=ModelDistributionChannel.HuggingFaceSnapshot
45+
)
46+
checkpoint_file: str = field(default="model.pth")
47+
48+
49+
# Keys are stored in lowercase.
50+
model_aliases: Dict[str, str] = None
51+
model_configs: Dict[str, ModelConfig] = None
52+
53+
54+
def resolve_model_config(model: str) -> ModelConfig:
55+
global model_aliases
56+
global model_configs
57+
58+
model = model.lower()
59+
60+
# Lazy load model config from JSON.
61+
if not model_configs:
62+
model_aliases = {}
63+
model_configs = {}
64+
65+
with open(
66+
Path(__file__).parent.parent / "config" / "data" / "models.json", "r"
67+
) as f:
68+
model_config_dict = json.load(f)
69+
70+
for key, value in model_config_dict.items():
71+
config = ModelConfig(**value)
72+
config.name = key
73+
74+
key = key.lower()
75+
model_configs[key] = config
76+
77+
for alias in config.aliases:
78+
model_aliases[alias.lower()] = key
79+
80+
if model in model_aliases:
81+
model = model_aliases[model]
82+
83+
if model not in model_configs:
84+
raise ValueError(f"Unknown model '{model}'.")
85+
86+
return model_configs[model]

0 commit comments

Comments
 (0)