@@ -3333,8 +3333,6 @@ kernel void kernel_flash_attn_ext(
3333
3333
3334
3334
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0 *DK); // holds the query data
3335
3335
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3336
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3337
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0 *DK); // same as above but in o4_t
3338
3336
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2 *sgitg*SH + 2 *Q*DK); // scratch buffer for attention, mask and diagonal matrix
3339
3337
3340
3338
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4 *16 *KV) + Q*T); // scratch buffer to load K in shared memory
@@ -3548,20 +3546,20 @@ kernel void kernel_flash_attn_ext(
3548
3546
3549
3547
// O = diag(ms)*O
3550
3548
{
3551
- s8x8_t mm ;
3552
- simdgroup_load (mm , ss + 2 *C, TS, 0 , false );
3549
+ s8x8_t ms ;
3550
+ simdgroup_load (ms , ss + 2 *C, TS, 0 , false );
3553
3551
3554
3552
#pragma unroll(DV8)
3555
3553
for (short i = 0 ; i < DV8; ++i) {
3556
- simdgroup_multiply (lo[i], mm , lo[i]);
3554
+ simdgroup_multiply (lo[i], ms , lo[i]);
3557
3555
}
3558
3556
}
3559
3557
3560
3558
// O = O + (Q*K^T)*V
3561
3559
{
3562
3560
for (short cc = 0 ; cc < C/8 ; ++cc) {
3563
- s8x8_t ms ;
3564
- simdgroup_load (ms , ss + 8 *cc, TS, 0 , false );
3561
+ s8x8_t vs ;
3562
+ simdgroup_load (vs , ss + 8 *cc, TS, 0 , false );
3565
3563
3566
3564
if (is_same<vd4x4_t , v4x4_t >::value) {
3567
3565
// we can read directly from global memory
@@ -3572,7 +3570,7 @@ kernel void kernel_flash_attn_ext(
3572
3570
v8x8_t mv;
3573
3571
simdgroup_load (mv, pv + i*8 , args.nb21 /sizeof (v_t ), 0 , false ); // TODO: use ne20
3574
3572
3575
- simdgroup_multiply_accumulate (lo[i], ms , mv, lo[i]);
3573
+ simdgroup_multiply_accumulate (lo[i], vs , mv, lo[i]);
3576
3574
}
3577
3575
} else {
3578
3576
for (short ii = 0 ; ii < DV16; ii += 4 ) {
@@ -3593,10 +3591,10 @@ kernel void kernel_flash_attn_ext(
3593
3591
v8x8_t mv;
3594
3592
3595
3593
simdgroup_load (mv, sv + 16 *k + 0 *8 , 4 *16 , 0 , false );
3596
- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], ms , mv, lo[2 *(ii + k) + 0 ]);
3594
+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], vs , mv, lo[2 *(ii + k) + 0 ]);
3597
3595
3598
3596
simdgroup_load (mv, sv + 16 *k + 1 *8 , 4 *16 , 0 , false );
3599
- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], ms , mv, lo[2 *(ii + k) + 1 ]);
3597
+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], vs , mv, lo[2 *(ii + k) + 1 ]);
3600
3598
}
3601
3599
} else {
3602
3600
if (ii + tx < DV16) {
@@ -3611,10 +3609,10 @@ kernel void kernel_flash_attn_ext(
3611
3609
v8x8_t mv;
3612
3610
3613
3611
simdgroup_load (mv, sv + 16 *k + 0 *8 , 4 *16 , 0 , false );
3614
- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], ms , mv, lo[2 *(ii + k) + 0 ]);
3612
+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], vs , mv, lo[2 *(ii + k) + 0 ]);
3615
3613
3616
3614
simdgroup_load (mv, sv + 16 *k + 1 *8 , 4 *16 , 0 , false );
3617
- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], ms , mv, lo[2 *(ii + k) + 1 ]);
3615
+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], vs , mv, lo[2 *(ii + k) + 1 ]);
3618
3616
}
3619
3617
}
3620
3618
}
@@ -3624,83 +3622,80 @@ kernel void kernel_flash_attn_ext(
3624
3622
}
3625
3623
3626
3624
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3627
- for (short j = 0 ; j < Q; ++j) {
3628
- if (tiisg == 0 ) {
3629
- ss[j*TS + 0 ] = S[j];
3630
- ss[j*TS + 1 ] = M[j];
3631
- }
3625
+ for (short j = tiisg; j < Q; j += NW) {
3626
+ ss[j*TS + 0 ] = S[j];
3627
+ ss[j*TS + 1 ] = M[j];
3632
3628
}
3633
3629
}
3634
3630
3635
- // reduce the warps sequentially
3636
- for (ushort sg = 1 ; sg < nsg; ++sg) {
3637
- threadgroup_barrier (mem_flags::mem_threadgroup);
3631
+ threadgroup_barrier (mem_flags::mem_threadgroup);
3638
3632
3639
- // each simdgroup stores its output to shared memory, reusing sq
3640
- if (sgitg == sg) {
3641
- for (short i = 0 ; i < DV8; ++i) {
3642
- simdgroup_store (lo[i], so + i*8 , DV, 0 , false );
3643
- }
3633
+ threadgroup float * so = (threadgroup float *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3634
+ threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0 *DK);
3635
+
3636
+ // store result to shared memory in F32
3637
+ if (sgitg == 0 ) {
3638
+ for (short i = 0 ; i < DV8; ++i) {
3639
+ // simdgroup_store(lo[i], so + i*8, DV, 0, false);
3640
+ simdgroup_float8x8 t (1 .0f );
3641
+ simdgroup_multiply (t, lo[i], t);
3642
+ simdgroup_store (t, so + i*8 , DV, 0 , false );
3644
3643
}
3644
+ }
3645
3645
3646
- threadgroup_barrier (mem_flags::mem_threadgroup);
3646
+ threadgroup_barrier (mem_flags::mem_threadgroup);
3647
3647
3648
- // the first simdgroup accumulates the results from the other simdgroups
3649
- if (sgitg == 0 ) {
3650
- for (short j = 0 ; j < Q; ++j) {
3651
- const float S0 = ss[j*TS + 0 ];
3652
- const float S1 = ss[j*TS + sg*SH + 0 ];
3648
+ // reduce the warps sequentially
3649
+ for (ushort sg = 1 ; sg < nsg; ++sg) {
3650
+ if (sgitg == sg) {
3651
+ for (short j = tiisg; j < Q; j += NW) {
3652
+ const float S0 = ss[j*TS - 1 *SH + 0 ];
3653
+ const float S1 = ss[j*TS + 0 ];
3653
3654
3654
- const float M0 = ss[j*TS + 1 ];
3655
- const float M1 = ss[j*TS + sg*SH + 1 ];
3655
+ const float M0 = ss[j*TS - 1 *SH + 1 ];
3656
+ const float M1 = ss[j*TS + 1 ];
3656
3657
3657
3658
const float M = max (M0, M1);
3658
3659
3659
- const float ms0 = exp (M0 - M);
3660
- const float ms1 = exp (M1 - M);
3660
+ float ms0 = exp (M0 - M);
3661
+ float ms1 = exp (M1 - M);
3661
3662
3662
3663
const float S = S0*ms0 + S1*ms1;
3663
3664
3664
- if (tiisg == 0 ) {
3665
- ss[j*TS + 0 ] = S;
3666
- ss[j*TS + 1 ] = M;
3665
+ ss[j*TS + 0 ] = S;
3666
+ ss[j*TS + 1 ] = M;
3667
3667
3668
- ss[j*TS + 2 *C + j ] = ms0;
3669
- ss[j*TS + 2 *C + j + sg*SH] = ms1;
3670
- }
3668
+ ss[j*TS + 2 *C + j - 1 *SH] = ms0;
3669
+ ss[j*TS + 2 *C + j ] = ms1;
3671
3670
}
3672
3671
3672
+ // simdgroup_barrier(mem_flags::mem_threadgroup);
3673
+
3673
3674
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3674
3675
{
3675
3676
s8x8_t ms0;
3676
3677
s8x8_t ms1;
3677
3678
3678
- simdgroup_load (ms0, ss + 2 *C, TS, 0 , false );
3679
- simdgroup_load (ms1, ss + 2 *C + sg*SH, TS, 0 , false );
3679
+ simdgroup_load (ms0, ss + 2 *C - 1 *SH, TS, 0 , false );
3680
+ simdgroup_load (ms1, ss + 2 *C, TS, 0 , false );
3680
3681
3681
3682
#pragma unroll(DV8)
3682
3683
for (short i = 0 ; i < DV8; ++i) {
3683
- o8x8_t t;
3684
+ simdgroup_float8x8 t;
3684
3685
3685
3686
simdgroup_load (t, so + i*8 , DV, 0 , false );
3686
- simdgroup_multiply (t, ms1 , t);
3687
+ simdgroup_multiply (t, ms0 , t);
3687
3688
3688
- simdgroup_multiply_accumulate (lo[i], ms0, lo[i], t);
3689
+ simdgroup_multiply_accumulate (t, ms1, lo[i], t);
3690
+ simdgroup_store (t, so + i*8 , DV, 0 , false );
3689
3691
}
3690
3692
}
3691
3693
}
3692
- }
3693
3694
3694
- // store result to shared memory (reuse sq)
3695
- if (sgitg == 0 ) {
3696
- for (short i = 0 ; i < DV8; ++i) {
3697
- simdgroup_store (lo[i], so + i*8 , DV, 0 , false );
3698
- }
3695
+ threadgroup_barrier (mem_flags::mem_threadgroup);
3699
3696
}
3700
3697
3701
- threadgroup_barrier (mem_flags::mem_threadgroup);
3702
-
3703
- threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2 *Q*DK);
3698
+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2 *(nsg-1 )*SH + 2 *Q*DK);
3704
3699
3705
3700
// final rescale with 1/S and store to global memory
3706
3701
for (short j = sgitg; j < Q && iq1 + j < args.ne01 ; j += nsg) {
@@ -3723,17 +3718,17 @@ kernel void kernel_flash_attn_ext(
3723
3718
half, half4x4, simdgroup_half8x8, \
3724
3719
float , simdgroup_float8x8, \
3725
3720
float , simdgroup_float8x8, \
3726
- float , float4 , simdgroup_float8x8
3727
- // half , half4 , simdgroup_half8x8
3721
+ half , half4 , simdgroup_half8x8
3722
+ // float , float4 , simdgroup_float8x8
3728
3723
3729
3724
#define FA_TYPES_BF \
3730
3725
bfloat, bfloat4, simdgroup_bfloat8x8, \
3731
3726
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3732
3727
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3733
3728
float , simdgroup_float8x8, \
3734
3729
float , simdgroup_float8x8, \
3735
- float , float4 , simdgroup_float8x8
3736
- // half , half4 , simdgroup_half8x8
3730
+ half , half4 , simdgroup_half8x8
3731
+ // float , float4 , simdgroup_float8x8
3737
3732
3738
3733
typedef decltype (kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 , 64 >) flash_attn_ext_t;
3739
3734
0 commit comments