Skip to content

Commit 12b7d16

Browse files
committed
fix(convert): Add support for permuted kvq bias weights in HF conversion
Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 88705cb commit 12b7d16

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,10 @@ def convert_hf_checkpoint(
100100
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}
101101

102102
def permute(w, n_heads):
103-
dim = config.dim
104103
return (
105-
w.view(n_heads, 2, config.head_dim // 2, dim)
104+
w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:])
106105
.transpose(1, 2)
107-
.reshape(config.head_dim * n_heads, dim)
106+
.reshape(w.shape)
108107
)
109108

110109
merged_result = {}
@@ -137,6 +136,7 @@ def load_safetensors():
137136
continue
138137
assert state_dict is not None, f"Unable to load tensors from {file}"
139138
merged_result.update(state_dict)
139+
140140
final_result = {}
141141
for key, value in merged_result.items():
142142
if "layers" in key:
@@ -150,16 +150,20 @@ def load_safetensors():
150150
final_result[new_key] = value
151151

152152
for key in tuple(final_result.keys()):
153-
if "wq.weight" in key:
153+
if "wq.weight" in key or "wq.bias" in key:
154+
wk_key = key.replace("wq", "wk")
155+
wv_key = key.replace("wq", "wv")
154156
q = final_result[key]
155-
k = final_result[key.replace("wq", "wk")]
156-
v = final_result[key.replace("wq", "wv")]
157+
k = final_result[wk_key]
158+
v = final_result[wv_key]
159+
print(key)
157160
q = permute(q, config.n_heads)
161+
print(wk_key)
158162
k = permute(k, config.n_local_heads)
159163
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
160164
del final_result[key]
161-
del final_result[key.replace("wq", "wk")]
162-
del final_result[key.replace("wq", "wv")]
165+
del final_result[wk_key]
166+
del final_result[wv_key]
163167
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
164168
torch.save(final_result, model_dir / "model.pth")
165169
print("Done.")

0 commit comments

Comments
 (0)