@@ -260,21 +260,22 @@ class Attention(nn.Module):
260
260
def __init__ (self , args : ModelArgs , layer_id : int ):
261
261
super ().__init__ ()
262
262
self .use_kv_cache = args .use_kv_cache
263
- self .n_kv_heads = args .n_heads if args .n_kv_heads is None else args .n_kv_heads
264
- assert args .n_heads % self .n_kv_heads == 0
263
+ self .n_heads = args .n_heads
264
+ self .n_kv_heads = self .n_heads if args .n_kv_heads is None else args .n_kv_heads
265
+ assert self .n_heads % self .n_kv_heads == 0
265
266
model_parallel_size = 1
266
- self .n_local_heads = args .n_heads // model_parallel_size
267
+ self .n_local_heads = self .n_heads // model_parallel_size
267
268
self .n_local_kv_heads = self .n_kv_heads // model_parallel_size
268
269
self .n_rep = self .n_local_heads // self .n_local_kv_heads
269
- self .head_dim = args .dim // args .n_heads
270
+ self .head_dim = args .dim // self .n_heads
270
271
self .max_batch_size = args .max_batch_size
271
272
self .max_seq_len = args .max_seq_len
272
273
self .dim = args .dim
273
- # args .dim = 4096, args .n_heads = 32, self.head_dim = 4096 / 32 = 125
274
- self .wq = nn .Linear (args .dim , args .n_heads * self .head_dim , bias = False )
275
- self .wk = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
276
- self .wv = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
277
- self .wo = nn .Linear (args .n_heads * self .head_dim , args .dim , bias = False )
274
+ # self .dim = 4096, self .n_heads = 32, self.head_dim = 4096 / 32 = 125
275
+ self .wq = nn .Linear (self .dim , self .n_heads * self .head_dim , bias = False )
276
+ self .wk = nn .Linear (self .dim , self .n_kv_heads * self .head_dim , bias = False )
277
+ self .wv = nn .Linear (self .dim , self .n_kv_heads * self .head_dim , bias = False )
278
+ self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
278
279
279
280
self .layer_id = layer_id
280
281
0 commit comments