@@ -3406,8 +3406,6 @@ kernel void kernel_flash_attn_ext(
3406
3406
3407
3407
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0 *DK); // holds the query data
3408
3408
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3409
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3410
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0 *DK); // same as above but in o4_t
3411
3409
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2 *sgitg*SH + 2 *Q*DK); // scratch buffer for attention, mask and diagonal matrix
3412
3410
3413
3411
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4 *16 *KV) + Q*T); // scratch buffer to load K in shared memory
@@ -3621,20 +3619,20 @@ kernel void kernel_flash_attn_ext(
3621
3619
3622
3620
// O = diag(ms)*O
3623
3621
{
3624
- s8x8_t mm ;
3625
- simdgroup_load (mm , ss + 2 *C, TS, 0 , false );
3622
+ s8x8_t ms ;
3623
+ simdgroup_load (ms , ss + 2 *C, TS, 0 , false );
3626
3624
3627
3625
#pragma unroll(DV8)
3628
3626
for (short i = 0 ; i < DV8; ++i) {
3629
- simdgroup_multiply (lo[i], mm , lo[i]);
3627
+ simdgroup_multiply (lo[i], ms , lo[i]);
3630
3628
}
3631
3629
}
3632
3630
3633
3631
// O = O + (Q*K^T)*V
3634
3632
{
3635
3633
for (short cc = 0 ; cc < C/8 ; ++cc) {
3636
- s8x8_t ms ;
3637
- simdgroup_load (ms , ss + 8 *cc, TS, 0 , false );
3634
+ s8x8_t vs ;
3635
+ simdgroup_load (vs , ss + 8 *cc, TS, 0 , false );
3638
3636
3639
3637
if (is_same<vd4x4_t , v4x4_t >::value) {
3640
3638
// we can read directly from global memory
@@ -3645,7 +3643,7 @@ kernel void kernel_flash_attn_ext(
3645
3643
v8x8_t mv;
3646
3644
simdgroup_load (mv, pv + i*8 , args.nb21 /sizeof (v_t ), 0 , false ); // TODO: use ne20
3647
3645
3648
- simdgroup_multiply_accumulate (lo[i], ms , mv, lo[i]);
3646
+ simdgroup_multiply_accumulate (lo[i], vs , mv, lo[i]);
3649
3647
}
3650
3648
} else {
3651
3649
for (short ii = 0 ; ii < DV16; ii += 4 ) {
@@ -3666,10 +3664,10 @@ kernel void kernel_flash_attn_ext(
3666
3664
v8x8_t mv;
3667
3665
3668
3666
simdgroup_load (mv, sv + 16 *k + 0 *8 , 4 *16 , 0 , false );
3669
- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], ms , mv, lo[2 *(ii + k) + 0 ]);
3667
+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], vs , mv, lo[2 *(ii + k) + 0 ]);
3670
3668
3671
3669
simdgroup_load (mv, sv + 16 *k + 1 *8 , 4 *16 , 0 , false );
3672
- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], ms , mv, lo[2 *(ii + k) + 1 ]);
3670
+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], vs , mv, lo[2 *(ii + k) + 1 ]);
3673
3671
}
3674
3672
} else {
3675
3673
if (ii + tx < DV16) {
@@ -3684,10 +3682,10 @@ kernel void kernel_flash_attn_ext(
3684
3682
v8x8_t mv;
3685
3683
3686
3684
simdgroup_load (mv, sv + 16 *k + 0 *8 , 4 *16 , 0 , false );
3687
- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], ms , mv, lo[2 *(ii + k) + 0 ]);
3685
+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], vs , mv, lo[2 *(ii + k) + 0 ]);
3688
3686
3689
3687
simdgroup_load (mv, sv + 16 *k + 1 *8 , 4 *16 , 0 , false );
3690
- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], ms , mv, lo[2 *(ii + k) + 1 ]);
3688
+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], vs , mv, lo[2 *(ii + k) + 1 ]);
3691
3689
}
3692
3690
}
3693
3691
}
@@ -3697,83 +3695,80 @@ kernel void kernel_flash_attn_ext(
3697
3695
}
3698
3696
3699
3697
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3700
- for (short j = 0 ; j < Q; ++j) {
3701
- if (tiisg == 0 ) {
3702
- ss[j*TS + 0 ] = S[j];
3703
- ss[j*TS + 1 ] = M[j];
3704
- }
3698
+ for (short j = tiisg; j < Q; j += NW) {
3699
+ ss[j*TS + 0 ] = S[j];
3700
+ ss[j*TS + 1 ] = M[j];
3705
3701
}
3706
3702
}
3707
3703
3708
- // reduce the warps sequentially
3709
- for (ushort sg = 1 ; sg < nsg; ++sg) {
3710
- threadgroup_barrier (mem_flags::mem_threadgroup);
3704
+ threadgroup_barrier (mem_flags::mem_threadgroup);
3711
3705
3712
- // each simdgroup stores its output to shared memory, reusing sq
3713
- if (sgitg == sg) {
3714
- for (short i = 0 ; i < DV8; ++i) {
3715
- simdgroup_store (lo[i], so + i*8 , DV, 0 , false );
3716
- }
3706
+ threadgroup float * so = (threadgroup float *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3707
+ threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0 *DK);
3708
+
3709
+ // store result to shared memory in F32
3710
+ if (sgitg == 0 ) {
3711
+ for (short i = 0 ; i < DV8; ++i) {
3712
+ // simdgroup_store(lo[i], so + i*8, DV, 0, false);
3713
+ simdgroup_float8x8 t (1 .0f );
3714
+ simdgroup_multiply (t, lo[i], t);
3715
+ simdgroup_store (t, so + i*8 , DV, 0 , false );
3717
3716
}
3717
+ }
3718
3718
3719
- threadgroup_barrier (mem_flags::mem_threadgroup);
3719
+ threadgroup_barrier (mem_flags::mem_threadgroup);
3720
3720
3721
- // the first simdgroup accumulates the results from the other simdgroups
3722
- if (sgitg == 0 ) {
3723
- for (short j = 0 ; j < Q; ++j) {
3724
- const float S0 = ss[j*TS + 0 ];
3725
- const float S1 = ss[j*TS + sg*SH + 0 ];
3721
+ // reduce the warps sequentially
3722
+ for (ushort sg = 1 ; sg < nsg; ++sg) {
3723
+ if (sgitg == sg) {
3724
+ for (short j = tiisg; j < Q; j += NW) {
3725
+ const float S0 = ss[j*TS - 1 *SH + 0 ];
3726
+ const float S1 = ss[j*TS + 0 ];
3726
3727
3727
- const float M0 = ss[j*TS + 1 ];
3728
- const float M1 = ss[j*TS + sg*SH + 1 ];
3728
+ const float M0 = ss[j*TS - 1 *SH + 1 ];
3729
+ const float M1 = ss[j*TS + 1 ];
3729
3730
3730
3731
const float M = max (M0, M1);
3731
3732
3732
- const float ms0 = exp (M0 - M);
3733
- const float ms1 = exp (M1 - M);
3733
+ float ms0 = exp (M0 - M);
3734
+ float ms1 = exp (M1 - M);
3734
3735
3735
3736
const float S = S0*ms0 + S1*ms1;
3736
3737
3737
- if (tiisg == 0 ) {
3738
- ss[j*TS + 0 ] = S;
3739
- ss[j*TS + 1 ] = M;
3738
+ ss[j*TS + 0 ] = S;
3739
+ ss[j*TS + 1 ] = M;
3740
3740
3741
- ss[j*TS + 2 *C + j ] = ms0;
3742
- ss[j*TS + 2 *C + j + sg*SH] = ms1;
3743
- }
3741
+ ss[j*TS + 2 *C + j - 1 *SH] = ms0;
3742
+ ss[j*TS + 2 *C + j ] = ms1;
3744
3743
}
3745
3744
3745
+ // simdgroup_barrier(mem_flags::mem_threadgroup);
3746
+
3746
3747
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3747
3748
{
3748
3749
s8x8_t ms0;
3749
3750
s8x8_t ms1;
3750
3751
3751
- simdgroup_load (ms0, ss + 2 *C, TS, 0 , false );
3752
- simdgroup_load (ms1, ss + 2 *C + sg*SH, TS, 0 , false );
3752
+ simdgroup_load (ms0, ss + 2 *C - 1 *SH, TS, 0 , false );
3753
+ simdgroup_load (ms1, ss + 2 *C, TS, 0 , false );
3753
3754
3754
3755
#pragma unroll(DV8)
3755
3756
for (short i = 0 ; i < DV8; ++i) {
3756
- o8x8_t t;
3757
+ simdgroup_float8x8 t;
3757
3758
3758
3759
simdgroup_load (t, so + i*8 , DV, 0 , false );
3759
- simdgroup_multiply (t, ms1 , t);
3760
+ simdgroup_multiply (t, ms0 , t);
3760
3761
3761
- simdgroup_multiply_accumulate (lo[i], ms0, lo[i], t);
3762
+ simdgroup_multiply_accumulate (t, ms1, lo[i], t);
3763
+ simdgroup_store (t, so + i*8 , DV, 0 , false );
3762
3764
}
3763
3765
}
3764
3766
}
3765
- }
3766
3767
3767
- // store result to shared memory (reuse sq)
3768
- if (sgitg == 0 ) {
3769
- for (short i = 0 ; i < DV8; ++i) {
3770
- simdgroup_store (lo[i], so + i*8 , DV, 0 , false );
3771
- }
3768
+ threadgroup_barrier (mem_flags::mem_threadgroup);
3772
3769
}
3773
3770
3774
- threadgroup_barrier (mem_flags::mem_threadgroup);
3775
-
3776
- threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2 *Q*DK);
3771
+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2 *(nsg-1 )*SH + 2 *Q*DK);
3777
3772
3778
3773
// final rescale with 1/S and store to global memory
3779
3774
for (short j = sgitg; j < Q && iq1 + j < args.ne01 ; j += nsg) {
@@ -3796,17 +3791,17 @@ kernel void kernel_flash_attn_ext(
3796
3791
half, half4x4, simdgroup_half8x8, \
3797
3792
float , simdgroup_float8x8, \
3798
3793
float , simdgroup_float8x8, \
3799
- float , float4 , simdgroup_float8x8
3800
- // half , half4 , simdgroup_half8x8
3794
+ half , half4 , simdgroup_half8x8
3795
+ // float , float4 , simdgroup_float8x8
3801
3796
3802
3797
#define FA_TYPES_BF \
3803
3798
bfloat, bfloat4, simdgroup_bfloat8x8, \
3804
3799
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3805
3800
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3806
3801
float , simdgroup_float8x8, \
3807
3802
float , simdgroup_float8x8, \
3808
- float , float4 , simdgroup_float8x8
3809
- // half , half4 , simdgroup_half8x8
3803
+ half , half4 , simdgroup_half8x8
3804
+ // float , float4 , simdgroup_float8x8
3810
3805
3811
3806
typedef decltype (kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 , 64 >) flash_attn_ext_t;
3812
3807
0 commit comments