Skip to content

Commit bbea338

Browse files
committed
fix(model): Add support for bias wqkv tensor in Attention
Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 12b7d16 commit bbea338

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

torchchat/model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -769,14 +769,16 @@ def load_hook(self, state_dict, prefix, *args):
769769
# wv = state_dict.pop(prefix + "wv.weight")
770770
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
771771

772-
if prefix + "wqkv.weight" in state_dict:
773-
wqkv = state_dict.pop(prefix + "wqkv.weight")
774-
q_size = self.n_heads * self.head_dim
775-
kv_size = self.n_local_heads * self.head_dim
776-
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
777-
state_dict[prefix + "wq.weight"] = wq
778-
state_dict[prefix + "wk.weight"] = wk
779-
state_dict[prefix + "wv.weight"] = wv
772+
for tensor_suffix in ["weight", "bias"]:
773+
wqkv_key = f"{prefix}wqkv.{tensor_suffix}"
774+
if wqkv_key in state_dict:
775+
wqkv = state_dict.pop(wqkv_key)
776+
q_size = self.n_heads * self.head_dim
777+
kv_size = self.n_local_heads * self.head_dim
778+
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
779+
state_dict[f"{prefix}wq.{tensor_suffix}"] = wq
780+
state_dict[f"{prefix}wk.{tensor_suffix}"] = wk
781+
state_dict[f"{prefix}wv.{tensor_suffix}"] = wv
780782

781783
return
782784

0 commit comments

Comments
 (0)