Skip to content

Commit 766bee9

Browse files
authored
Safetensors (#1255)
* feat(granite): Add support for finding weight mapping files with other names Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * feat(granite): Add support for loading state_dict from safetensors Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * feat(safetensors): Use model_info to determine whether to download pth or safetensors The logic here will prefer pth over safetensors unless the model's config explicitly states a preference for safetensors over pth. If only one of the two is found, the download will use whichever is present. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]>
1 parent d8c0aaf commit 766bee9

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import glob
67
import json
78
import os
89
import re
@@ -41,7 +42,12 @@ def convert_hf_checkpoint(
4142
print(f"Model config {config.__dict__}")
4243

4344
# Load the json file containing weight mapping
44-
model_map_json = model_dir / "pytorch_model.bin.index.json"
45+
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
46+
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
47+
if len(model_map_json_matches):
48+
model_map_json = model_map_json_matches[0]
49+
else:
50+
model_map_json = model_dir / "pytorch_model.bin.index.json"
4551

4652
# If there is no weight mapping, check for a consolidated model and
4753
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
@@ -96,9 +102,33 @@ def permute(w, n_heads):
96102

97103
merged_result = {}
98104
for file in sorted(bin_files):
99-
state_dict = torch.load(
105+
106+
# The state_dict can be loaded from either a torch zip file or
107+
# safetensors. We take our best guess from the name and try all
108+
# possibilities
109+
load_pt_mmap = lambda: torch.load(
100110
str(file), map_location="cpu", mmap=True, weights_only=True
101111
)
112+
load_pt_no_mmap = lambda: torch.load(
113+
str(file), map_location="cpu", mmap=False, weights_only=True
114+
)
115+
def load_safetensors():
116+
import safetensors.torch
117+
with open(file, "rb") as handle:
118+
return safetensors.torch.load(handle.read())
119+
if "safetensors" in str(file):
120+
loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap]
121+
else:
122+
loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors]
123+
124+
state_dict = None
125+
for loader in loaders:
126+
try:
127+
state_dict = loader()
128+
break
129+
except Exception:
130+
continue
131+
assert state_dict is not None, f"Unable to load tensors from {file}"
102132
merged_result.update(state_dict)
103133
final_result = {}
104134
for key, value in merged_result.items():

torchchat/cli/download.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,44 @@
2222
def _download_hf_snapshot(
2323
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
2424
):
25-
from huggingface_hub import snapshot_download
25+
from huggingface_hub import model_info, snapshot_download
2626
from requests.exceptions import HTTPError
2727

2828
# Download and store the HF model artifacts.
2929
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
3030
try:
31+
# Fetch the info about the model's repo
32+
model_info = model_info(model_config.distribution_path, token=hf_token)
33+
model_fnames = [f.rfilename for f in model_info.siblings]
34+
35+
# Check the model config for preference between safetensors and pth
36+
has_pth = any(f.endswith(".pth") for f in model_fnames)
37+
has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)
38+
39+
# If told to prefer safetensors, ignore pth files
40+
if model_config.prefer_safetensors:
41+
if not has_safetensors:
42+
print(
43+
f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",
44+
file=sys.stderr,
45+
)
46+
exit(1)
47+
ignore_patterns = "*.pth"
48+
49+
# If the model has both, prefer pth files over safetensors
50+
elif has_pth and has_safetensors:
51+
ignore_patterns = "*safetensors*"
52+
53+
# Otherwise, download everything
54+
else:
55+
ignore_patterns = None
56+
3157
snapshot_download(
3258
model_config.distribution_path,
3359
local_dir=artifact_dir,
3460
local_dir_use_symlinks=False,
3561
token=hf_token,
36-
ignore_patterns="*safetensors*",
62+
ignore_patterns=ignore_patterns,
3763
)
3864
except HTTPError as e:
3965
if e.response.status_code == 401: # Missing HuggingFace CLI login.

torchchat/model_config/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ModelConfig:
4646
checkpoint_file: str = field(default="model.pth")
4747
tokenizer_file: str = field(default="tokenizer.model")
4848
transformer_params_key: str = field(default=None)
49+
prefer_safetensors: bool = field(default=False)
4950

5051

5152
# Keys are stored in lowercase.

0 commit comments

Comments
 (0)