Skip to content

Commit 095bd2c

Browse files
committed
{executorch][llama] support mqa
Pull Request resolved: #3080 This diff adds support for multi query attention for sdpa with kv cache Differential Revision: [D56228316](https://our.internmc.facebook.com/intern/diff/D56228316/) ghstack-source-id: 222855405
1 parent 458d743 commit 095bd2c

File tree

3 files changed

+236
-2
lines changed

3 files changed

+236
-2
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
11
# Any targets that should be shared between fbcode and xplat must be defined in
22
# targets.bzl. This file can contain fbcode-only targets.
33

4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
45
load(":targets.bzl", "define_common_targets")
56

67
oncall("executorch")
78

89
define_common_targets()
10+
11+
runtime.python_test(
12+
name = "test_sdpa_with_kv_cache",
13+
srcs = [
14+
"test_sdpa_with_kv_cache.py",
15+
],
16+
preload_deps = [
17+
":custom_ops_aot_lib",
18+
],
19+
deps = [
20+
"//caffe2:torch",
21+
],
22+
)

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,29 @@ void cpu_flash_attention(
219219
int64_t qSize = query.size(2);
220220
int64_t headSize = query.size(3);
221221
int64_t kvSize = value.size(2);
222+
int64_t num_heads_kv = key.size(1);
222223

223224
if (is_with_kv_cache) {
224225
num_head = query.size(2);
226+
num_heads_kv = key.size(2);
225227
qSize = query.size(1);
226228
kvSize = value.size(1);
227229
}
228230

231+
ET_CHECK_MSG(
232+
num_heads_kv <= num_head,
233+
"FlashAttention does not support num kv heads > num query heads.Got num query heads=%" PRId64
234+
" num key heads:%" PRId64,
235+
num_head,
236+
num_heads_kv);
237+
ET_CHECK_MSG(
238+
num_head % num_heads_kv == 0,
239+
"FlashAttention: num qyery heads must be divisible by num kv heads but got num query heads=%" PRId64
240+
" and num kv heads=%" PRId64,
241+
num_head,
242+
num_heads_kv);
243+
int64_t num_reps = num_head / num_heads_kv;
244+
229245
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
230246
if (has_attn_mask) {
231247
/*
@@ -365,6 +381,7 @@ void cpu_flash_attention(
365381
fill_stub(
366382
qk_max_data, -std::numeric_limits<accum_t>::infinity(), qBlockSize);
367383
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
384+
auto j_kv = j / num_reps;
368385
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
369386
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
370387
// Calculate scale * q @ k.T
@@ -376,7 +393,7 @@ void cpu_flash_attention(
376393
qBlockSize,
377394
headSize,
378395
static_cast<accum_t>(1),
379-
k_data + i * kStrideB + j * kStrideH + n * kStrideN,
396+
k_data + i * kStrideB + j_kv * kStrideH + n * kStrideN,
380397
kStrideN,
381398
q_data + i * qStrideB + j * qStrideH + m * qStrideM,
382399
qStrideM,
@@ -460,7 +477,7 @@ void cpu_flash_attention(
460477
qBlockSize,
461478
kvBlockSize,
462479
static_cast<accum_t>(1),
463-
v_data + i * vStrideB + j * vStrideH + n * vStrideN,
480+
v_data + i * vStrideB + j_kv * vStrideH + n * vStrideN,
464481
vStrideN,
465482
conditional_data_ptr(qk_data, qk_reduced_data),
466483
kvBlockSize,
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import torch.nn.functional as F
11+
12+
13+
class SDPATest(unittest.TestCase):
14+
15+
def setUp(self):
16+
torch.manual_seed(42)
17+
self.k_cache = torch.zeros((1, 5, 8, 4))
18+
self.v_cache = torch.zeros((1, 5, 8, 4))
19+
self.mask = torch.full(
20+
(5, 5),
21+
float("-inf"),
22+
)
23+
self.mask = torch.triu(self.mask, diagonal=1)
24+
25+
def _sdpa_with_kv_cache_ref(self, q, k, v, k_cache, v_cache, mask, start_pos):
26+
print(f"at start_pos:{start_pos}")
27+
print(q)
28+
print(k)
29+
print(v)
30+
attn_mask = mask[start_pos].view((1, -1))
31+
attn_mask = attn_mask[:, : start_pos + 1]
32+
q = q.transpose(1, 2)
33+
k_cache[:, start_pos] = k
34+
v_cache[:, start_pos] = v
35+
sliced_k_cache = k_cache[:, : start_pos + 1, :, :]
36+
sliced_v_cache = v_cache[:, : start_pos + 1, :, :]
37+
sliced_k_cache = sliced_k_cache.transpose(1, 2)
38+
sliced_v_cache = sliced_v_cache.transpose(1, 2)
39+
# print(sliced_k_cache.size())
40+
# print(torch.matmul(q, sliced_k_cache.transpose(2, 3)))
41+
# print("q @ k")
42+
# qk = torch.matmul(q, sliced_k_cache.transpose(2, 3))
43+
# qk_softmax = torch.softmax(qk, dim=-1)
44+
# qkv = torch.matmul(qk_softmax, sliced_v_cache)
45+
# print(qk)
46+
# print(qk_softmax)
47+
# print(qkv)
48+
out = F.scaled_dot_product_attention(
49+
q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
50+
)
51+
out = out.transpose(1, 2)
52+
print(out)
53+
print(f"-------- start pos {start_pos} done -----")
54+
return out
55+
56+
def test_sdpa_with_cache_no_mqa_1(self):
57+
q = torch.rand((1, 1, 8, 4))
58+
k = torch.rand((1, 1, 8, 4))
59+
v = torch.rand((1, 1, 8, 4))
60+
ref_output = self._sdpa_with_kv_cache_ref(
61+
q, k, v, self.k_cache, self.v_cache, self.mask, 0
62+
)
63+
op_output = torch.ops.llama.sdpa_with_kv_cache(
64+
q, k, v, self.k_cache, self.v_cache, 0, 1, None, 0, False
65+
)
66+
self.assertTrue(torch.allclose(ref_output, op_output))
67+
68+
def test_sdpa_with_cache_no_mqa_2(self):
69+
q = torch.rand((1, 1, 8, 4))
70+
k = torch.rand((1, 1, 8, 4))
71+
v = torch.rand((1, 1, 8, 4))
72+
73+
ref_output = self._sdpa_with_kv_cache_ref(
74+
q, k, v, self.k_cache, self.v_cache, self.mask, 1
75+
)
76+
op_output = torch.ops.llama.sdpa_with_kv_cache(
77+
q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False
78+
)
79+
self.assertTrue(torch.allclose(ref_output, op_output))
80+
81+
def test_sdpa_with_cache_no_mqa_3(self):
82+
q = torch.rand((1, 1, 8, 4))
83+
k = torch.rand((1, 1, 8, 4))
84+
v = torch.rand((1, 1, 8, 4))
85+
86+
ref_output = self._sdpa_with_kv_cache_ref(
87+
q, k, v, self.k_cache, self.v_cache, self.mask, 2
88+
)
89+
op_output = torch.ops.llama.sdpa_with_kv_cache(
90+
q, k, v, self.k_cache, self.v_cache, 2, 1, None, 0, False
91+
)
92+
self.assertTrue(torch.allclose(ref_output, op_output))
93+
94+
def test_sdpa_with_cache_no_mqa_4(self):
95+
q = torch.rand((1, 1, 8, 4))
96+
k = torch.rand((1, 1, 8, 4))
97+
v = torch.rand((1, 1, 8, 4))
98+
99+
ref_output = self._sdpa_with_kv_cache_ref(
100+
q, k, v, self.k_cache, self.v_cache, self.mask, 3
101+
)
102+
op_output = torch.ops.llama.sdpa_with_kv_cache(
103+
q, k, v, self.k_cache, self.v_cache, 3, 1, None, 0, False
104+
)
105+
self.assertTrue(torch.allclose(ref_output, op_output))
106+
107+
108+
class SDPATestWithMQA(unittest.TestCase):
109+
110+
def setup_caches(self):
111+
self.k_cache = torch.zeros((1, 5, self.n_heads_kv, 4))
112+
self.v_cache = torch.zeros((1, 5, self.n_heads_kv, 4))
113+
114+
def setUp(self):
115+
torch.manual_seed(42)
116+
self.n_heads_kv = 4
117+
self.n_heads_q = 8
118+
self.setup_caches()
119+
self.mask = torch.full(
120+
(5, 5),
121+
float("-inf"),
122+
)
123+
self.mask = torch.triu(self.mask, diagonal=1)
124+
125+
def _sdpa_with_kv_cache_ref(self, q, k, v, k_cache, v_cache, mask, start_pos):
126+
print(f"at start_pos:{start_pos}")
127+
print(q)
128+
print(k)
129+
print(v)
130+
attn_mask = mask[start_pos].view((1, -1))
131+
attn_mask = attn_mask[:, : start_pos + 1]
132+
q = q.transpose(1, 2)
133+
k_cache[:, start_pos] = k
134+
v_cache[:, start_pos] = v
135+
sliced_k_cache = k_cache[:, : start_pos + 1, :, :]
136+
sliced_v_cache = v_cache[:, : start_pos + 1, :, :]
137+
sliced_k_cache = sliced_k_cache.transpose(1, 2)
138+
sliced_v_cache = sliced_v_cache.transpose(1, 2)
139+
# print(sliced_k_cache.size())
140+
# print(torch.matmul(q, sliced_k_cache.transpose(2, 3)))
141+
# print("q @ k")
142+
# qk = torch.matmul(q, sliced_k_cache.transpose(2, 3))
143+
# qk_softmax = torch.softmax(qk, dim=-1)
144+
# qkv = torch.matmul(qk_softmax, sliced_v_cache)
145+
# print(qk)
146+
# print(qk_softmax)
147+
# print(qkv)
148+
num_heads_q = q.size(1)
149+
num_heads_kv = sliced_k_cache.size(1)
150+
if num_heads_q != num_heads_kv:
151+
assert (
152+
num_heads_q % num_heads_kv == 0
153+
), f"{num_heads_q} not divisible by {num_heads_kv}"
154+
n_reps = num_heads_q // num_heads_kv
155+
if n_reps > 1:
156+
sliced_k_cache = sliced_k_cache.repeat_interleave(n_reps, dim=1)
157+
sliced_v_cache = sliced_v_cache.repeat_interleave(n_reps, dim=1)
158+
out = F.scaled_dot_product_attention(
159+
q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
160+
)
161+
out = out.transpose(1, 2)
162+
print(out)
163+
print(f"-------- start pos {start_pos} done -----")
164+
return out
165+
166+
def test_sdpa_with_cache_mqa_1(self):
167+
q = torch.rand((1, 1, self.n_heads_q, 4))
168+
k = torch.rand((1, 1, self.n_heads_kv, 4))
169+
v = torch.rand((1, 1, self.n_heads_kv, 4))
170+
ref_output = self._sdpa_with_kv_cache_ref(
171+
q, k, v, self.k_cache, self.v_cache, self.mask, 0
172+
)
173+
op_output = torch.ops.llama.sdpa_with_kv_cache(
174+
q, k, v, self.k_cache, self.v_cache, 0, 1, None, 0, False
175+
)
176+
self.assertTrue(torch.allclose(ref_output, op_output))
177+
178+
def test_sdpa_with_cache_mqa_2(self):
179+
q = torch.rand((1, 1, self.n_heads_q, 4))
180+
k = torch.rand((1, 1, self.n_heads_kv, 4))
181+
v = torch.rand((1, 1, self.n_heads_kv, 4))
182+
ref_output = self._sdpa_with_kv_cache_ref(
183+
q, k, v, self.k_cache, self.v_cache, self.mask, 1
184+
)
185+
op_output = torch.ops.llama.sdpa_with_kv_cache(
186+
q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False
187+
)
188+
self.assertTrue(torch.allclose(ref_output, op_output))
189+
190+
def test_sdpa_with_cache_mqa_3(self):
191+
self.n_heads_q = 14
192+
self.n_heads_kv = 7
193+
self.setup_caches()
194+
q = torch.rand((1, 1, self.n_heads_q, 4))
195+
k = torch.rand((1, 1, self.n_heads_kv, 4))
196+
v = torch.rand((1, 1, self.n_heads_kv, 4))
197+
ref_output = self._sdpa_with_kv_cache_ref(
198+
q, k, v, self.k_cache, self.v_cache, self.mask, 1
199+
)
200+
op_output = torch.ops.llama.sdpa_with_kv_cache(
201+
q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False
202+
)
203+
self.assertTrue(torch.allclose(ref_output, op_output))

0 commit comments

Comments
 (0)