@@ -126,8 +126,6 @@ def __init__(
126
126
self .head_dim = head_dim
127
127
self .max_seq_len = max_seq_len
128
128
self .is_causal = is_causal
129
- # Number of queries per k, v
130
- self .q_per_kv = self .num_heads // self .num_kv_heads
131
129
132
130
# Set layers
133
131
self .kv_cache = kv_cache
@@ -145,7 +143,7 @@ def __init__(
145
143
num_kv_heads = self .num_kv_heads ,
146
144
num_heads = self .num_heads ,
147
145
head_dim = self .head_dim ,
148
- q_per_kv = self .q_per_kv ,
146
+ q_per_kv = self .num_heads // self . num_kv_heads ,
149
147
attn_dropout = self .attn_dropout if self .training else 0.0 ,
150
148
is_causal = self .is_causal ,
151
149
attention_fn = self ._attention_call ,
@@ -239,7 +237,10 @@ def forward(
239
237
240
238
# q has shape [b, s_x, num_heads * head_dim]
241
239
q = self .q_proj (x )
242
- q = q .view (b , s_x , self .num_kv_heads * self .q_per_kv , self .head_dim )
240
+
241
+ # number of queries per key/value
242
+ q_per_kv = self .num_heads // self .num_kv_heads
243
+ q = q .view (b , s_x , self .num_kv_heads * q_per_kv , self .head_dim )
243
244
244
245
# Apply positional embeddings
245
246
if self .pos_embeddings is not None :
0 commit comments