Skip to content

Commit 2155284

Browse files
authored
[Executorch][SDPA] Fix bug in sdpa
This diff fixes two bugs 1. When doing flash attention, the partical q @ k block may contain some entries that needs to be masked out. This logic had a bug. Maybe this bug also exist in PT core. I will look into that to add test and see if I can prove it. 2. Due to special handling via start_pos in SDPA it also exposed the bug in 1 when doing really long sequence prefill in chunked manner. It is probably better to just use mask though. Code has detail comments on the issue and fix. Differential Revision: D70922039
1 parent bd87ca5 commit 2155284

File tree

2 files changed

+65
-20
lines changed

2 files changed

+65
-20
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ void cpu_flash_attention(
382382
/* qk_sum */ qSplitSize +
383383
/* dst */ qSplitSize * headSize;
384384

385-
int64_t size_bytes = size_per_thread * num_thread * query.element_size();
385+
int64_t size_bytes = size_per_thread * num_thread * query.element_size() * 4;
386386
std::vector<char> buf_vec(size_bytes);
387387
void* buf = reinterpret_cast<void*>(buf_vec.data());
388388
// Need to double check the following
@@ -452,6 +452,7 @@ void cpu_flash_attention(
452452
// However, lets just fix that as well.
453453
int64_t num_keys =
454454
is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize;
455+
int64_t m_start_pos = m + start_pos;
455456
auto j_kv = j / num_reps;
456457
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
457458
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
@@ -471,29 +472,62 @@ void cpu_flash_attention(
471472
static_cast<accum_t>(0),
472473
qk_data,
473474
kvBlockSize);
474-
// Apply causal mask, fill unused, i.e. future values, with -inf
475-
// Say you have q @ k.T size = [16, 32]
476-
// With qblock size = 4, say you are processing
477-
// q seq len dim = 8:11.
478-
// Say kvSplitSize = 4
479-
// Then for causal mask, the entries that needs to be
480-
// ignored are
481-
// [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31]
482-
// Following condition says that num_keys = 8 + 4 =12
483-
// (num_keys - n) <= kvSplitSize
484-
// num_keys <= n + kvSplitSize
485-
// If n + kvSplitSize is larger than 12, then some
486-
// entries need masked out. In our example n = 4
487-
// will qualify for that
488-
if (is_causal && num_keys - n <= kvSplitSize) {
475+
// There are 4 cases that is_causal has to cover to fill
476+
// not-attendable-position with -inf
477+
/* 1. Everything is attended to. This happens when m_start_pos > n +
478+
kvSplitSize e.g m_pos [8:15] and n_pos [0:7]. Since you must attend to
479+
all previous tokens matrix is full
480+
+ + + + + + + +
481+
+ + + + + + + +
482+
+ + + + + + + +
483+
+ + + + + + + +
484+
+ + + + + + + +
485+
+ + + + + + + +
486+
+ + + + + + + +
487+
2. Everything is not attended to. However only some tokens at the
488+
beginning dont attend to everything. This happens when m_start_pos <= n
489+
+ kvSplitSize but m_start_pos + qBlockSize > n + kvSplitSize m_start_pos
490+
= 8 qBlockSize = 8 n = 4 kvSplitSize = 8 For example m_pos [8:15] but
491+
n_pos is [4:11]
492+
+ + + + + - - -
493+
+ + + + + + - -
494+
+ + + + + + + -
495+
+ + + + + + + +
496+
+ + + + + + + +
497+
+ + + + + + + +
498+
+ + + + + + + +
499+
+ + + + + + + +
500+
3. In this case only last few tokens have something to attend to.
501+
This happens when m_start_pos < n and m_start_pos + qBlockSize >= n and
502+
m_start_pos + qBlockSize <= n + kvSplitSize m_start_pos = 8 qBlockSize =
503+
8 n = 13 kvSplitSize = 8 For example m_pos [8:15] but n_pos is [13:20]
504+
- - - - - - - -
505+
- - - - - - - -
506+
- - - - - - - -
507+
- - - - - - - -
508+
- - - - - - - -
509+
+ - - - - - - -
510+
+ + - - - - - -
511+
+ + + - - - - -
512+
4. In this no tokens attend to anything, but we dont really have to
513+
take care of this case because the loop for (int64_t n = 0; n <
514+
num_keys; n += kvSplitSize) will exit before that.
515+
*/
516+
if (is_causal && m_start_pos <= n + kvSplitSize) {
489517
// For this fn to work k_split_size > q_split_size
490-
for (int32_t row = 0; row < qBlockSize; ++row) {
491-
int64_t last_col = m + (row + start_pos) - n;
518+
for (int32_t row = 0;
519+
row < qBlockSize && (m_start_pos + row < n + (kvSplitSize - 1));
520+
++row) {
521+
// When last_col is 0, it means that the entire row is not attended
522+
// to because m_pos is smaller than n_pos. So everything in n is for
523+
// future.
524+
int64_t last_col =
525+
n > (m_start_pos + row) ? 0 : row + m_start_pos + 1 - n;
492526
accum_t* row_ptr = qk_data + row * kvBlockSize;
493527
fill_stub(
494-
row_ptr + last_col + 1,
528+
row_ptr + last_col,
495529
-std::numeric_limits<accum_t>::infinity(),
496-
kvBlockSize - last_col - 1);
530+
kvBlockSize - last_col);
497531
}
498532
}
499533
// Update attention weights with attention mask

extension/llm/custom_ops/test_sdpa_with_kv_cache.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,14 @@ def test_sdpa_with_cache_seq_len_llava_example_gqa(self):
590590
self._test_sdpa_common(
591591
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
592592
)
593+
594+
def test_sdpa_to_repro_long_seq_failure(self):
595+
n_heads_kv = 16
596+
n_heads_q = 32
597+
head_dim = 128
598+
max_seq_len = 2048
599+
seq_len = 508
600+
next_iter_seq_len = 127
601+
self._test_sdpa_common(
602+
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
603+
)

0 commit comments

Comments
 (0)