Skip to content

Commit 3e99253

Browse files
author
Chun-I Tsai
committed
Set member variable to Attention module
- Set and use self.n_heads from args.heads - Use self.dim from args.dim
1 parent ad0e5e8 commit 3e99253

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,21 +260,22 @@ class Attention(nn.Module):
260260
def __init__(self, args: ModelArgs, layer_id: int):
261261
super().__init__()
262262
self.use_kv_cache = args.use_kv_cache
263-
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
264-
assert args.n_heads % self.n_kv_heads == 0
263+
self.n_heads = args.n_heads
264+
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
265+
assert self.n_heads % self.n_kv_heads == 0
265266
model_parallel_size = 1
266-
self.n_local_heads = args.n_heads // model_parallel_size
267+
self.n_local_heads = self.n_heads // model_parallel_size
267268
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
268269
self.n_rep = self.n_local_heads // self.n_local_kv_heads
269-
self.head_dim = args.dim // args.n_heads
270+
self.head_dim = args.dim // self.n_heads
270271
self.max_batch_size = args.max_batch_size
271272
self.max_seq_len = args.max_seq_len
272273
self.dim = args.dim
273-
# args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125
274-
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
275-
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
276-
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
277-
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
274+
# self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125
275+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
276+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
277+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
278+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
278279

279280
self.layer_id = layer_id
280281

0 commit comments

Comments
 (0)