@@ -2581,7 +2581,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2581
2581
}
2582
2582
2583
2583
// pointer to the mask
2584
- device const half * mp = (device const half *) (mask + iq1*nb31);
2584
+ device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2585
2585
2586
2586
// prepare diagonal scale matrix
2587
2587
// simdgroup_half8x8 mscale(scale);
@@ -2597,23 +2597,23 @@ kernel void kernel_flash_attn_ext_vec_f16(
2597
2597
2598
2598
// Q*K^T
2599
2599
{
2600
- for (short cc = 0 ; cc < C; ++cc) {
2601
- half mqk[Q];
2600
+ for (short cc = 0 ; cc < C/ 4 ; ++cc) {
2601
+ half4 mqk[Q];
2602
2602
for (short j = 0 ; j < Q; ++j) {
2603
2603
mqk[j] = 0 .0h;
2604
2604
}
2605
2605
2606
- // device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
2607
- device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13));
2606
+ device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4 *cc)*nb11 + ik2*nb12 + ik3*nb13));
2608
2607
2609
2608
for (short i = tiisg; i < D4; i += NW) {
2610
- // simdgroup_half8x8 mk;
2611
- // simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
2612
- half4 mk = pk4[i];
2609
+ half4x4 mk;
2610
+ mk[0 ] = pk4[i + 0 *(nb11/8 )];
2611
+ mk[1 ] = pk4[i + 1 *(nb11/8 )];
2612
+ mk[2 ] = pk4[i + 2 *(nb11/8 )];
2613
+ mk[3 ] = pk4[i + 3 *(nb11/8 )];
2613
2614
2614
2615
for (short j = 0 ; j < Q; ++j) {
2615
- // simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
2616
- mqk[j] += dot (mq[j][i], mk);
2616
+ mqk[j] += mq[j][i] * mk;
2617
2617
}
2618
2618
}
2619
2619
@@ -2633,85 +2633,40 @@ kernel void kernel_flash_attn_ext_vec_f16(
2633
2633
// mqk = mqk*scale + mask
2634
2634
if (tiisg == 0 ) {
2635
2635
for (short j = 0 ; j < Q; ++j) {
2636
- // simdgroup_half8x8 mm;
2637
- // simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
2638
- // simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
2639
-
2640
- // simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
2641
-
2642
- half mm = mp[j*(nb31/sizeof (half)) + ic + cc];
2636
+ half4 mm = mp4[(j*(nb31/sizeof (half)) + ic)/4 + cc];
2643
2637
mqk[j] = mqk[j]*mscale + mm;
2644
2638
2645
- ss [j*T + cc] = mqk[j];
2639
+ ss4 [j*T4 + cc] = mqk[j];
2646
2640
}
2647
2641
}
2648
2642
}
2649
2643
}
2650
2644
2651
- // threadgroup_barrier (mem_flags::mem_threadgroup);
2645
+ simdgroup_barrier (mem_flags::mem_threadgroup);
2652
2646
2653
2647
// online softmax
2654
- if (C == 32 ) {
2655
- half ms[Q];
2656
-
2657
- for (short j = 0 ; j < Q; ++j) {
2658
- const short p = tiisg;
2659
-
2660
- const half m = M[j];
2661
- const half s = ss[j*T + p];
2662
-
2663
- M[j] = simd_max (max (M[j], s));
2664
-
2665
- ms[j] = m == -INFINITY ? 0 .0h : exp (m - M[j]);
2666
- const half vs = s == -INFINITY ? 0 .0h : exp (s - M[j]);
2667
-
2668
- S[j] = S[j]*ms[j] + simd_sum (vs);
2669
-
2670
- // the P matrix from the paper (Q rows, C columns)
2671
- ss[j*T + p] = vs;
2672
- }
2673
-
2674
- // create a QxQ diagonal matrix for rescaling the output
2675
- if (tiisg < Q) {
2676
- ss[tiisg*T + C + tiisg] = ms[tiisg];
2677
- }
2678
- } else {
2679
- half ms[Q];
2648
+ half ms[Q];
2680
2649
2681
- for (short j = 0 ; j < Q; ++j) {
2682
- const half m = M[j];
2683
-
2684
- for (short p = tiisg; p < C; p += NW) {
2685
- const half s = ss[j*T + p];
2686
-
2687
- M[j] = max (M[j], s);
2688
- }
2689
-
2690
- M[j] = simd_max (M[j]);
2691
-
2692
- ms[j] = m == -INFINITY ? 0 .0h : exp (m - M[j]);
2693
-
2694
- // local sum
2695
- half ls = 0 .0h;
2650
+ for (short j = 0 ; j < Q; ++j) {
2651
+ const short p = tiisg;
2696
2652
2697
- for ( short p = tiisg; p < C; p += NW) {
2698
- const half s = ss[j*T + p];
2653
+ const half m = M[j];
2654
+ const half s = ss[j*T + p];
2699
2655
2700
- const half vs = s == -INFINITY ? 0 .0h : exp (s - M[j]);
2656
+ M[j] = simd_max ( max ( M[j], s) );
2701
2657
2702
- ls += vs;
2658
+ ms[j] = m == -INFINITY ? 0 .0h : exp (m - M[j]);
2659
+ const half vs = s == -INFINITY ? 0 .0h : exp (s - M[j]);
2703
2660
2704
- // the P matrix from the paper (Q rows, C columns)
2705
- ss[j*T + p] = vs;
2706
- }
2661
+ S[j] = S[j]*ms[j] + simd_sum (vs);
2707
2662
2708
- S[j] = S[j]*ms[j] + simd_sum (ls);
2709
- }
2663
+ // the P matrix from the paper (Q rows, C columns)
2664
+ ss[j*T + p] = vs;
2665
+ }
2710
2666
2711
- // create a QxQ diagonal matrix for rescaling the output
2712
- if (tiisg < Q) {
2713
- ss[tiisg*T + C + tiisg] = ms[tiisg];
2714
- }
2667
+ // create a QxQ diagonal matrix for rescaling the output
2668
+ if (tiisg < Q) {
2669
+ ss[tiisg*T + C + tiisg] = ms[tiisg];
2715
2670
}
2716
2671
2717
2672
// threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -2733,7 +2688,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
2733
2688
for (short cc = 0 ; cc < C; ++cc) {
2734
2689
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
2735
2690
2736
- half vsum[Q];
2737
2691
for (short i = tiisg; i < D4; i += NW) {
2738
2692
for (short j = 0 ; j < Q; ++j) {
2739
2693
lo[j][i] += pv4[i]*ss[j*T + cc];
0 commit comments