@@ -2459,7 +2459,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f
2459
2459
2460
2460
#define HALF_MAX_HALF half (65504 .0f /2 ) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
2461
2461
2462
- template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
2462
+ template<int64_t D, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
2463
2463
kernel void kernel_flash_attn_ext_vec_f16(
2464
2464
device const char * q,
2465
2465
device const char * k,
@@ -2499,12 +2499,12 @@ kernel void kernel_flash_attn_ext_vec_f16(
2499
2499
2500
2500
const short iq3 = tgpig[2 ];
2501
2501
const short iq2 = tgpig[1 ];
2502
- const short iq1 = tgpig[0 ]*Q ;
2502
+ const short iq1 = tgpig[0 ];
2503
2503
2504
2504
const short D4 = D/4 ;
2505
2505
const short D8 = D/8 ;
2506
2506
const short NW = N_SIMDWIDTH;
2507
- const short SH = (C + Q ); // shared memory per simdgroup in (half)
2507
+ const short SH = (C + 1 ); // shared memory per simdgroup in (half)
2508
2508
2509
2509
const short T = D + nsg*SH; // shared memory size per query in (half)
2510
2510
const short T4 = T/4 ; // shared memory size per query in (half4)
@@ -2513,43 +2513,37 @@ kernel void kernel_flash_attn_ext_vec_f16(
2513
2513
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
2514
2514
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
2515
2515
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1 *D); // same as above but in half4
2516
- threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q *T); // scratch buffer for the results
2516
+ threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1 *T); // scratch buffer for the results
2517
2517
2518
2518
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2519
- half4 lo[Q][ D4/NW];
2519
+ half4 lo[D4/NW];
2520
2520
2521
2521
// load heads from Q to shared memory
2522
- for (short j = sgitg; j < Q; j += nsg) {
2523
- device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
2522
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
2524
2523
2525
- for (short i = tiisg; i < D4; i += NW) {
2526
- if (iq1 + j < ne01) {
2527
- sq4[j*T4 + i] = (half4) q4[i];
2528
- } else {
2529
- sq4[j*T4 + i] = 0 .0h;
2530
- }
2524
+ for (short i = tiisg; i < D4; i += NW) {
2525
+ if (iq1 < ne01) {
2526
+ sq4[i] = (half4) q4[i];
2527
+ } else {
2528
+ sq4[i] = 0 .0h;
2531
2529
}
2532
2530
}
2533
2531
2534
2532
// zero out lo
2535
- for (short j = 0 ; j < Q; ++j) {
2536
- for (short i = tiisg; i < D4; i += NW) {
2537
- lo[j][i/NW] = 0 .0h;
2538
- }
2533
+ for (short i = tiisg; i < D4; i += NW) {
2534
+ lo[i/NW] = 0 .0h;
2539
2535
}
2540
2536
2541
2537
// zero out shared memory SH
2542
- for (short j = 0 ; j < Q; ++j) {
2543
- for (short i = tiisg; i < SH/4 ; i += NW) {
2544
- ss4[j*T4 + i] = 0 .0h;
2545
- }
2538
+ for (short i = tiisg; i < SH/4 ; i += NW) {
2539
+ ss4[i] = 0 .0h;
2546
2540
}
2547
2541
2548
2542
threadgroup_barrier (mem_flags::mem_threadgroup);
2549
2543
2550
2544
{
2551
- half S[Q] = { [ 0 ... Q- 1 ] = 0 .0h };
2552
- half M[Q] = { [ 0 ... Q- 1 ] = -HALF_MAX_HALF };
2545
+ half S = { 0 .0h };
2546
+ half M = { -HALF_MAX_HALF };
2553
2547
2554
2548
// assume K and V are same shape
2555
2549
const short ne22 = ne12;
@@ -2575,21 +2569,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
2575
2569
const short iv3 = iq3 / rv3;
2576
2570
2577
2571
// load the queries from shared memory into local memory
2578
- half4 mq[Q][ D4];
2572
+ half4 mq[D4];
2579
2573
2580
- for (short j = 0 ; j < Q; ++j) {
2581
- for (short ii = 0 ; ii < D4; ii += NW) {
2582
- short i = ii + tiisg;
2583
- mq[j][i] = sq4[j*T4 + i];
2584
- }
2574
+ for (short ii = 0 ; ii < D4; ii += NW) {
2575
+ short i = ii + tiisg;
2576
+ mq[i] = sq4[i];
2585
2577
}
2586
2578
2587
2579
// pointer to the mask
2588
2580
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2589
2581
2590
- // prepare diagonal scale matrix
2591
- half mscale (scale);
2592
-
2593
2582
// loop over the KV cache
2594
2583
// each simdgroup handles blocks of Q rows and C columns
2595
2584
for (int ic0 = 0 ; ic0 < ne11; ic0 += C*nsg) {
@@ -2600,8 +2589,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
2600
2589
2601
2590
// Q*K^T
2602
2591
{
2592
+ #pragma unroll
2603
2593
for (short cc = 0 ; cc < C/4 ; ++cc) {
2604
- half4 mqk[Q] = { [ 0 ... Q- 1 ] = 0 .0h };
2594
+ half4 mqk = { 0 .0h };
2605
2595
2606
2596
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4 *cc)*nb11 + ik2*nb12 + ik3*nb13));
2607
2597
@@ -2615,142 +2605,110 @@ kernel void kernel_flash_attn_ext_vec_f16(
2615
2605
mk[2 ] = pk4[i + 2 *(nb11/8 )];
2616
2606
mk[3 ] = pk4[i + 3 *(nb11/8 )];
2617
2607
2618
- for (short j = 0 ; j < Q; ++j) {
2619
- mqk[j] += mq[j][i] * mk;
2620
- }
2608
+ mqk += mq[i] * mk;
2621
2609
}
2622
2610
2623
2611
// reduce the results from the threads in the simdgroup
2624
- for (short j = 0 ; j < Q; ++j) {
2625
- mqk[j] += simd_shuffle_down (mqk[j], 16 );
2626
- mqk[j] += simd_shuffle_down (mqk[j], 8 );
2627
- mqk[j] += simd_shuffle_down (mqk[j], 4 );
2628
- mqk[j] += simd_shuffle_down (mqk[j], 2 );
2629
- mqk[j] += simd_shuffle_down (mqk[j], 1 );
2630
- }
2612
+ mqk += simd_shuffle_down (mqk, 16 );
2613
+ mqk += simd_shuffle_down (mqk, 8 );
2614
+ mqk += simd_shuffle_down (mqk, 4 );
2615
+ mqk += simd_shuffle_down (mqk, 2 );
2616
+ mqk += simd_shuffle_down (mqk, 1 );
2631
2617
2632
2618
// mqk = mqk*scale + mask
2633
2619
if (tiisg == 0 ) {
2634
- for (short j = 0 ; j < Q; ++j) {
2635
- half4 mm = mp4[(j*(nb31/sizeof (half)) + ic)/4 + cc];
2636
- mqk[j] = mqk[j]*mscale + mm;
2620
+ half4 mm = mp4[ic/4 + cc];
2621
+ mqk = mqk*scale + mm;
2637
2622
2638
- ss4[j*T4 + cc] = mqk[j];
2639
- }
2623
+ ss4[cc] = mqk;
2640
2624
}
2641
2625
}
2642
2626
}
2643
2627
2644
2628
// online softmax
2645
- half ms[Q];
2646
-
2647
- for (short j = 0 ; j < Q; ++j) {
2629
+ {
2648
2630
const short p = tiisg;
2649
2631
2650
- const half m = M[j] ;
2651
- const half s = ss[j*T + p];
2632
+ const half m = M;
2633
+ const half s = ss[p];
2652
2634
2653
- M[j] = simd_max (max (M[j] , s));
2635
+ M = simd_max (max (M, s));
2654
2636
2655
- ms[j] = exp (m - M[j] );
2656
- const half vs = exp (s - M[j] );
2637
+ const half ms = exp (m - M);
2638
+ const half vs = exp (s - M);
2657
2639
2658
- S[j] = S[j] *ms[j] + simd_sum (vs);
2640
+ S = S*ms + simd_sum (vs);
2659
2641
2660
2642
// the P matrix from the paper (Q rows, C columns)
2661
- ss[j*T + p] = vs;
2662
- }
2663
-
2664
- // create a QxQ diagonal matrix for rescaling the output
2665
- if (tiisg < Q) {
2666
- ss[tiisg*T + C + tiisg] = ms[tiisg];
2667
- }
2668
-
2669
- // O = diag(ms)*O
2670
- for (short j = 0 ; j < Q; ++j) {
2671
- half mm (ss[j*T + C + j]);
2643
+ ss[p] = vs;
2672
2644
2645
+ // O = diag(ms)*O
2673
2646
#pragma unroll
2674
2647
for (short ii = 0 ; ii < D4; ii += NW) {
2675
2648
const short i = ii + tiisg;
2676
- lo[j][ i/NW] = lo[j][i/NW]*mm ;
2649
+ lo[i/NW] *= ms ;
2677
2650
}
2678
2651
}
2679
2652
2680
2653
// O = O + (Q*K^T)*V
2681
2654
{
2682
2655
#pragma unroll
2683
- for (short cc = 0 ; cc < C; ++cc) {
2684
- device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
2656
+ for (short cc = 0 ; cc < C/ 4 ; ++cc) {
2657
+ device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4 * cc)*nb21 + iv2*nb22 + iv3*nb23));
2685
2658
2686
2659
#pragma unroll
2687
2660
for (short ii = 0 ; ii < D4; ii += NW) {
2688
- short i = ii + tiisg;
2689
- for (short j = 0 ; j < Q; ++j) {
2690
- lo[j][i/NW] += pv4[i]*ss[j*T + cc];
2691
- }
2661
+ const short i = ii + tiisg;
2662
+ lo[i/NW] += pv4[i + 0 *(nb21/8 )] * ss[4 *cc + 0 ];
2663
+ lo[i/NW] += pv4[i + 1 *(nb21/8 )] * ss[4 *cc + 1 ];
2664
+ lo[i/NW] += pv4[i + 2 *(nb21/8 )] * ss[4 *cc + 2 ];
2665
+ lo[i/NW] += pv4[i + 3 *(nb21/8 )] * ss[4 *cc + 3 ];
2692
2666
}
2693
2667
}
2694
2668
}
2669
+
2695
2670
}
2696
2671
2697
2672
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
2698
- for (short j = 0 ; j < Q; ++j) {
2699
- if (tiisg == 0 ) {
2700
- ss[j*T + 0 ] = S[j];
2701
- ss[j*T + 1 ] = M[j];
2702
- }
2673
+ if (tiisg == 0 ) {
2674
+ ss[0 ] = S;
2675
+ ss[1 ] = M;
2703
2676
}
2704
2677
}
2705
2678
2706
2679
// store results to shared memory
2707
- for (short j = 0 ; j < Q; ++j) {
2708
- for (short ii = 0 ; ii < D4; ii += NW) {
2709
- short i = ii + tiisg;
2710
- sr4[i] = lo[j][ii/NW];
2711
- }
2680
+ for (short ii = 0 ; ii < D4; ii += NW) {
2681
+ short i = ii + tiisg;
2682
+ sr4[i] = lo[ii/NW];
2712
2683
}
2713
2684
2714
2685
threadgroup_barrier (mem_flags::mem_threadgroup);
2715
2686
2716
2687
// parallel reduce
2717
2688
for (short r = nsg/2 ; r > 0 ; r >>= 1 ) {
2718
2689
if (sgitg < r) {
2719
- if (tiisg == 0 ) {
2720
- for (short j = 0 ; j < Q; ++j) {
2721
- const half S0 = ss[j*T + 0 ];
2722
- const half S1 = ss[j*T + r*SH + 0 ];
2690
+ const half S0 = ss[ 0 ];
2691
+ const half S1 = ss[r*SH + 0 ];
2723
2692
2724
- const half M0 = ss[j*T + 1 ];
2725
- const half M1 = ss[j*T + r*SH + 1 ];
2693
+ const half M0 = ss[ 1 ];
2694
+ const half M1 = ss[r*SH + 1 ];
2726
2695
2727
- const half M = max (M0, M1);
2696
+ const half M = max (M0, M1);
2728
2697
2729
- const half ms0 = exp (M0 - M);
2730
- const half ms1 = exp (M1 - M);
2698
+ const half ms0 = exp (M0 - M);
2699
+ const half ms1 = exp (M1 - M);
2731
2700
2732
- const half S = S0*ms0 + S1*ms1;
2733
-
2734
- ss[j*T + 0 ] = S;
2735
- ss[j*T + 1 ] = M;
2701
+ const half S = S0*ms0 + S1*ms1;
2736
2702
2737
- ss[j*T + C + j ] = ms0;
2738
- ss[j*T + C + j + r*SH ] = ms1 ;
2739
- }
2703
+ if (tiisg == 0 ) {
2704
+ ss[0 ] = S ;
2705
+ ss[ 1 ] = M;
2740
2706
}
2741
- }
2742
-
2743
- threadgroup_barrier (mem_flags::mem_threadgroup);
2744
2707
2745
- if (sgitg < r) {
2746
- for (short j = 0 ; j < Q; ++j) {
2747
- const half ms0 = ss[j*T + C + j];
2748
- const half ms1 = ss[j*T + C + j + r*SH];
2749
-
2750
- // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2751
- for (short i = tiisg; i < D4; i += NW) {
2752
- sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
2753
- }
2708
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2709
+ for (short ii = 0 ; ii < D4; ii += NW) {
2710
+ short i = ii + tiisg;
2711
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
2754
2712
}
2755
2713
}
2756
2714
@@ -2761,18 +2719,17 @@ kernel void kernel_flash_attn_ext_vec_f16(
2761
2719
2762
2720
// final rescale with 1/S and store to global memory
2763
2721
if (sgitg == 0 ) {
2764
- for (short j = 0 ; j < Q && iq1 + j < ne01; ++j) {
2765
- const half S = ss[j*T + 0 ];
2722
+ const half S = ss[0 ];
2766
2723
2767
- for (short i = tiisg; i < D4; i += NW) {
2768
- dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sr4[i]/S ;
2769
- }
2724
+ for (short ii = 0 ; ii < D4; ii += NW) {
2725
+ short i = ii + tiisg ;
2726
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
2770
2727
}
2771
2728
}
2772
2729
}
2773
2730
2774
- template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 , 1 , 32 >;
2775
- template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256 , 1 , 32 >;
2731
+ template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 , 32 >;
2732
+ template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256 , 32 >;
2776
2733
2777
2734
kernel void kernel_cpy_f16_f16 (
2778
2735
device const half * src0,
0 commit comments