Skip to content

Commit a74ddc9

Browse files
committed
fix(convert hf): Better logic to handle multiple weight mapping files
This will not actually be needed for mistral with the fix in download to handle .bin files, but it may be needed for other models, so it's worth having. Branch: download-fix Signed-off-by: Gabe Goodhart <[email protected]>
1 parent bf5140a commit a74ddc9

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,14 @@ def convert_hf_checkpoint(
3939
config = TransformerArgs.from_params(config_args)
4040
print(f"Model config {config.__dict__}")
4141

42-
# Load the json file containing weight mapping
42+
# Find all candidate weight mapping index files
4343
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
44-
if "mistral" not in model_name:
45-
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
46-
if len(model_map_json_matches):
47-
model_map_json = model_map_json_matches[0]
48-
else:
49-
model_map_json = model_dir / "pytorch_model.bin.index.json"
5044

5145
# If there is no weight mapping, check for a consolidated model and
5246
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
5347
# Llama 3 has a consolidated model and tokenizer.
5448
# Otherwise raise an error.
55-
if not model_map_json.is_file():
49+
if not model_map_json_matches:
5650
consolidated_pth = model_dir / "original" / "consolidated.00.pth"
5751
tokenizer_pth = model_dir / "original" / "tokenizer.model"
5852
if consolidated_pth.is_file() and tokenizer_pth.is_file():
@@ -69,11 +63,29 @@ def convert_hf_checkpoint(
6963
return
7064
else:
7165
raise RuntimeError(
72-
f"Could not find {model_map_json} or {consolidated_pth} plus {tokenizer_pth}"
66+
f"Could not find a valid model weight map or {consolidated_pth} plus {tokenizer_pth}"
7367
)
7468

75-
with open(model_map_json) as json_map:
76-
bin_index = json.load(json_map)
69+
# Load the json file(s) containing weight mapping
70+
#
71+
# NOTE: If there are multiple index files, there are two possibilities:
72+
# 1. The files could be mapped to different weight format files (e.g. .bin
73+
# vs .safetensors)
74+
# 2. The files could be split subsets of the mappings that need to be
75+
# merged
76+
#
77+
# In either case, we can simply keep the mappings where the target file is
78+
# valid in the model dir.
79+
bin_files = {}
80+
for weight_map_file in model_map_json_matches:
81+
with open(weight_map_file, "r") as handle:
82+
weight_map = json.load(handle)
83+
valid_mappings = {
84+
k: model_dir / v
85+
for (k, v) in weight_map.get("weight_map", {}).items()
86+
if (model_dir / v).is_file()
87+
}
88+
bin_files.update(valid_mappings)
7789

7890
weight_map = {
7991
"model.embed_tokens.weight": "tok_embeddings.weight",
@@ -97,7 +109,6 @@ def convert_hf_checkpoint(
97109
"model.norm.weight": "norm.weight",
98110
"lm_head.weight": "output.weight",
99111
}
100-
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}
101112

102113
def permute(w, n_heads):
103114
return (

0 commit comments

Comments
 (0)