@@ -2457,6 +2457,8 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f
2457
2457
template [[host_name(" kernel_flash_attn_ext_f16_h128" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128 , 8 , 32 >;
2458
2458
template [[host_name(" kernel_flash_attn_ext_f16_h256" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256 , 8 , 32 >;
2459
2459
2460
+ #define HALF_MAX_HALF half (65504 .0f /2 ) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
2461
+
2460
2462
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
2461
2463
kernel void kernel_flash_attn_ext_vec_f16(
2462
2464
device const char * q,
@@ -2500,6 +2502,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2500
2502
const short iq1 = tgpig[0 ]*Q;
2501
2503
2502
2504
const short D4 = D/4 ;
2505
+ const short D8 = D/8 ;
2503
2506
const short NW = N_SIMDWIDTH;
2504
2507
const short SH = (C + Q); // shared memory per simdgroup in (half)
2505
2508
@@ -2510,6 +2513,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2510
2513
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
2511
2514
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
2512
2515
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1 *D); // same as above but in half4
2516
+ threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
2513
2517
2514
2518
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2515
2519
half4 lo[Q][D4];
@@ -2545,7 +2549,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2545
2549
2546
2550
{
2547
2551
half S[Q] = { [0 ... Q-1 ] = 0 .0h };
2548
- half M[Q] = { [0 ... Q-1 ] = -INFINITY };
2552
+ half M[Q] = { [0 ... Q-1 ] = -HALF_MAX_HALF };
2549
2553
2550
2554
// assume K and V are same shape
2551
2555
const short ne22 = ne12;
@@ -2571,21 +2575,21 @@ kernel void kernel_flash_attn_ext_vec_f16(
2571
2575
const short iv3 = iq3 / rv3;
2572
2576
2573
2577
// load the queries from shared memory into local memory
2574
- half4 mq[Q][D4 ];
2578
+ simdgroup_half8x8 mq[Q][D8 ];
2575
2579
2576
2580
for (short j = 0 ; j < Q; ++j) {
2577
- for (short i = tiisg; i < D4; i += NW) {
2578
- // simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
2579
- mq[j][i] = sq4[j*T4 + i];
2581
+ for (short i = 0 ; i < D8; ++i) {
2582
+ simdgroup_load (mq[j][i], sq + 8 *j*T + i*8 , T);
2580
2583
}
2581
2584
}
2582
2585
2583
2586
// pointer to the mask
2584
- device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2587
+ // device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2588
+ device const half * mp = (device const half *) (mask + iq1*nb31);
2585
2589
2586
2590
// prepare diagonal scale matrix
2587
- // simdgroup_half8x8 mscale(scale);
2588
- half mscale (scale);
2591
+ simdgroup_half8x8 mscale (scale);
2592
+ // half mscale(scale);
2589
2593
2590
2594
// loop over the KV cache
2591
2595
// each simdgroup handles blocks of Q rows and C columns
@@ -2595,55 +2599,83 @@ kernel void kernel_flash_attn_ext_vec_f16(
2595
2599
break ;
2596
2600
}
2597
2601
2602
+ // Q*K^T
2603
+ // {
2604
+ // for (short cc = 0; cc < C/4; ++cc) {
2605
+ // half4 mqk[Q];
2606
+ // for (short j = 0; j < Q; ++j) {
2607
+ // mqk[j] = 0.0h;
2608
+ // }
2609
+
2610
+ // device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
2611
+
2612
+ // for (short i = tiisg; i < D4; i += NW) {
2613
+ // half4x4 mk;
2614
+ // mk[0] = pk4[i + 0*(nb11/8)];
2615
+ // mk[1] = pk4[i + 1*(nb11/8)];
2616
+ // mk[2] = pk4[i + 2*(nb11/8)];
2617
+ // mk[3] = pk4[i + 3*(nb11/8)];
2618
+
2619
+ // for (short j = 0; j < Q; ++j) {
2620
+ // mqk[j] += mq[j][i] * mk;
2621
+ // }
2622
+ // }
2623
+
2624
+ // // reduce the results from the threads in the simdgroup
2625
+ // simdgroup_barrier(mem_flags::mem_none);
2626
+
2627
+ // for (short i = NW/2; i > 0; i /= 2) {
2628
+ // if (tiisg < i) {
2629
+ // for (short j = 0; j < Q; ++j) {
2630
+ // mqk[j] += simd_shuffle_down(mqk[j], i);
2631
+ // }
2632
+ // }
2633
+
2634
+ // simdgroup_barrier(mem_flags::mem_none);
2635
+ // }
2636
+
2637
+ // // mqk = mqk*scale + mask
2638
+ // if (tiisg == 0) {
2639
+ // for (short j = 0; j < Q; ++j) {
2640
+ // half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
2641
+ // mqk[j] = mqk[j]*mscale + mm;
2642
+
2643
+ // ss4[j*T4 + cc] = mqk[j];
2644
+ // }
2645
+ // }
2646
+ // }
2647
+ // }
2648
+
2598
2649
// Q*K^T
2599
2650
{
2600
- for (short cc = 0 ; cc < C/4 ; ++cc) {
2601
- half4 mqk[Q];
2651
+ for (short cc = 0 ; cc < C/8 ; ++cc) {
2652
+ simdgroup_half8x8 mqk[Q];
2602
2653
for (short j = 0 ; j < Q; ++j) {
2603
- mqk[j] = 0 .0h ;
2654
+ mqk[j] = make_filled_simdgroup_matrix<half, 8 >( 0 . h ) ;
2604
2655
}
2605
2656
2606
- device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4 *cc)*nb11 + ik2*nb12 + ik3*nb13));
2657
+ device const half * pk = (device const half *) ((device const char *) k + ((ic + 8 *cc)*nb11 + ik2*nb12 + ik3*nb13));
2607
2658
2608
- for (short i = tiisg; i < D4; i += NW) {
2609
- half4x4 mk;
2610
- mk[0 ] = pk4[i + 0 *(nb11/8 )];
2611
- mk[1 ] = pk4[i + 1 *(nb11/8 )];
2612
- mk[2 ] = pk4[i + 2 *(nb11/8 )];
2613
- mk[3 ] = pk4[i + 3 *(nb11/8 )];
2659
+ for (short i = 0 ; i < D8; ++i) {
2660
+ simdgroup_half8x8 mk;
2661
+ simdgroup_load (mk, pk + i*8 , nb11/sizeof (half), 0 , true ); // transpose
2614
2662
2615
2663
for (short j = 0 ; j < Q; ++j) {
2616
- mqk[j] += mq[j][i] * mk;
2617
- }
2618
- }
2619
-
2620
- // reduce the results from the threads in the simdgroup
2621
- simdgroup_barrier (mem_flags::mem_none);
2622
-
2623
- for (short i = NW/2 ; i > 0 ; i /= 2 ) {
2624
- if (tiisg < i) {
2625
- for (short j = 0 ; j < Q; ++j) {
2626
- mqk[j] += simd_shuffle_down (mqk[j], i);
2627
- }
2664
+ simdgroup_multiply_accumulate (mqk[j], mq[j][i], mk, mqk[j]);
2628
2665
}
2629
-
2630
- simdgroup_barrier (mem_flags::mem_none);
2631
2666
}
2632
2667
2633
2668
// mqk = mqk*scale + mask
2634
- if (tiisg == 0 ) {
2635
- for ( short j = 0 ; j < Q; ++j) {
2636
- half4 mm = mp4[( j*(nb31/sizeof (half)) + ic)/ 4 + cc] ;
2637
- mqk[j] = mqk[j]* mscale + mm ;
2669
+ for ( short j = 0 ; j < Q; ++j ) {
2670
+ simdgroup_half8x8 mm;
2671
+ simdgroup_load (mm, mp + 8 * j*(nb31/sizeof (half)) + ic + 8 *cc, nb31/ sizeof (half), 0 , false ) ;
2672
+ simdgroup_multiply_accumulate ( mqk[j], mqk[j], mscale, mm) ;
2638
2673
2639
- ss4[j*T4 + cc] = mqk[j];
2640
- }
2674
+ simdgroup_store (mqk[j], ss + 8 *j*T + 8 *cc, T, 0 , false );
2641
2675
}
2642
2676
}
2643
2677
}
2644
2678
2645
- simdgroup_barrier (mem_flags::mem_threadgroup);
2646
-
2647
2679
// online softmax
2648
2680
half ms[Q];
2649
2681
@@ -2655,8 +2687,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
2655
2687
2656
2688
M[j] = simd_max (max (M[j], s));
2657
2689
2658
- ms[j] = m == -INFINITY ? 0 .0h : exp (m - M[j]);
2659
- const half vs = s == -INFINITY ? 0 .0h : exp (s - M[j]);
2690
+ ms[j] = exp (m - M[j]);
2691
+ const half vs = exp (s - M[j]);
2660
2692
2661
2693
S[j] = S[j]*ms[j] + simd_sum (vs);
2662
2694
@@ -2706,75 +2738,59 @@ kernel void kernel_flash_attn_ext_vec_f16(
2706
2738
}
2707
2739
}
2708
2740
2709
- // reduce the warps sequentially
2710
- for (short sg = 1 ; sg < nsg; ++sg) {
2711
- half S = { 0 .0h };
2712
- half M = { -INFINITY };
2713
-
2714
- threadgroup_barrier (mem_flags::mem_threadgroup);
2741
+ threadgroup_barrier (mem_flags::mem_threadgroup);
2715
2742
2716
- // each simdgroup stores its output to shared memory, reusing sq
2717
- if (sgitg == sg) {
2718
- for (short j = 0 ; j < Q; ++j) {
2719
- for (short i = tiisg; i < D4; i += NW) {
2720
- // simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
2721
- sq4[j*T4 + i] = lo[j][i];
2722
- }
2723
- }
2743
+ // store results to shared memory
2744
+ for (short j = 0 ; j < Q; ++j) {
2745
+ for (short i = tiisg; i < D4; i += NW) {
2746
+ sr4[i] = lo[j][i];
2724
2747
}
2748
+ }
2725
2749
2726
- threadgroup_barrier (mem_flags::mem_threadgroup);
2727
-
2728
- // the first simdgroup accumulates the results from the other simdgroups
2729
- if (sgitg == 0 ) {
2730
- for (short j = 0 ; j < Q; ++j) {
2731
- const half S0 = ss[j*T + 0 ];
2732
- const half S1 = ss[j*T + sg *SH + 0 ];
2750
+ // parallel reduce
2751
+ for ( short r = nsg/ 2 ; r > 0 ; r >>= 1 ) {
2752
+ if (sgitg < r) {
2753
+ if (tiisg == 0 ) {
2754
+ for (short j = 0 ; j < Q; ++j) {
2755
+ const half S0 = ss[j*T + 0 ];
2756
+ const half S1 = ss[j*T + r *SH + 0 ];
2733
2757
2734
- const half M0 = ss[j*T + 1 ];
2735
- const half M1 = ss[j*T + sg *SH + 1 ];
2758
+ const half M0 = ss[j*T + 1 ];
2759
+ const half M1 = ss[j*T + r *SH + 1 ];
2736
2760
2737
- M = max (M0, M1);
2761
+ const half M = max (M0, M1);
2738
2762
2739
- const half ms0 = M0 == -INFINITY ? 0 .0h : exp (M0 - M);
2740
- const half ms1 = M1 == -INFINITY ? 0 .0h : exp (M1 - M);
2763
+ const half ms0 = exp (M0 - M);
2764
+ const half ms1 = exp (M1 - M);
2741
2765
2742
- S = S0*ms0 + S1*ms1;
2766
+ const half S = S0*ms0 + S1*ms1;
2743
2767
2744
- if (tiisg == 0 ) {
2745
2768
ss[j*T + 0 ] = S;
2746
2769
ss[j*T + 1 ] = M;
2747
2770
2748
- ss[j*T + C + j ] = ms0;
2749
- ss[j*T + C + j + sg *SH] = ms1;
2771
+ ss[j*T + C + j ] = ms0;
2772
+ ss[j*T + C + j + r *SH] = ms1;
2750
2773
}
2751
2774
}
2775
+ }
2752
2776
2753
- // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2777
+ threadgroup_barrier (mem_flags::mem_threadgroup);
2778
+
2779
+ if (sgitg < r) {
2754
2780
for (short j = 0 ; j < Q; ++j) {
2755
- for (short i = tiisg; i < D4; i += NW) {
2756
- half4 t = sq4[j*T4 + i];
2757
- half ms0 = ss[j*T + C + j];
2758
- half ms1 = ss[j*T + C + j + sg*SH];
2781
+ const half ms0 = ss[j*T + C + j];
2782
+ const half ms1 = ss[j*T + C + j + r*SH];
2759
2783
2760
- lo[j][i] = lo[j][i]*ms0 + t*ms1;
2784
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2785
+ for (short i = tiisg; i < D4; i += NW) {
2786
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
2761
2787
}
2762
2788
}
2763
2789
}
2764
- }
2765
2790
2766
- // store result to shared memory (reuse sq)
2767
- if (sgitg == 0 ) {
2768
- for (short j = 0 ; j < Q; ++j) {
2769
- for (short i = tiisg; i < D4; i += NW) {
2770
- // simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
2771
- sq4[j*T4 + i] = lo[j][i];
2772
- }
2773
- }
2791
+ threadgroup_barrier (mem_flags::mem_threadgroup);
2774
2792
}
2775
2793
2776
- threadgroup_barrier (mem_flags::mem_threadgroup);
2777
-
2778
2794
device float4 * dst4 = (device float4 *) dst;
2779
2795
2780
2796
// final rescale with 1/S and store to global memory
@@ -2783,7 +2799,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2783
2799
const half S = ss[j*T + 0 ];
2784
2800
2785
2801
for (short i = tiisg; i < D4; i += NW) {
2786
- dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
2802
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sr4[ i]/S;
2787
2803
}
2788
2804
}
2789
2805
}
0 commit comments