@@ -420,6 +420,9 @@ cpu_flash_attention(
420
420
int64_t oStrideB = output.stride (0 );
421
421
int64_t oStrideM = output.stride (1 );
422
422
int64_t oStrideH = output.stride (2 );
423
+ int64_t lStrideB = logsumexp.stride (0 );
424
+ int64_t lStrideM = logsumexp.stride (1 );
425
+ int64_t lStrideH = logsumexp.stride (2 );
423
426
int64_t mStrideB =
424
427
(attention_mask.has_value () && attention_mask.value ().size (0 ) > 1 )
425
428
? attention_mask.value ().stride (0 )
@@ -459,6 +462,7 @@ cpu_flash_attention(
459
462
? attention_mask.value ().data_ptr <accum_t >()
460
463
: nullptr ;
461
464
scalar_t * out_data = output.data_ptr <scalar_t >();
465
+ accum_t * lse_data = logsumexp.data_ptr <accum_t >();
462
466
accum_t * buf_data = buf.data_ptr <accum_t >();
463
467
464
468
at::parallel_for (
@@ -616,6 +620,13 @@ cpu_flash_attention(
616
620
dst_data + row * headSize,
617
621
headSize);
618
622
}
623
+ // Store logsumexp for backward
624
+ accum_t * lse_ptr =
625
+ lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
626
+ for (const auto row : c10::irange (qBlockSize)) {
627
+ lse_ptr[row * lStrideM] =
628
+ qk_max_data[row] + std::log (qk_sum_data[row]);
629
+ }
619
630
// Move to the next query
620
631
at::native::data_index_step (i, batchSize, j, num_head, k, qSlice);
621
632
}
@@ -684,6 +695,9 @@ cpu_flash_attention(
684
695
int64_t oStrideB = output.stride (0 );
685
696
int64_t oStrideM = output.stride (1 );
686
697
int64_t oStrideH = output.stride (2 );
698
+ int64_t lStrideB = logsumexp.stride (0 );
699
+ int64_t lStrideM = logsumexp.stride (1 );
700
+ int64_t lStrideH = logsumexp.stride (2 );
687
701
int64_t mStrideB =
688
702
(attention_mask.has_value () && attention_mask.value ().size (0 ) > 1 )
689
703
? attention_mask.value ().stride (0 )
@@ -725,6 +739,7 @@ cpu_flash_attention(
725
739
? attention_mask.value ().data_ptr <accum_t >()
726
740
: nullptr ;
727
741
scalar_t * out_data = output.data_ptr <scalar_t >();
742
+ accum_t * lse_data = logsumexp.data_ptr <accum_t >();
728
743
accum_t * buf_data = buf.data_ptr <accum_t >();
729
744
scalar_t * buf_reduced_data = buf_reduced.data_ptr <scalar_t >();
730
745
@@ -1310,6 +1325,13 @@ cpu_flash_attention(
1310
1325
dst_data + row * headSize,
1311
1326
headSize);
1312
1327
}
1328
+ // Store logsumexp for backward
1329
+ accum_t * lse_ptr =
1330
+ lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
1331
+ for (const auto row : c10::irange (qBlockSize)) {
1332
+ lse_ptr[row * lStrideM] =
1333
+ qk_max_data[row] + std::log (qk_sum_data[row]);
1334
+ }
1313
1335
// Move to the next query
1314
1336
at::native::data_index_step (i, batchSize, j, num_head, k, qSlice);
1315
1337
}
0 commit comments