@@ -382,7 +382,7 @@ void cpu_flash_attention(
382
382
/* qk_sum */ qSplitSize +
383
383
/* dst */ qSplitSize * headSize;
384
384
385
- int64_t size_bytes = size_per_thread * num_thread * query.element_size ();
385
+ int64_t size_bytes = size_per_thread * num_thread * query.element_size () * 4 ;
386
386
std::vector<char > buf_vec (size_bytes);
387
387
void * buf = reinterpret_cast <void *>(buf_vec.data ());
388
388
// Need to double check the following
@@ -452,6 +452,7 @@ void cpu_flash_attention(
452
452
// However, lets just fix that as well.
453
453
int64_t num_keys =
454
454
is_causal ? std::min (m + start_pos + qBlockSize, kvSize) : kvSize;
455
+ int64_t m_start_pos = m + start_pos;
455
456
auto j_kv = j / num_reps;
456
457
for (int64_t n = 0 ; n < num_keys; n += kvSplitSize) {
457
458
int64_t kvBlockSize = std::min (kvSplitSize, kvSize - n);
@@ -471,29 +472,62 @@ void cpu_flash_attention(
471
472
static_cast<accum_t>(0 ),
472
473
qk_data,
473
474
kvBlockSize);
474
- // Apply causal mask, fill unused, i.e. future values, with -inf
475
- // Say you have q @ k.T size = [16, 32]
476
- // With qblock size = 4, say you are processing
477
- // q seq len dim = 8:11.
478
- // Say kvSplitSize = 4
479
- // Then for causal mask, the entries that needs to be
480
- // ignored are
481
- // [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31]
482
- // Following condition says that num_keys = 8 + 4 =12
483
- // (num_keys - n) <= kvSplitSize
484
- // num_keys <= n + kvSplitSize
485
- // If n + kvSplitSize is larger than 12, then some
486
- // entries need masked out. In our example n = 4
487
- // will qualify for that
488
- if (is_causal && num_keys - n <= kvSplitSize) {
475
+ // There are 4 cases that is_causal has to cover to fill
476
+ // not-attendable-position with -inf
477
+ /* 1. Everything is attended to. This happens when m_start_pos > n +
478
+ kvSplitSize e.g m_pos [8:15] and n_pos [0:7]. Since you must attend to
479
+ all previous tokens matrix is full
480
+ + + + + + + + +
481
+ + + + + + + + +
482
+ + + + + + + + +
483
+ + + + + + + + +
484
+ + + + + + + + +
485
+ + + + + + + + +
486
+ + + + + + + + +
487
+ 2. Everything is not attended to. However only some tokens at the
488
+ beginning dont attend to everything. This happens when m_start_pos <= n
489
+ + kvSplitSize but m_start_pos + qBlockSize > n + kvSplitSize m_start_pos
490
+ = 8 qBlockSize = 8 n = 4 kvSplitSize = 8 For example m_pos [8:15] but
491
+ n_pos is [4:11]
492
+ + + + + + - - -
493
+ + + + + + + - -
494
+ + + + + + + + -
495
+ + + + + + + + +
496
+ + + + + + + + +
497
+ + + + + + + + +
498
+ + + + + + + + +
499
+ + + + + + + + +
500
+ 3. In this case only last few tokens have something to attend to.
501
+ This happens when m_start_pos < n and m_start_pos + qBlockSize >= n and
502
+ m_start_pos + qBlockSize <= n + kvSplitSize m_start_pos = 8 qBlockSize =
503
+ 8 n = 13 kvSplitSize = 8 For example m_pos [8:15] but n_pos is [13:20]
504
+ - - - - - - - -
505
+ - - - - - - - -
506
+ - - - - - - - -
507
+ - - - - - - - -
508
+ - - - - - - - -
509
+ + - - - - - - -
510
+ + + - - - - - -
511
+ + + + - - - - -
512
+ 4. In this no tokens attend to anything, but we dont really have to
513
+ take care of this case because the loop for (int64_t n = 0; n <
514
+ num_keys; n += kvSplitSize) will exit before that.
515
+ */
516
+ if (is_causal && m_start_pos <= n + kvSplitSize) {
489
517
// For this fn to work k_split_size > q_split_size
490
- for (int32_t row = 0 ; row < qBlockSize; ++row) {
491
- int64_t last_col = m + (row + start_pos) - n;
518
+ for (int32_t row = 0 ;
519
+ row < qBlockSize && (m_start_pos + row < n + (kvSplitSize - 1 ));
520
+ ++row) {
521
+ // When last_col is 0, it means that the entire row is not attended
522
+ // to because m_pos is smaller than n_pos. So everything in n is for
523
+ // future.
524
+ int64_t last_col =
525
+ n > (m_start_pos + row) ? 0 : row + m_start_pos + 1 - n;
492
526
accum_t * row_ptr = qk_data + row * kvBlockSize;
493
527
fill_stub (
494
- row_ptr + last_col + 1 ,
528
+ row_ptr + last_col,
495
529
-std::numeric_limits<accum_t >::infinity (),
496
- kvBlockSize - last_col - 1 );
530
+ kvBlockSize - last_col);
497
531
}
498
532
}
499
533
// Update attention weights with attention mask
0 commit comments