Skip to content

Commit 741d7ba

Browse files
guangy10malfet
authored andcommitted
Fix loading checkpoints with fused wqkv weights (#158)
1 parent 7429578 commit 741d7ba

File tree

3 files changed

+36
-17
lines changed

3 files changed

+36
-17
lines changed

model.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
from dataclasses import dataclass
7-
from typing import Optional
7+
from typing import Dict, Optional
88

99
import torch
1010
import torch.nn as nn
@@ -35,7 +35,7 @@ class ModelArgs:
3535
norm_eps: float = 1e-5
3636
multiple_of = 256
3737
ffn_dim_multiplier = None
38-
38+
3939
def __post_init__(self):
4040
if self.n_local_heads == -1:
4141
self.n_local_heads = self.n_heads
@@ -56,7 +56,7 @@ def from_params(cls, params_path):
5656
with open(params_path, "r") as f:
5757
params = json.loads(f.read())
5858
return cls(**params)
59-
59+
6060
@classmethod
6161
def from_name(cls, name: str):
6262
print(f"name {name}")
@@ -221,7 +221,7 @@ def from_name(cls, name: str):
221221
@classmethod
222222
def from_params(cls, params_path: str):
223223
return cls(ModelArgs.from_params(params_path))
224-
224+
225225

226226
class TransformerBlock(nn.Module):
227227
def __init__(self, config: ModelArgs) -> None:
@@ -258,14 +258,33 @@ def __init__(self, config: ModelArgs):
258258
self.head_dim = config.head_dim
259259
self.n_local_heads = config.n_local_heads
260260
self.dim = config.dim
261-
# self._register_load_state_dict_pre_hook(self.load_hook)
262-
263-
# def load_hook(self, state_dict, prefix, *args):
264-
# if prefix + "wq.weight" in state_dict:
265-
# wq = state_dict.pop(prefix + "wq.weight")
266-
# wk = state_dict.pop(prefix + "wk.weight")
267-
# wv = state_dict.pop(prefix + "wv.weight")
268-
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
261+
self._register_load_state_dict_pre_hook(self.load_hook)
262+
263+
def load_hook(self, state_dict, prefix, *args):
264+
# if prefix + "wq.weight" in state_dict:
265+
# wq = state_dict.pop(prefix + "wq.weight")
266+
# wk = state_dict.pop(prefix + "wk.weight")
267+
# wv = state_dict.pop(prefix + "wv.weight")
268+
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
269+
270+
def _unfuse_wqkv_state_dict(
271+
state_dict: Dict[str, torch.Tensor],
272+
dim: int,
273+
):
274+
for key in list(state_dict):
275+
if key.endswith("wqkv.weight"):
276+
tensor = state_dict[key]
277+
wq_key = key.replace("wqkv.weight", "wq.weight")
278+
state_dict[wq_key] = tensor[: dim]
279+
wk_key = key.replace("wqkv.weight", "wk.weight")
280+
wv_key = key.replace("wqkv.weight", "wv.weight")
281+
wk, wv = tensor[dim :].chunk(2, 0)
282+
state_dict[wk_key] = wk
283+
state_dict[wv_key] = wv
284+
state_dict.pop(key)
285+
else:
286+
continue
287+
_unfuse_wqkv_state_dict(state_dict, self.dim)
269288

270289
def forward(
271290
self,

scripts/convert_hf_checkpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def convert_hf_checkpoint(
5555
}
5656
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
5757

58-
def permute(w, n_head):
58+
def permute(w, n_heads):
5959
dim = config.dim
6060
return (
61-
w.view(n_head, 2, config.head_dim // 2, dim)
61+
w.view(n_heads, 2, config.head_dim // 2, dim)
6262
.transpose(1, 2)
63-
.reshape(config.head_dim * n_head, dim)
63+
.reshape(config.head_dim * n_heads, dim)
6464
)
6565

6666
merged_result = {}
@@ -86,7 +86,7 @@ def permute(w, n_head):
8686
q = final_result[key]
8787
k = final_result[key.replace("wq", "wk")]
8888
v = final_result[key.replace("wq", "wv")]
89-
q = permute(q, config.n_head)
89+
q = permute(q, config.n_heads)
9090
k = permute(k, config.n_local_heads)
9191
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
9292
del final_result[key]

scripts/test_flow.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
22
rm -r checkpoints/$MODEL_REPO
33
python scripts/download.py --repo-id $MODEL_REPO
44
python scripts/convert_hf_checkpoint.py --checkpoint-dir checkpoints/$MODEL_REPO
5-
python generate.py --compile --checkpoint-path checkpoints/$MODEL_REPO/model.pth --max_new_tokens 100
5+
python generate.py --compile --checkpoint-path checkpoints/$MODEL_REPO/model.pth --max-new-tokens 100

0 commit comments

Comments
 (0)