Skip to content

Commit 57dee04

Browse files
gabe-l-hartleseb
andauthored
Download fix (#1366)
* fix: allow multiple weight mapping files for mistral Downloading a Mistral model fails because it includes multiple weight mapping files. The regression was introduced in commit `766bee9f4a1fcb187fae543a525495d3ff482097`. I'm unclear on the original intent, but perhaps the exception was meant to apply only to Granite models. This isn’t an ideal fix, but it does enable Mistral to be downloaded and used for chat. Signed-off-by: Sébastien Han <[email protected]> * fix(download): Fix safetensors/bin/pth download logic The previous logic didn't handle .bin files, so if a model (like mistral) has both .bin and .safetensors, it would download both. Branch: download-fix Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert hf): Better logic to handle multiple weight mapping files This will not actually be needed for mistral with the fix in download to handle .bin files, but it may be needed for other models, so it's worth having. Branch: download-fix Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Sébastien Han <[email protected]> Signed-off-by: Gabe Goodhart <[email protected]> Co-authored-by: Sébastien Han <[email protected]>
1 parent a6a6e61 commit 57dee04

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,14 @@ def convert_hf_checkpoint(
3939
config = TransformerArgs.from_params(config_args)
4040
print(f"Model config {config.__dict__}")
4141

42-
# Load the json file containing weight mapping
42+
# Find all candidate weight mapping index files
4343
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
44-
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
45-
if len(model_map_json_matches):
46-
model_map_json = model_map_json_matches[0]
47-
else:
48-
model_map_json = model_dir / "pytorch_model.bin.index.json"
4944

5045
# If there is no weight mapping, check for a consolidated model and
5146
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
5247
# Llama 3 has a consolidated model and tokenizer.
5348
# Otherwise raise an error.
54-
if not model_map_json.is_file():
49+
if not model_map_json_matches:
5550
consolidated_pth = model_dir / "original" / "consolidated.00.pth"
5651
tokenizer_pth = model_dir / "original" / "tokenizer.model"
5752
if consolidated_pth.is_file() and tokenizer_pth.is_file():
@@ -68,11 +63,30 @@ def convert_hf_checkpoint(
6863
return
6964
else:
7065
raise RuntimeError(
71-
f"Could not find {model_map_json} or {consolidated_pth} plus {tokenizer_pth}"
66+
f"Could not find a valid model weight map or {consolidated_pth} plus {tokenizer_pth}"
7267
)
7368

74-
with open(model_map_json) as json_map:
75-
bin_index = json.load(json_map)
69+
# Load the json file(s) containing weight mapping
70+
#
71+
# NOTE: If there are multiple index files, there are two possibilities:
72+
# 1. The files could be mapped to different weight format files (e.g. .bin
73+
# vs .safetensors)
74+
# 2. The files could be split subsets of the mappings that need to be
75+
# merged
76+
#
77+
# In either case, we can simply keep the mappings where the target file is
78+
# valid in the model dir.
79+
bin_index = {}
80+
for weight_map_file in model_map_json_matches:
81+
with open(weight_map_file, "r") as handle:
82+
weight_map = json.load(handle)
83+
valid_mappings = {
84+
k: model_dir / v
85+
for (k, v) in weight_map.get("weight_map", {}).items()
86+
if (model_dir / v).is_file()
87+
}
88+
bin_index.update(valid_mappings)
89+
bin_files = set(bin_index.values())
7690

7791
weight_map = {
7892
"model.embed_tokens.weight": "tok_embeddings.weight",
@@ -96,7 +110,6 @@ def convert_hf_checkpoint(
96110
"model.norm.weight": "norm.weight",
97111
"lm_head.weight": "output.weight",
98112
}
99-
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}
100113

101114
def permute(w, n_heads):
102115
return (

torchchat/cli/download.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,23 @@ def _download_hf_snapshot(
3535
model_info = model_info(model_config.distribution_path, token=hf_token)
3636
model_fnames = [f.rfilename for f in model_info.siblings]
3737

38-
# Check the model config for preference between safetensors and pth
38+
# Check the model config for preference between safetensors and pth/bin
3939
has_pth = any(f.endswith(".pth") for f in model_fnames)
40+
has_bin = any(f.endswith(".bin") for f in model_fnames)
4041
has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)
4142

42-
# If told to prefer safetensors, ignore pth files
43+
# If told to prefer safetensors, ignore pth/bin files
4344
if model_config.prefer_safetensors:
4445
if not has_safetensors:
4546
print(
4647
f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",
4748
file=sys.stderr,
4849
)
4950
exit(1)
50-
ignore_patterns = "*.pth"
51+
ignore_patterns = ["*.pth", "*.bin"]
5152

5253
# If the model has both, prefer pth files over safetensors
53-
elif has_pth and has_safetensors:
54+
elif (has_pth or has_bin) and has_safetensors:
5455
ignore_patterns = "*safetensors*"
5556

5657
# Otherwise, download everything

0 commit comments

Comments
 (0)