Skip to content

Commit 625c12c

Browse files
committed
feat(safetensors): Use model_info to determine whether to download pth or safetensors
The logic here will prefer pth over safetensors unless the model's config explicitly states a preference for safetensors over pth. If only one of the two is found, the download will use whichever is present. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 97b4be6 commit 625c12c

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

torchchat/cli/download.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,44 @@
2222
def _download_hf_snapshot(
2323
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
2424
):
25-
from huggingface_hub import snapshot_download
25+
from huggingface_hub import model_info, snapshot_download
2626
from requests.exceptions import HTTPError
2727

2828
# Download and store the HF model artifacts.
2929
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
3030
try:
31+
# Fetch the info about the model's repo
32+
model_info = model_info(model_config.distribution_path, token=hf_token)
33+
model_fnames = [f.rfilename for f in model_info.siblings]
34+
35+
# Check the model config for preference between safetensors and pth
36+
has_pth = any(f.endswith(".pth") for f in model_fnames)
37+
has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)
38+
39+
# If told to prefer safetensors, ignore pth files
40+
if model_config.prefer_safetensors:
41+
if not has_safetensors:
42+
print(
43+
f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",
44+
file=sys.stderr,
45+
)
46+
exit(1)
47+
ignore_patterns = "*.pth"
48+
49+
# If the model has both, prefer pth files over safetensors
50+
elif has_pth and has_safetensors:
51+
ignore_patterns = "*safetensors*"
52+
53+
# Otherwise, download everything
54+
else:
55+
ignore_patterns = None
56+
3157
snapshot_download(
3258
model_config.distribution_path,
3359
local_dir=artifact_dir,
3460
local_dir_use_symlinks=False,
3561
token=hf_token,
36-
ignore_patterns="*safetensors*",
62+
ignore_patterns=ignore_patterns,
3763
)
3864
except HTTPError as e:
3965
if e.response.status_code == 401: # Missing HuggingFace CLI login.

torchchat/model_config/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ModelConfig:
4646
checkpoint_file: str = field(default="model.pth")
4747
tokenizer_file: str = field(default="tokenizer.model")
4848
transformer_params_key: str = field(default=None)
49+
prefer_safetensors: bool = field(default=False)
4950

5051

5152
# Keys are stored in lowercase.

0 commit comments

Comments
 (0)