@@ -39,20 +39,14 @@ def convert_hf_checkpoint(
39
39
config = TransformerArgs .from_params (config_args )
40
40
print (f"Model config { config .__dict__ } " )
41
41
42
- # Load the json file containing weight mapping
42
+ # Find all candidate weight mapping index files
43
43
model_map_json_matches = [Path (m ) for m in glob .glob (str (model_dir / "*.index.json" ))]
44
- if "mistral" not in model_name :
45
- assert len (model_map_json_matches ) <= 1 , "Found multiple weight mapping files"
46
- if len (model_map_json_matches ):
47
- model_map_json = model_map_json_matches [0 ]
48
- else :
49
- model_map_json = model_dir / "pytorch_model.bin.index.json"
50
44
51
45
# If there is no weight mapping, check for a consolidated model and
52
46
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
53
47
# Llama 3 has a consolidated model and tokenizer.
54
48
# Otherwise raise an error.
55
- if not model_map_json . is_file () :
49
+ if not model_map_json_matches :
56
50
consolidated_pth = model_dir / "original" / "consolidated.00.pth"
57
51
tokenizer_pth = model_dir / "original" / "tokenizer.model"
58
52
if consolidated_pth .is_file () and tokenizer_pth .is_file ():
@@ -69,11 +63,29 @@ def convert_hf_checkpoint(
69
63
return
70
64
else :
71
65
raise RuntimeError (
72
- f"Could not find { model_map_json } or { consolidated_pth } plus { tokenizer_pth } "
66
+ f"Could not find a valid model weight map or { consolidated_pth } plus { tokenizer_pth } "
73
67
)
74
68
75
- with open (model_map_json ) as json_map :
76
- bin_index = json .load (json_map )
69
+ # Load the json file(s) containing weight mapping
70
+ #
71
+ # NOTE: If there are multiple index files, there are two possibilities:
72
+ # 1. The files could be mapped to different weight format files (e.g. .bin
73
+ # vs .safetensors)
74
+ # 2. The files could be split subsets of the mappings that need to be
75
+ # merged
76
+ #
77
+ # In either case, we can simply keep the mappings where the target file is
78
+ # valid in the model dir.
79
+ bin_files = {}
80
+ for weight_map_file in model_map_json_matches :
81
+ with open (weight_map_file , "r" ) as handle :
82
+ weight_map = json .load (handle )
83
+ valid_mappings = {
84
+ k : model_dir / v
85
+ for (k , v ) in weight_map .get ("weight_map" , {}).items ()
86
+ if (model_dir / v ).is_file ()
87
+ }
88
+ bin_files .update (valid_mappings )
77
89
78
90
weight_map = {
79
91
"model.embed_tokens.weight" : "tok_embeddings.weight" ,
@@ -97,7 +109,6 @@ def convert_hf_checkpoint(
97
109
"model.norm.weight" : "norm.weight" ,
98
110
"lm_head.weight" : "output.weight" ,
99
111
}
100
- bin_files = {model_dir / bin for bin in bin_index ["weight_map" ].values ()}
101
112
102
113
def permute (w , n_heads ):
103
114
return (
0 commit comments