@@ -2575,21 +2575,20 @@ kernel void kernel_flash_attn_ext_vec_f16(
2575
2575
const short iv3 = iq3 / rv3;
2576
2576
2577
2577
// load the queries from shared memory into local memory
2578
- simdgroup_half8x8 mq[Q][D8 ];
2578
+ half4 mq[Q][D4 ];
2579
2579
2580
2580
for (short j = 0 ; j < Q; ++j) {
2581
- for (short i = 0 ; i < D8; ++i) {
2582
- simdgroup_load (mq[j][i], sq + 8 *j*T + i*8 , T);
2581
+ for (short ii = 0 ; ii < D4; ii += NW) {
2582
+ short i = ii + tiisg;
2583
+ mq[j][i] = sq4[j*T4 + i];
2583
2584
}
2584
2585
}
2585
2586
2586
2587
// pointer to the mask
2587
- // device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2588
- device const half * mp = (device const half *) (mask + iq1*nb31);
2588
+ device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2589
2589
2590
2590
// prepare diagonal scale matrix
2591
- simdgroup_half8x8 mscale (scale);
2592
- // half mscale(scale);
2591
+ half mscale (scale);
2593
2592
2594
2593
// loop over the KV cache
2595
2594
// each simdgroup handles blocks of Q rows and C columns
@@ -2599,79 +2598,45 @@ kernel void kernel_flash_attn_ext_vec_f16(
2599
2598
break ;
2600
2599
}
2601
2600
2602
- // Q*K^T
2603
- // {
2604
- // for (short cc = 0; cc < C/4; ++cc) {
2605
- // half4 mqk[Q];
2606
- // for (short j = 0; j < Q; ++j) {
2607
- // mqk[j] = 0.0h;
2608
- // }
2609
-
2610
- // device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
2611
-
2612
- // for (short i = tiisg; i < D4; i += NW) {
2613
- // half4x4 mk;
2614
- // mk[0] = pk4[i + 0*(nb11/8)];
2615
- // mk[1] = pk4[i + 1*(nb11/8)];
2616
- // mk[2] = pk4[i + 2*(nb11/8)];
2617
- // mk[3] = pk4[i + 3*(nb11/8)];
2618
-
2619
- // for (short j = 0; j < Q; ++j) {
2620
- // mqk[j] += mq[j][i] * mk;
2621
- // }
2622
- // }
2623
-
2624
- // // reduce the results from the threads in the simdgroup
2625
- // simdgroup_barrier(mem_flags::mem_none);
2626
-
2627
- // for (short i = NW/2; i > 0; i /= 2) {
2628
- // if (tiisg < i) {
2629
- // for (short j = 0; j < Q; ++j) {
2630
- // mqk[j] += simd_shuffle_down(mqk[j], i);
2631
- // }
2632
- // }
2633
-
2634
- // simdgroup_barrier(mem_flags::mem_none);
2635
- // }
2636
-
2637
- // // mqk = mqk*scale + mask
2638
- // if (tiisg == 0) {
2639
- // for (short j = 0; j < Q; ++j) {
2640
- // half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc];
2641
- // mqk[j] = mqk[j]*mscale + mm;
2642
-
2643
- // ss4[j*T4 + cc] = mqk[j];
2644
- // }
2645
- // }
2646
- // }
2647
- // }
2648
-
2649
2601
// Q*K^T
2650
2602
{
2651
- for (short cc = 0 ; cc < C/8 ; ++cc) {
2652
- simdgroup_half8x8 mqk[Q];
2653
- for (short j = 0 ; j < Q; ++j) {
2654
- mqk[j] = make_filled_simdgroup_matrix<half, 8 >(0 .h );
2655
- }
2603
+ for (short cc = 0 ; cc < C/4 ; ++cc) {
2604
+ half4 mqk[Q] = { [0 ... Q-1 ] = 0 .0h };
2656
2605
2657
- device const half * pk = (device const half *) ((device const char *) k + ((ic + 8 *cc)*nb11 + ik2*nb12 + ik3*nb13));
2606
+ device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4 *cc)*nb11 + ik2*nb12 + ik3*nb13));
2658
2607
2659
- for (short i = 0 ; i < D8; ++i) {
2660
- simdgroup_half8x8 mk;
2661
- simdgroup_load (mk, pk + i*8 , nb11/sizeof (half), 0 , true ); // transpose
2608
+ #pragma unroll
2609
+ for (short ii = 0 ; ii < D4; ii += NW) {
2610
+ const short i = ii + tiisg;
2611
+
2612
+ half4x4 mk;
2613
+ mk[0 ] = pk4[i + 0 *(nb11/8 )];
2614
+ mk[1 ] = pk4[i + 1 *(nb11/8 )];
2615
+ mk[2 ] = pk4[i + 2 *(nb11/8 )];
2616
+ mk[3 ] = pk4[i + 3 *(nb11/8 )];
2662
2617
2663
2618
for (short j = 0 ; j < Q; ++j) {
2664
- simdgroup_multiply_accumulate ( mqk[j], mq[j][i], mk, mqk[j]) ;
2619
+ mqk[j] += mq[j][i] * mk ;
2665
2620
}
2666
2621
}
2667
2622
2668
- // mqk = mqk*scale + mask
2623
+ // reduce the results from the threads in the simdgroup
2669
2624
for (short j = 0 ; j < Q; ++j) {
2670
- simdgroup_half8x8 mm;
2671
- simdgroup_load (mm, mp + 8 *j*(nb31/sizeof (half)) + ic + 8 *cc, nb31/sizeof (half), 0 , false );
2672
- simdgroup_multiply_accumulate (mqk[j], mqk[j], mscale, mm);
2625
+ mqk[j] += simd_shuffle_down (mqk[j], 16 );
2626
+ mqk[j] += simd_shuffle_down (mqk[j], 8 );
2627
+ mqk[j] += simd_shuffle_down (mqk[j], 4 );
2628
+ mqk[j] += simd_shuffle_down (mqk[j], 2 );
2629
+ mqk[j] += simd_shuffle_down (mqk[j], 1 );
2630
+ }
2673
2631
2674
- simdgroup_store (mqk[j], ss + 8 *j*T + 8 *cc, T, 0 , false );
2632
+ // mqk = mqk*scale + mask
2633
+ if (tiisg == 0 ) {
2634
+ for (short j = 0 ; j < Q; ++j) {
2635
+ half4 mm = mp4[(j*(nb31/sizeof (half)) + ic)/4 + cc];
2636
+ mqk[j] = mqk[j]*mscale + mm;
2637
+
2638
+ ss4[j*T4 + cc] = mqk[j];
2639
+ }
2675
2640
}
2676
2641
}
2677
2642
}
@@ -2701,26 +2666,26 @@ kernel void kernel_flash_attn_ext_vec_f16(
2701
2666
ss[tiisg*T + C + tiisg] = ms[tiisg];
2702
2667
}
2703
2668
2704
- // threadgroup_barrier(mem_flags::mem_threadgroup);
2705
-
2706
2669
// O = diag(ms)*O
2707
2670
for (short j = 0 ; j < Q; ++j) {
2708
- // simdgroup_half8x8 mm;
2709
- // simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
2710
2671
half mm (ss[j*T + C + j]);
2711
2672
2712
- for (short i = tiisg; i < D4; i += NW) {
2713
- // simdgroup_multiply(lo[j][i], mm, lo[j][i]);
2673
+ #pragma unroll
2674
+ for (short ii = 0 ; ii < D4; ii += NW) {
2675
+ const short i = ii + tiisg;
2714
2676
lo[j][i/NW] = lo[j][i/NW]*mm;
2715
2677
}
2716
2678
}
2717
2679
2718
2680
// O = O + (Q*K^T)*V
2719
2681
{
2682
+ #pragma unroll
2720
2683
for (short cc = 0 ; cc < C; ++cc) {
2721
2684
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
2722
2685
2723
- for (short i = tiisg; i < D4; i += NW) {
2686
+ #pragma unroll
2687
+ for (short ii = 0 ; ii < D4; ii += NW) {
2688
+ short i = ii + tiisg;
2724
2689
for (short j = 0 ; j < Q; ++j) {
2725
2690
lo[j][i/NW] += pv4[i]*ss[j*T + cc];
2726
2691
}
@@ -2738,15 +2703,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
2738
2703
}
2739
2704
}
2740
2705
2741
- threadgroup_barrier (mem_flags::mem_threadgroup);
2742
-
2743
2706
// store results to shared memory
2744
2707
for (short j = 0 ; j < Q; ++j) {
2745
- for (short i = tiisg; i < D4; i += NW) {
2746
- sr4[i] = lo[j][i/NW];
2708
+ for (short ii = 0 ; ii < D4; ii += NW) {
2709
+ short i = ii + tiisg;
2710
+ sr4[i] = lo[j][ii/NW];
2747
2711
}
2748
2712
}
2749
2713
2714
+ threadgroup_barrier (mem_flags::mem_threadgroup);
2715
+
2750
2716
// parallel reduce
2751
2717
for (short r = nsg/2 ; r > 0 ; r >>= 1 ) {
2752
2718
if (sgitg < r) {
@@ -2805,10 +2771,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
2805
2771
}
2806
2772
}
2807
2773
2808
- template [[host_name(" kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 , 2 , 32 >;
2809
- template [[host_name(" kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 , 3 , 32 >;
2810
- template [[host_name(" kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 , 4 , 32 >;
2811
- template [[host_name(" kernel_flash_attn_ext_vec_f16_h112" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 , 5 , 32 >;
2812
2774
template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128 , 1 , 32 >;
2813
2775
template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256 , 1 , 32 >;
2814
2776
0 commit comments