Skip to content

Commit f88a7d1

Browse files
authored
Optimize softmax via flash attention v2 (#2468)
* optimize softmax as flash attention v2
1 parent e046f5c commit f88a7d1

File tree

1 file changed

+18
-47
lines changed

1 file changed

+18
-47
lines changed

csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp

Lines changed: 18 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -191,25 +191,25 @@ inline Vectorized<float> exp_u20(Vectorized<float> data) {
191191

192192
// 1) out = exp(a - val)
193193
// 2) val = sum(out)
194-
template <typename scalar_t>
194+
template <typename T1, typename T2>
195195
inline void _exp_reduce_sum_fusion_kernel(
196-
scalar_t* a,
196+
T1* a,
197197
const int& size,
198-
scalar_t* out,
199-
scalar_t& val) {
200-
auto vec_size = at::vec::Vectorized<scalar_t>::size();
201-
auto vec_max = at::vec::Vectorized<scalar_t>(val);
202-
scalar_t tmp_sum = 0;
203-
auto vec_tmp_sum = at::vec::Vectorized<scalar_t>(tmp_sum);
198+
T2* out,
199+
T1& val) {
200+
auto vec_size = at::vec::Vectorized<T1>::size();
201+
auto vec_max = at::vec::Vectorized<T1>(val);
202+
T1 tmp_sum = 0;
203+
auto vec_tmp_sum = at::vec::Vectorized<T1>(tmp_sum);
204204
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
205-
auto tmp0 = at::vec::Vectorized<scalar_t>::loadu(a + i);
205+
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
206206
auto tmp1 = tmp0 - vec_max;
207207
auto tmp2 = exp_u20(tmp1);
208208
vec_tmp_sum += tmp2;
209209
at::native::_store(out + i, tmp2);
210210
}
211-
tmp_sum = at::vec::vec_reduce_all<scalar_t>(
212-
[](at::vec::Vectorized<scalar_t>& x, at::vec::Vectorized<scalar_t>& y) {
211+
tmp_sum = at::vec::vec_reduce_all<T1>(
212+
[](at::vec::Vectorized<T1>& x, at::vec::Vectorized<T1>& y) {
213213
return x + y;
214214
},
215215
vec_tmp_sum);
@@ -223,27 +223,6 @@ inline void _exp_reduce_sum_fusion_kernel(
223223
val = tmp_sum;
224224
}
225225

226-
// out = a / sum
227-
template <typename T1, typename T2>
228-
inline void _normalization_kernel(
229-
const T1* a,
230-
const T1& sum,
231-
const int& size,
232-
T2* out) {
233-
auto vec_size = at::vec::Vectorized<T1>::size();
234-
auto vec_sum = at::vec::Vectorized<T1>(sum);
235-
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
236-
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
237-
auto tmp1 = tmp0 / vec_sum;
238-
at::native::_store(out + i, tmp1);
239-
}
240-
for (long i = vec_size * (size / vec_size); i < size; i++) {
241-
auto tmp0 = a[i];
242-
auto tmp1 = tmp0 / sum;
243-
out[i] = tmp1;
244-
}
245-
}
246-
247226
// 1) out = a * scale
248227
// 2) max = max(out)
249228
template <typename scalar_t>
@@ -767,29 +746,19 @@ void cpu_flash_attention(
767746
_exp_reduce_sum_fusion_kernel(
768747
qk_data + row * kvBlockSize,
769748
kvBlockSize,
770-
qk_data + row * kvBlockSize,
749+
conditional_data_ptr(qk_data, qk_reduced_data) +
750+
row * kvBlockSize,
771751
tmp_sum);
772752
// exp_tmp <- exp(max[row] - max)
773753
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
774754
// sum[row] <- sum + exp_tmp * sum[row]
775755
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
776756
// max[row] <- max
777757
qk_max_data[row] = tmp_max;
778-
// qk <- qk / sum[row]
779-
accum_t sum_new = qk_sum_data[row];
780-
_normalization_kernel(
781-
qk_data + row * kvBlockSize,
782-
sum_new,
783-
kvBlockSize,
784-
conditional_data_ptr(qk_data, qk_reduced_data) +
785-
row * kvBlockSize);
786-
// dst <- dst * sum_old / sum_new * exp_tmp
758+
// dst <- dst * exp_tmp
787759
if (n > 0) {
788-
accum_t sum_cor = sum_old / sum_new;
789760
at::vec::map<accum_t>(
790-
[sum_cor, exp_tmp](Vec x) {
791-
return x * Vec(sum_cor) * Vec(exp_tmp);
792-
},
761+
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
793762
dst_data + row * headSize,
794763
dst_data + row * headSize,
795764
headSize);
@@ -856,10 +825,12 @@ void cpu_flash_attention(
856825
headSize);
857826
}
858827
}
828+
// dst <- dst / sum[row]
859829
// reorder MHA output with strides
860830
for (int64_t row = 0; row < qBlockSize; ++row) {
831+
accum_t sum_reciprocal = 1 / qk_sum_data[row];
861832
at::vec::map<scalar_t>(
862-
[](Vec x) { return x; },
833+
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
863834
out_data + i * oStrideB + j * oStrideH + m * oStrideM +
864835
row * oStrideM,
865836
dst_data + row * headSize,

0 commit comments

Comments
 (0)