Skip to content

Commit f55a2bf

Browse files
Valentine233jiayisunxWeizhuoZhang-intel
authored
[flash attention] calculate logsumexp for backward (#2631)
* [flash attention] calculate logsumexp for backward * [flash attention] calculate logsumexp for backward reduced type --------- Co-authored-by: jiayisunx <[email protected]> Co-authored-by: WeizhuoZhang-intel <[email protected]>
1 parent c7b1a8b commit f55a2bf

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,9 @@ cpu_flash_attention(
420420
int64_t oStrideB = output.stride(0);
421421
int64_t oStrideM = output.stride(1);
422422
int64_t oStrideH = output.stride(2);
423+
int64_t lStrideB = logsumexp.stride(0);
424+
int64_t lStrideM = logsumexp.stride(1);
425+
int64_t lStrideH = logsumexp.stride(2);
423426
int64_t mStrideB =
424427
(attention_mask.has_value() && attention_mask.value().size(0) > 1)
425428
? attention_mask.value().stride(0)
@@ -459,6 +462,7 @@ cpu_flash_attention(
459462
? attention_mask.value().data_ptr<accum_t>()
460463
: nullptr;
461464
scalar_t* out_data = output.data_ptr<scalar_t>();
465+
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
462466
accum_t* buf_data = buf.data_ptr<accum_t>();
463467

464468
at::parallel_for(
@@ -616,6 +620,13 @@ cpu_flash_attention(
616620
dst_data + row * headSize,
617621
headSize);
618622
}
623+
// Store logsumexp for backward
624+
accum_t* lse_ptr =
625+
lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
626+
for (const auto row : c10::irange(qBlockSize)) {
627+
lse_ptr[row * lStrideM] =
628+
qk_max_data[row] + std::log(qk_sum_data[row]);
629+
}
619630
// Move to the next query
620631
at::native::data_index_step(i, batchSize, j, num_head, k, qSlice);
621632
}
@@ -684,6 +695,9 @@ cpu_flash_attention(
684695
int64_t oStrideB = output.stride(0);
685696
int64_t oStrideM = output.stride(1);
686697
int64_t oStrideH = output.stride(2);
698+
int64_t lStrideB = logsumexp.stride(0);
699+
int64_t lStrideM = logsumexp.stride(1);
700+
int64_t lStrideH = logsumexp.stride(2);
687701
int64_t mStrideB =
688702
(attention_mask.has_value() && attention_mask.value().size(0) > 1)
689703
? attention_mask.value().stride(0)
@@ -725,6 +739,7 @@ cpu_flash_attention(
725739
? attention_mask.value().data_ptr<accum_t>()
726740
: nullptr;
727741
scalar_t* out_data = output.data_ptr<scalar_t>();
742+
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
728743
accum_t* buf_data = buf.data_ptr<accum_t>();
729744
scalar_t* buf_reduced_data = buf_reduced.data_ptr<scalar_t>();
730745

@@ -1310,6 +1325,13 @@ cpu_flash_attention(
13101325
dst_data + row * headSize,
13111326
headSize);
13121327
}
1328+
// Store logsumexp for backward
1329+
accum_t* lse_ptr =
1330+
lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
1331+
for (const auto row : c10::irange(qBlockSize)) {
1332+
lse_ptr[row * lStrideM] =
1333+
qk_max_data[row] + std::log(qk_sum_data[row]);
1334+
}
13131335
// Move to the next query
13141336
at::native::data_index_step(i, batchSize, j, num_head, k, qSlice);
13151337
}

0 commit comments

Comments
 (0)