@@ -100,11 +100,10 @@ def convert_hf_checkpoint(
100
100
bin_files = {model_dir / bin for bin in bin_index ["weight_map" ].values ()}
101
101
102
102
def permute (w , n_heads ):
103
- dim = config .dim
104
103
return (
105
- w .view (n_heads , 2 , config .head_dim // 2 , dim )
104
+ w .view (n_heads , 2 , config .head_dim // 2 , * w . shape [ 1 :] )
106
105
.transpose (1 , 2 )
107
- .reshape (config . head_dim * n_heads , dim )
106
+ .reshape (w . shape )
108
107
)
109
108
110
109
merged_result = {}
@@ -137,6 +136,7 @@ def load_safetensors():
137
136
continue
138
137
assert state_dict is not None , f"Unable to load tensors from { file } "
139
138
merged_result .update (state_dict )
139
+
140
140
final_result = {}
141
141
for key , value in merged_result .items ():
142
142
if "layers" in key :
@@ -150,16 +150,20 @@ def load_safetensors():
150
150
final_result [new_key ] = value
151
151
152
152
for key in tuple (final_result .keys ()):
153
- if "wq.weight" in key :
153
+ if "wq.weight" in key or "wq.bias" in key :
154
+ wk_key = key .replace ("wq" , "wk" )
155
+ wv_key = key .replace ("wq" , "wv" )
154
156
q = final_result [key ]
155
- k = final_result [key .replace ("wq" , "wk" )]
156
- v = final_result [key .replace ("wq" , "wv" )]
157
+ k = final_result [wk_key ]
158
+ v = final_result [wv_key ]
159
+ print (key )
157
160
q = permute (q , config .n_heads )
161
+ print (wk_key )
158
162
k = permute (k , config .n_local_heads )
159
163
final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
160
164
del final_result [key ]
161
- del final_result [key . replace ( "wq" , "wk" ) ]
162
- del final_result [key . replace ( "wq" , "wv" ) ]
165
+ del final_result [wk_key ]
166
+ del final_result [wv_key ]
163
167
print (f"Saving checkpoint to { model_dir / 'model.pth' } . This may take a while." )
164
168
torch .save (final_result , model_dir / "model.pth" )
165
169
print ("Done." )
0 commit comments