@@ -265,21 +265,22 @@ class Attention(nn.Module):
265
265
def __init__ (self , args : ModelArgs , layer_id : int ):
266
266
super ().__init__ ()
267
267
self .use_kv_cache = args .use_kv_cache
268
- self .n_kv_heads = args .n_heads if args .n_kv_heads is None else args .n_kv_heads
269
- assert args .n_heads % self .n_kv_heads == 0
268
+ self .n_heads = args .n_heads
269
+ self .n_kv_heads = self .n_heads if args .n_kv_heads is None else args .n_kv_heads
270
+ assert self .n_heads % self .n_kv_heads == 0
270
271
model_parallel_size = 1
271
- self .n_local_heads = args .n_heads // model_parallel_size
272
+ self .n_local_heads = self .n_heads // model_parallel_size
272
273
self .n_local_kv_heads = self .n_kv_heads // model_parallel_size
273
274
self .n_rep = self .n_local_heads // self .n_local_kv_heads
274
- self .head_dim = args .dim // args .n_heads
275
+ self .head_dim = args .dim // self .n_heads
275
276
self .max_batch_size = args .max_batch_size
276
277
self .max_seq_len = args .max_seq_len
277
278
self .dim = args .dim
278
- # args .dim = 4096, args .n_heads = 32, self.head_dim = 4096 / 32 = 125
279
- self .wq = nn .Linear (args .dim , args .n_heads * self .head_dim , bias = False )
280
- self .wk = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
281
- self .wv = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
282
- self .wo = nn .Linear (args .n_heads * self .head_dim , args .dim , bias = False )
279
+ # self .dim = 4096, self .n_heads = 32, self.head_dim = 4096 / 32 = 125
280
+ self .wq = nn .Linear (self .dim , self .n_heads * self .head_dim , bias = False )
281
+ self .wk = nn .Linear (self .dim , self .n_kv_heads * self .head_dim , bias = False )
282
+ self .wv = nn .Linear (self .dim , self .n_kv_heads * self .head_dim , bias = False )
283
+ self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
283
284
284
285
self .layer_id = layer_id
285
286
0 commit comments