@@ -769,14 +769,16 @@ def load_hook(self, state_dict, prefix, *args):
769
769
# wv = state_dict.pop(prefix + "wv.weight")
770
770
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
771
771
772
- if prefix + "wqkv.weight" in state_dict :
773
- wqkv = state_dict .pop (prefix + "wqkv.weight" )
774
- q_size = self .n_heads * self .head_dim
775
- kv_size = self .n_local_heads * self .head_dim
776
- wq , wk , wv = torch .split (wqkv , (q_size , kv_size , kv_size ), dim = 0 )
777
- state_dict [prefix + "wq.weight" ] = wq
778
- state_dict [prefix + "wk.weight" ] = wk
779
- state_dict [prefix + "wv.weight" ] = wv
772
+ for tensor_suffix in ["weight" , "bias" ]:
773
+ wqkv_key = f"{ prefix } wqkv.{ tensor_suffix } "
774
+ if wqkv_key in state_dict :
775
+ wqkv = state_dict .pop (wqkv_key )
776
+ q_size = self .n_heads * self .head_dim
777
+ kv_size = self .n_local_heads * self .head_dim
778
+ wq , wk , wv = torch .split (wqkv , (q_size , kv_size , kv_size ), dim = 0 )
779
+ state_dict [f"{ prefix } wq.{ tensor_suffix } " ] = wq
780
+ state_dict [f"{ prefix } wk.{ tensor_suffix } " ] = wk
781
+ state_dict [f"{ prefix } wv.{ tensor_suffix } " ] = wv
780
782
781
783
return
782
784
0 commit comments