Skip to content

Commit 904a2e5

Browse files
chunit-quicChun-I Tsai
andauthored
Set member variable to Attention module (#6376)
- Set and use self.n_heads from args.heads - Use self.dim from args.dim Co-authored-by: Chun-I Tsai <[email protected]>
1 parent f044e91 commit 904a2e5

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,22 @@ class Attention(nn.Module):
265265
def __init__(self, args: ModelArgs, layer_id: int):
266266
super().__init__()
267267
self.use_kv_cache = args.use_kv_cache
268-
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
269-
assert args.n_heads % self.n_kv_heads == 0
268+
self.n_heads = args.n_heads
269+
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
270+
assert self.n_heads % self.n_kv_heads == 0
270271
model_parallel_size = 1
271-
self.n_local_heads = args.n_heads // model_parallel_size
272+
self.n_local_heads = self.n_heads // model_parallel_size
272273
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
273274
self.n_rep = self.n_local_heads // self.n_local_kv_heads
274-
self.head_dim = args.dim // args.n_heads
275+
self.head_dim = args.dim // self.n_heads
275276
self.max_batch_size = args.max_batch_size
276277
self.max_seq_len = args.max_seq_len
277278
self.dim = args.dim
278-
# args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125
279-
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
280-
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
281-
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
282-
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
279+
# self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125
280+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
281+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
282+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
283+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
283284

284285
self.layer_id = layer_id
285286

0 commit comments

Comments
 (0)