4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
from dataclasses import dataclass
7
- from typing import Optional
7
+ from typing import Dict , Optional
8
8
9
9
import torch
10
10
import torch .nn as nn
@@ -35,7 +35,7 @@ class ModelArgs:
35
35
norm_eps : float = 1e-5
36
36
multiple_of = 256
37
37
ffn_dim_multiplier = None
38
-
38
+
39
39
def __post_init__ (self ):
40
40
if self .n_local_heads == - 1 :
41
41
self .n_local_heads = self .n_heads
@@ -56,7 +56,7 @@ def from_params(cls, params_path):
56
56
with open (params_path , "r" ) as f :
57
57
params = json .loads (f .read ())
58
58
return cls (** params )
59
-
59
+
60
60
@classmethod
61
61
def from_name (cls , name : str ):
62
62
print (f"name { name } " )
@@ -221,7 +221,7 @@ def from_name(cls, name: str):
221
221
@classmethod
222
222
def from_params (cls , params_path : str ):
223
223
return cls (ModelArgs .from_params (params_path ))
224
-
224
+
225
225
226
226
class TransformerBlock (nn .Module ):
227
227
def __init__ (self , config : ModelArgs ) -> None :
@@ -258,14 +258,33 @@ def __init__(self, config: ModelArgs):
258
258
self .head_dim = config .head_dim
259
259
self .n_local_heads = config .n_local_heads
260
260
self .dim = config .dim
261
- # self._register_load_state_dict_pre_hook(self.load_hook)
262
-
263
- # def load_hook(self, state_dict, prefix, *args):
264
- # if prefix + "wq.weight" in state_dict:
265
- # wq = state_dict.pop(prefix + "wq.weight")
266
- # wk = state_dict.pop(prefix + "wk.weight")
267
- # wv = state_dict.pop(prefix + "wv.weight")
268
- # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
261
+ self ._register_load_state_dict_pre_hook (self .load_hook )
262
+
263
+ def load_hook (self , state_dict , prefix , * args ):
264
+ # if prefix + "wq.weight" in state_dict:
265
+ # wq = state_dict.pop(prefix + "wq.weight")
266
+ # wk = state_dict.pop(prefix + "wk.weight")
267
+ # wv = state_dict.pop(prefix + "wv.weight")
268
+ # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
269
+
270
+ def _unfuse_wqkv_state_dict (
271
+ state_dict : Dict [str , torch .Tensor ],
272
+ dim : int ,
273
+ ):
274
+ for key in list (state_dict ):
275
+ if key .endswith ("wqkv.weight" ):
276
+ tensor = state_dict [key ]
277
+ wq_key = key .replace ("wqkv.weight" , "wq.weight" )
278
+ state_dict [wq_key ] = tensor [: dim ]
279
+ wk_key = key .replace ("wqkv.weight" , "wk.weight" )
280
+ wv_key = key .replace ("wqkv.weight" , "wv.weight" )
281
+ wk , wv = tensor [dim :].chunk (2 , 0 )
282
+ state_dict [wk_key ] = wk
283
+ state_dict [wv_key ] = wv
284
+ state_dict .pop (key )
285
+ else :
286
+ continue
287
+ _unfuse_wqkv_state_dict (state_dict , self .dim )
269
288
270
289
def forward (
271
290
self ,
0 commit comments