Skip to content

Commit 6dbb4dc

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Fix sdpa flash attention op for et llama deployment (#4322)
Summary: Pull Request resolved: #4322 We retropfitted flash attention cpu from aten. The retrofit we did was to make it work to cacluate attention for a) batched prefill and b) decode with different start_pos. For b, there was a bug when kv cache's seqlen dim is split. As a result attention calculation is not right. There is a detail in the code to explain the issue. bypass-github-export-checks ghstack-source-id: 234634902 Reviewed By: larryliu0820 Differential Revision: D60011925 fbshipit-source-id: 50921846b329e449a4a767cf28c7a55d507217bd
1 parent 9d85965 commit 6dbb4dc

File tree

2 files changed

+168
-1
lines changed

2 files changed

+168
-1
lines changed

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ void cpu_flash_attention(
239239
at::Tensor value = v.transpose(1, 2);
240240
*/
241241

242+
// Without this we have out-of-bounds writes for
243+
// causal masking
244+
static_assert(
245+
kv_split_size > q_split_size,
246+
"KV_split_size must be greater than q_split_size");
247+
242248
constexpr bool is_reduced_type =
243249
torch::executor::is_reduced_floating_point<scalar_t>::value;
244250

@@ -417,7 +423,35 @@ void cpu_flash_attention(
417423
// Initialize max and sum
418424
fill_stub(
419425
qk_max_data, -std::numeric_limits<accum_t>::infinity(), qBlockSize);
420-
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
426+
// Original flash sdpa wasnt really meant to be used
427+
// for decode the way we are using via start_pos here.
428+
// Thus when num_keys is 1 during decode phase, we
429+
// still need to iterate through all the kv_splits
430+
// Take start_pos = 130 and k_split_size = 128
431+
// Here we have to produce [1x130] of q @ k.T
432+
// when seq_len = 1
433+
// But if num_keys = 1 then we dont really loop over
434+
// all kv_splits.
435+
// When k_split_size > 130, this is not an issue because
436+
// there is only one iteration of the following loop anyway.
437+
// Outside of determining how many loop iterations are needed
438+
// num_keys participates only in causal attention.
439+
// Rest of the calculation of q @ k.T and @ v.T is same.
440+
// We dont run into this bug when k_split_size < start_pos + seqlen
441+
// since there is only one iteration and that applies
442+
// causal attention correctly.
443+
// Howeve when k_split_size > start_pos + seqlen, we have
444+
// more than one iteration, however if we dont adjust num_keys
445+
// we dont get more than one iteration
446+
// This is unique to this deployment of flash attention since
447+
// original implementation wasnt deployed on this way.
448+
449+
// Some of these bugs can be resolved by relying on attention mask
450+
// but that requires storing attention mask in float as the current
451+
// code doesnt support bool attention mask.
452+
// However, lets just fix that as well.
453+
int64_t num_keys =
454+
is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize;
421455
auto j_kv = j / num_reps;
422456
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
423457
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
@@ -452,6 +486,7 @@ void cpu_flash_attention(
452486
// entries need masked out. In our example n = 4
453487
// will qualify for that
454488
if (is_causal && num_keys - n <= kvSplitSize) {
489+
// For this fn to work k_split_size > q_split_size
455490
for (int32_t row = 0; row < qBlockSize; ++row) {
456491
int64_t last_col = m + (row + start_pos) - n;
457492
accum_t* row_ptr = qk_data + row * kvBlockSize;

examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,135 @@ def test_sdpa_with_cache_mqa_3(self):
365365
q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False
366366
)
367367
self.assertTrue(torch.allclose(ref_output, op_output))
368+
369+
370+
class SDPATestForLargeSeqLength(unittest.TestCase):
371+
372+
def setup_caches(self):
373+
self.k_cache = torch.zeros(
374+
(1, self.max_seq_len, self.n_heads_kv, self.head_dim)
375+
)
376+
self.v_cache = torch.zeros(
377+
(1, self.max_seq_len, self.n_heads_kv, self.head_dim)
378+
)
379+
self.mask = torch.full(
380+
(self.max_seq_len, self.max_seq_len),
381+
float("-inf"),
382+
)
383+
self.mask = torch.triu(self.mask, diagonal=1)
384+
385+
def setUp(self):
386+
torch.manual_seed(42)
387+
self.n_heads_kv = 32
388+
self.n_heads_q = 32
389+
self.head_dim = 128
390+
self.max_seq_len = 2048
391+
self.setup_caches()
392+
393+
def test_sdpa_with_cache_seq_len_130(self):
394+
self.n_heads_kv = 32
395+
self.n_heads_q = 32
396+
self.head_dim = 128
397+
self.max_seq_len = 2048
398+
self.setup_caches()
399+
seq_len = 130
400+
q = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
401+
k = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
402+
v = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
403+
start_pos = 0
404+
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
405+
attn_mask = attn_mask[:, : start_pos + seq_len]
406+
ref_output = _sdpa_with_kv_cache_ref(
407+
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
408+
)
409+
op_output = torch.ops.llama.sdpa_with_kv_cache(
410+
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
411+
)
412+
self.assertTrue(torch.allclose(ref_output, op_output))
413+
414+
q = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
415+
k = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
416+
v = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
417+
start_pos = seq_len
418+
seq_len = q.size(1)
419+
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
420+
attn_mask = attn_mask[:, : start_pos + seq_len]
421+
ref_output = _sdpa_with_kv_cache_ref(
422+
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
423+
)
424+
op_output = torch.ops.llama.sdpa_with_kv_cache(
425+
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
426+
)
427+
self.assertTrue(torch.allclose(ref_output, op_output))
428+
429+
def test_sdpa_with_cache_seq_len_small(self):
430+
self.n_heads_kv = 4
431+
self.n_heads_q = 4
432+
self.head_dim = 4
433+
self.max_seq_len = 8
434+
self.setup_caches()
435+
q = torch.rand((1, 4, self.n_heads_q, 4))
436+
k = torch.rand((1, 4, self.n_heads_q, 4))
437+
v = torch.rand((1, 4, self.n_heads_q, 4))
438+
start_pos = 0
439+
seq_len = q.size(1)
440+
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
441+
attn_mask = attn_mask[:, : start_pos + seq_len]
442+
ref_output = _sdpa_with_kv_cache_ref(
443+
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
444+
)
445+
op_output = torch.ops.llama.sdpa_with_kv_cache(
446+
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
447+
)
448+
self.assertTrue(torch.allclose(ref_output, op_output))
449+
450+
q = torch.rand((1, 1, self.n_heads_q, 4))
451+
k = torch.rand((1, 1, self.n_heads_q, 4))
452+
v = torch.rand((1, 1, self.n_heads_q, 4))
453+
start_pos = 4
454+
seq_len = q.size(1)
455+
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
456+
attn_mask = attn_mask[:, : start_pos + seq_len]
457+
ref_output = _sdpa_with_kv_cache_ref(
458+
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
459+
)
460+
op_output = torch.ops.llama.sdpa_with_kv_cache(
461+
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
462+
)
463+
self.assertTrue(torch.allclose(ref_output, op_output))
464+
465+
def test_sdpa_with_cache_seq_len_llava_example(self):
466+
self.n_heads_kv = 32
467+
self.n_heads_q = 32
468+
self.head_dim = 128
469+
self.max_seq_len = 2048
470+
self.setup_caches()
471+
seq_len = 634
472+
q = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
473+
k = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
474+
v = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
475+
start_pos = 0
476+
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
477+
attn_mask = attn_mask[:, : start_pos + seq_len]
478+
ref_output = _sdpa_with_kv_cache_ref(
479+
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
480+
)
481+
op_output = torch.ops.llama.sdpa_with_kv_cache(
482+
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
483+
)
484+
self.assertTrue(torch.allclose(ref_output, op_output))
485+
486+
q = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
487+
k = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
488+
v = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
489+
start_pos = seq_len
490+
seq_len = q.size(1)
491+
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
492+
attn_mask = attn_mask[:, : start_pos + seq_len]
493+
ref_output = _sdpa_with_kv_cache_ref(
494+
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
495+
)
496+
op_output = torch.ops.llama.sdpa_with_kv_cache(
497+
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
498+
)
499+
self.assertTrue(torch.allclose(ref_output, op_output))

0 commit comments

Comments
 (0)