@@ -24,23 +24,39 @@ class ModelArgs:
24
24
block_size : int = 2048
25
25
vocab_size : int = 32000
26
26
n_layer : int = 32
27
- n_head : int = 32
27
+ # n_head in gpt-fast
28
+ n_heads : int = 32
28
29
dim : int = 4096
29
- intermediate_size : int = None
30
+ # hidden dim is intermediate_size in gpt-fast
31
+ hidden_dim : int = None
30
32
n_local_heads : int = - 1
31
33
head_dim : int = 64
32
34
rope_base : float = 10000
33
35
norm_eps : float = 1e-5
34
-
36
+ multiple_of = 256
37
+ ffn_dim_multiplier = None
38
+
35
39
def __post_init__ (self ):
36
40
if self .n_local_heads == - 1 :
37
- self .n_local_heads = self .n_head
38
- if self .intermediate_size is None :
41
+ self .n_local_heads = self .n_heads
42
+ if self .hidden_dim is None :
43
+ # If hidden_dim is not explicitly set in the ModelArgs,
44
+ # then calculate implicitly based on dim and
45
+ # also multiple of `args.multiple_of`
46
+ multiple_of = self .multiple_of
39
47
hidden_dim = 4 * self .dim
40
- n_hidden = int (2 * hidden_dim / 3 )
41
- self .intermediate_size = find_multiple (n_hidden , 256 )
42
- self .head_dim = self .dim // self .n_head
48
+ hidden_dim = int (2 * hidden_dim / 3 )
49
+ if self .ffn_dim_multiplier is not None :
50
+ hidden_dim = int (self .ffn_dim_multiplier * hidden_dim )
51
+ self .hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1 ) // multiple_of )
52
+ self .head_dim = self .dim // self .n_heads
43
53
54
+ @classmethod
55
+ def from_params (cls , params_path ):
56
+ with open (params_path , "r" ) as f :
57
+ params = json .loads (f .read ())
58
+ return cls (** params )
59
+
44
60
@classmethod
45
61
def from_name (cls , name : str ):
46
62
print (f"name { name } " )
@@ -70,47 +86,47 @@ def from_name(cls, name: str):
70
86
"CodeLlama-7b-Python-hf" : dict (
71
87
block_size = 16384 , vocab_size = 32000 , n_layer = 32 , dim = 4096 , rope_base = 1000000
72
88
),
73
- "7B" : dict (n_layer = 32 , n_head = 32 , dim = 4096 ),
74
- "13B" : dict (n_layer = 40 , n_head = 40 , dim = 5120 ),
75
- "30B" : dict (n_layer = 60 , n_head = 52 , dim = 6656 ),
89
+ "7B" : dict (n_layer = 32 , n_heads = 32 , dim = 4096 ),
90
+ "13B" : dict (n_layer = 40 , n_heads = 40 , dim = 5120 ),
91
+ "30B" : dict (n_layer = 60 , n_heads = 52 , dim = 6656 ),
76
92
"34B" : dict (
77
93
n_layer = 48 ,
78
- n_head = 64 ,
94
+ n_heads = 64 ,
79
95
dim = 8192 ,
80
96
vocab_size = 32000 ,
81
97
n_local_heads = 8 ,
82
- intermediate_size = 22016 ,
98
+ hidden_dim = 22016 ,
83
99
rope_base = 1000000 ,
84
100
), # CodeLlama-34B-Python-hf
85
101
"70B" : dict (
86
- n_layer = 80 , n_head = 64 , dim = 8192 , n_local_heads = 8 , intermediate_size = 28672
102
+ n_layer = 80 , n_heads = 64 , dim = 8192 , n_local_heads = 8 , hidden_dim = 28672
87
103
),
88
104
"Mistral-7B" : dict (
89
105
n_layer = 32 ,
90
- n_head = 32 ,
106
+ n_heads = 32 ,
91
107
n_local_heads = 8 ,
92
108
dim = 4096 ,
93
- intermediate_size = 14336 ,
109
+ hidden_dim = 14336 ,
94
110
vocab_size = 32000 ,
95
111
),
96
112
"Mistral-7B-Instruct-v0.1" : dict (
97
113
n_layer = 32 ,
98
- n_head = 32 ,
114
+ n_heads = 32 ,
99
115
n_local_heads = 8 ,
100
116
dim = 4096 ,
101
- intermediate_size = 14336 ,
117
+ hidden_dim = 14336 ,
102
118
vocab_size = 32000 ,
103
119
),
104
120
"Mistral-7B-Instruct-v0.2" : dict (
105
121
n_layer = 32 ,
106
- n_head = 32 ,
122
+ n_heads = 32 ,
107
123
n_local_heads = 8 ,
108
124
dim = 4096 ,
109
- intermediate_size = 14336 ,
125
+ hidden_dim = 14336 ,
110
126
vocab_size = 32000 ,
111
127
),
112
- "stories15M" : dict (n_layer = 6 , n_head = 6 , dim = 288 ),
113
- "stories110M" : dict (n_layer = 12 , n_head = 12 , dim = 768 ),
128
+ "stories15M" : dict (n_layer = 6 , n_heads = 6 , dim = 288 ),
129
+ "stories110M" : dict (n_layer = 12 , n_heads = 12 , dim = 768 ),
114
130
}
115
131
116
132
@@ -160,7 +176,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
160
176
and self .max_batch_size >= max_batch_size
161
177
):
162
178
return
163
- head_dim = self .config .dim // self .config .n_head
179
+ head_dim = self .config .dim // self .config .n_heads
164
180
max_seq_length = find_multiple (max_seq_length , 8 )
165
181
self .max_seq_length = max_seq_length
166
182
self .max_batch_size = max_batch_size
@@ -170,8 +186,8 @@ def setup_caches(self, max_batch_size, max_seq_length):
170
186
)
171
187
172
188
freqs_cis = precompute_freqs_cis (
173
- self .config .block_size ,
174
- self .config .dim // self . config . n_head ,
189
+ self .config .dim // self . config . n_heads ,
190
+ self .config .block_size * 2 ,
175
191
self .config .rope_base ,
176
192
)
177
193
self .register_buffer ("freqs_cis" , freqs_cis , persistent = True )
@@ -202,6 +218,10 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
202
218
def from_name (cls , name : str ):
203
219
return cls (ModelArgs .from_name (name ))
204
220
221
+ @classmethod
222
+ def from_params (cls , params_path : str ):
223
+ return cls (ModelArgs .from_params (params_path ))
224
+
205
225
206
226
class TransformerBlock (nn .Module ):
207
227
def __init__ (self , config : ModelArgs ) -> None :
@@ -222,19 +242,19 @@ def forward(
222
242
class Attention (nn .Module ):
223
243
def __init__ (self , config : ModelArgs ):
224
244
super ().__init__ ()
225
- assert config .dim % config .n_head == 0
245
+ assert config .dim % config .n_heads == 0
226
246
227
247
# key, query, value projections for all heads, but in a batch
228
- # total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
248
+ # total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
229
249
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
230
- self .wq = nn .Linear (config .dim , config .n_head * config .head_dim , bias = False )
250
+ self .wq = nn .Linear (config .dim , config .n_heads * config .head_dim , bias = False )
231
251
self .wk = nn .Linear (config .dim , config .n_local_heads * config .head_dim , bias = False )
232
252
self .wv = nn .Linear (config .dim , config .n_local_heads * config .head_dim , bias = False )
233
253
234
254
self .wo = nn .Linear (config .dim , config .dim , bias = False )
235
255
self .kv_cache = None
236
256
237
- self .n_head = config .n_head
257
+ self .n_heads = config .n_heads
238
258
self .head_dim = config .head_dim
239
259
self .n_local_heads = config .n_local_heads
240
260
self .dim = config .dim
@@ -263,7 +283,7 @@ def forward(
263
283
# kv_size = self.n_local_heads * self.head_dim
264
284
# q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
265
285
266
- q = q .view (bsz , seqlen , self .n_head , self .head_dim )
286
+ q = q .view (bsz , seqlen , self .n_heads , self .head_dim )
267
287
k = k .view (bsz , seqlen , self .n_local_heads , self .head_dim )
268
288
v = v .view (bsz , seqlen , self .n_local_heads , self .head_dim )
269
289
@@ -275,8 +295,8 @@ def forward(
275
295
if self .kv_cache is not None :
276
296
k , v = self .kv_cache .update (input_pos , k , v )
277
297
278
- k = k .repeat_interleave (self .n_head // self .n_local_heads , dim = 1 )
279
- v = v .repeat_interleave (self .n_head // self .n_local_heads , dim = 1 )
298
+ k = k .repeat_interleave (self .n_heads // self .n_local_heads , dim = 1 )
299
+ v = v .repeat_interleave (self .n_heads // self .n_local_heads , dim = 1 )
280
300
y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
281
301
282
302
y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
@@ -288,9 +308,9 @@ def forward(
288
308
class FeedForward (nn .Module ):
289
309
def __init__ (self , config : ModelArgs ) -> None :
290
310
super ().__init__ ()
291
- self .w1 = nn .Linear (config .dim , config .intermediate_size , bias = False )
292
- self .w3 = nn .Linear (config .dim , config .intermediate_size , bias = False )
293
- self .w2 = nn .Linear (config .intermediate_size , config .dim , bias = False )
311
+ self .w1 = nn .Linear (config .dim , config .hidden_dim , bias = False )
312
+ self .w2 = nn .Linear (config .hidden_dim , config .dim , bias = False )
313
+ self .w3 = nn .Linear (config .dim , config .hidden_dim , bias = False )
294
314
295
315
def forward (self , x : Tensor ) -> Tensor :
296
316
return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
@@ -309,8 +329,8 @@ def forward(self, x: Tensor) -> Tensor:
309
329
output = self ._norm (x .float ()).type_as (x )
310
330
return output * self .weight
311
331
312
-
313
- def precompute_freqs_cis (seq_len : int , n_elem : int , base : int = 10000 ) -> Tensor :
332
+ # transpsoed first two arguments to align with model in ET
333
+ def precompute_freqs_cis (n_elem : int , seq_len : int , base : int = 10000 ) -> Tensor :
314
334
freqs = 1.0 / (
315
335
base ** (torch .arange (0 , n_elem , 2 )[: (n_elem // 2 )].float () / n_elem )
316
336
)
0 commit comments