@@ -191,25 +191,25 @@ inline Vectorized<float> exp_u20(Vectorized<float> data) {
191
191
192
192
// 1) out = exp(a - val)
193
193
// 2) val = sum(out)
194
- template <typename scalar_t >
194
+ template <typename T1, typename T2 >
195
195
inline void _exp_reduce_sum_fusion_kernel (
196
- scalar_t * a,
196
+ T1 * a,
197
197
const int & size,
198
- scalar_t * out,
199
- scalar_t & val) {
200
- auto vec_size = at::vec::Vectorized<scalar_t >::size ();
201
- auto vec_max = at::vec::Vectorized<scalar_t >(val);
202
- scalar_t tmp_sum = 0 ;
203
- auto vec_tmp_sum = at::vec::Vectorized<scalar_t >(tmp_sum);
198
+ T2 * out,
199
+ T1 & val) {
200
+ auto vec_size = at::vec::Vectorized<T1 >::size ();
201
+ auto vec_max = at::vec::Vectorized<T1 >(val);
202
+ T1 tmp_sum = 0 ;
203
+ auto vec_tmp_sum = at::vec::Vectorized<T1 >(tmp_sum);
204
204
for (long i = 0 ; i < vec_size * (size / vec_size); i += vec_size) {
205
- auto tmp0 = at::vec::Vectorized<scalar_t >::loadu (a + i);
205
+ auto tmp0 = at::vec::Vectorized<T1 >::loadu (a + i);
206
206
auto tmp1 = tmp0 - vec_max;
207
207
auto tmp2 = exp_u20 (tmp1);
208
208
vec_tmp_sum += tmp2;
209
209
at::native::_store (out + i, tmp2);
210
210
}
211
- tmp_sum = at::vec::vec_reduce_all<scalar_t >(
212
- [](at::vec::Vectorized<scalar_t >& x, at::vec::Vectorized<scalar_t >& y) {
211
+ tmp_sum = at::vec::vec_reduce_all<T1 >(
212
+ [](at::vec::Vectorized<T1 >& x, at::vec::Vectorized<T1 >& y) {
213
213
return x + y;
214
214
},
215
215
vec_tmp_sum);
@@ -223,27 +223,6 @@ inline void _exp_reduce_sum_fusion_kernel(
223
223
val = tmp_sum;
224
224
}
225
225
226
- // out = a / sum
227
- template <typename T1, typename T2>
228
- inline void _normalization_kernel (
229
- const T1* a,
230
- const T1& sum,
231
- const int & size,
232
- T2* out) {
233
- auto vec_size = at::vec::Vectorized<T1>::size ();
234
- auto vec_sum = at::vec::Vectorized<T1>(sum);
235
- for (long i = 0 ; i < vec_size * (size / vec_size); i += vec_size) {
236
- auto tmp0 = at::vec::Vectorized<T1>::loadu (a + i);
237
- auto tmp1 = tmp0 / vec_sum;
238
- at::native::_store (out + i, tmp1);
239
- }
240
- for (long i = vec_size * (size / vec_size); i < size; i++) {
241
- auto tmp0 = a[i];
242
- auto tmp1 = tmp0 / sum;
243
- out[i] = tmp1;
244
- }
245
- }
246
-
247
226
// 1) out = a * scale
248
227
// 2) max = max(out)
249
228
template <typename scalar_t >
@@ -767,29 +746,19 @@ void cpu_flash_attention(
767
746
_exp_reduce_sum_fusion_kernel (
768
747
qk_data + row * kvBlockSize,
769
748
kvBlockSize,
770
- qk_data + row * kvBlockSize,
749
+ conditional_data_ptr (qk_data, qk_reduced_data) +
750
+ row * kvBlockSize,
771
751
tmp_sum);
772
752
// exp_tmp <- exp(max[row] - max)
773
753
exp_tmp = std::exp (qk_max_data[row] - tmp_max);
774
754
// sum[row] <- sum + exp_tmp * sum[row]
775
755
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
776
756
// max[row] <- max
777
757
qk_max_data[row] = tmp_max;
778
- // qk <- qk / sum[row]
779
- accum_t sum_new = qk_sum_data[row];
780
- _normalization_kernel (
781
- qk_data + row * kvBlockSize,
782
- sum_new,
783
- kvBlockSize,
784
- conditional_data_ptr (qk_data, qk_reduced_data) +
785
- row * kvBlockSize);
786
- // dst <- dst * sum_old / sum_new * exp_tmp
758
+ // dst <- dst * exp_tmp
787
759
if (n > 0 ) {
788
- accum_t sum_cor = sum_old / sum_new;
789
760
at::vec::map<accum_t >(
790
- [sum_cor, exp_tmp](Vec x) {
791
- return x * Vec (sum_cor) * Vec (exp_tmp);
792
- },
761
+ [exp_tmp](Vec x) { return x * Vec (exp_tmp); },
793
762
dst_data + row * headSize,
794
763
dst_data + row * headSize,
795
764
headSize);
@@ -856,10 +825,12 @@ void cpu_flash_attention(
856
825
headSize);
857
826
}
858
827
}
828
+ // dst <- dst / sum[row]
859
829
// reorder MHA output with strides
860
830
for (int64_t row = 0 ; row < qBlockSize; ++row) {
831
+ accum_t sum_reciprocal = 1 / qk_sum_data[row];
861
832
at::vec::map<scalar_t >(
862
- [](Vec x) { return x; },
833
+ [sum_reciprocal ](Vec x) { return x * Vec (sum_reciprocal) ; },
863
834
out_data + i * oStrideB + j * oStrideH + m * oStrideM +
864
835
row * oStrideM,
865
836
dst_data + row * headSize,
0 commit comments