Skip to content

Commit 8afb8e1

Browse files
committed
Match up mha with TT
1 parent 1fe0356 commit 8afb8e1

File tree

1 file changed

+5
-4
lines changed
  • examples/models/llama2/source_transformation/torchtune/modules

1 file changed

+5
-4
lines changed

examples/models/llama2/source_transformation/torchtune/modules/mha.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ def __init__(
126126
self.head_dim = head_dim
127127
self.max_seq_len = max_seq_len
128128
self.is_causal = is_causal
129-
# Number of queries per k, v
130-
self.q_per_kv = self.num_heads // self.num_kv_heads
131129

132130
# Set layers
133131
self.kv_cache = kv_cache
@@ -145,7 +143,7 @@ def __init__(
145143
num_kv_heads=self.num_kv_heads,
146144
num_heads=self.num_heads,
147145
head_dim=self.head_dim,
148-
q_per_kv=self.q_per_kv,
146+
q_per_kv=self.num_heads // self.num_kv_heads,
149147
attn_dropout=self.attn_dropout if self.training else 0.0,
150148
is_causal=self.is_causal,
151149
attention_fn=self._attention_call,
@@ -239,7 +237,10 @@ def forward(
239237

240238
# q has shape [b, s_x, num_heads * head_dim]
241239
q = self.q_proj(x)
242-
q = q.view(b, s_x, self.num_kv_heads * self.q_per_kv, self.head_dim)
240+
241+
# number of queries per key/value
242+
q_per_kv = self.num_heads // self.num_kv_heads
243+
q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
243244

244245
# Apply positional embeddings
245246
if self.pos_embeddings is not None:

0 commit comments

Comments
 (0)