Skip to content

Commit 4f9ae32

Browse files
authored
[llama-mm] Reduce copies in SDPA in MHA (#6917)
Summary: As titled. Align the implementation for SDPA with the torchtune version https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention.py#L267 Test Plan: Rely on unit tests Reviewers: Subscribers: Tasks: Tags:
1 parent 3784f06 commit 4f9ae32

File tree

1 file changed

+7
-15
lines changed

1 file changed

+7
-15
lines changed

extension/llm/modules/attention.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -352,26 +352,18 @@ def forward(
352352
# View + expand + reshape bring num_kv_heads to num_heads for k and v
353353
# to match q.
354354

355-
# k: [bsz, seq_len, n_kv, 1, h_d]
356-
# v: [bsz, seq_len, n_kv, 1, h_d]
357-
k = k.view(bsz, -1, self.num_kv_heads, 1, self.head_dim)
358-
v = v.view(bsz, -1, self.num_kv_heads, 1, self.head_dim)
359-
360-
# Expand the key and value tensors to have the same shape
361-
# as the query tensor by copying values across the relevant dim
362-
if self.num_heads != self.num_kv_heads:
363-
k = k.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim)
364-
v = v.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim)
365-
366-
# [bsz, s, n_h, h_d]
367-
k = k.reshape(bsz, -1, self.num_heads, self.head_dim)
368-
v = v.reshape(bsz, -1, self.num_heads, self.head_dim)
369-
370355
# [bsz, n_h, s, h_d]
371356
q = q.transpose(1, 2)
372357
k = k.transpose(1, 2)
373358
v = v.transpose(1, 2)
374359

360+
# Expand the key and value tensors to have the same shape
361+
# as the query tensor by copying values across the relevant dim
362+
if self.num_heads != self.num_kv_heads:
363+
expand_shape = (-1, -1, self.q_per_kv, -1, -1)
364+
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
365+
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
366+
375367
output = self._attention_fn(
376368
q,
377369
k,

0 commit comments

Comments
 (0)