@@ -2723,45 +2723,9 @@ kernel void kernel_leaky_relu_f32(
2723
2723
dst[tpig] = src0[tpig] > 0 .0f ? src0[tpig] : src0[tpig] * slope;
2724
2724
}
2725
2725
2726
- typedef void (flash_attn_ext_t )(
2727
- device const char * q,
2728
- device const char * k,
2729
- device const char * v,
2730
- device const char * mask,
2731
- device float * dst,
2732
- constant int64_t & ne01,
2733
- constant int64_t & ne02,
2734
- constant int64_t & ne03,
2735
- constant uint64_t & nb01,
2736
- constant uint64_t & nb02,
2737
- constant uint64_t & nb03,
2738
- constant int64_t & ne11,
2739
- constant int64_t & ne12,
2740
- constant int64_t & ne13,
2741
- constant uint64_t & nb11,
2742
- constant uint64_t & nb12,
2743
- constant uint64_t & nb13,
2744
- constant uint64_t & nb21,
2745
- constant uint64_t & nb22,
2746
- constant uint64_t & nb23,
2747
- constant uint64_t & nb31,
2748
- constant int64_t & ne1,
2749
- constant int64_t & ne2,
2750
- constant float & scale,
2751
- constant float & max_bias,
2752
- constant float & m0,
2753
- constant float & m1,
2754
- constant uint32_t & n_head_log2,
2755
- constant float & logit_softcap,
2756
- threadgroup half * shared,
2757
- uint3 tgpig[[threadgroup_position_in_grid]],
2758
- uint3 tpitg[[thread_position_in_threadgroup]],
2759
- uint3 ntg[[threads_per_threadgroup]],
2760
- ushort tiisg[[thread_index_in_simdgroup]],
2761
- ushort sgitg[[simdgroup_index_in_threadgroup]]);
2762
-
2763
2726
// ref: https://arxiv.org/pdf/2307.08691.pdf
2764
- template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread half4x4 &), short D, short Q = 8 , short K = 8 , short C = 32 > // head size, queries per threadgroup, cache items per threadgroup
2727
+ // D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
2728
+ template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread half4x4 &), short D, short Q = 8 , short KV = 8 , short C = 32 >
2765
2729
kernel void kernel_flash_attn_ext (
2766
2730
device const char * q,
2767
2731
device const char * k,
@@ -2818,8 +2782,8 @@ kernel void kernel_flash_attn_ext(
2818
2782
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
2819
2783
threadgroup float * ss = (threadgroup float *) (shared + 2 *sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
2820
2784
2821
- threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4 *16 *K ) + Q*T); // scratch buffer to load K and V in shared memory
2822
- threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4 *16 *K ) + Q*T); // same as above but in half4x4
2785
+ threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4 *16 *KV ) + Q*T); // scratch buffer to load K and V in shared memory
2786
+ threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4 *16 *KV ) + Q*T); // same as above but in half4x4
2823
2787
2824
2788
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2825
2789
simdgroup_half8x8 lo[D8];
@@ -3179,6 +3143,8 @@ kernel void kernel_flash_attn_ext(
3179
3143
}
3180
3144
}
3181
3145
3146
+ typedef decltype (kernel_flash_attn_ext<half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_t;
3147
+
3182
3148
template [[host_name(" kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1 , dequantize_f16, 64 >;
3183
3149
template [[host_name(" kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1 , dequantize_f16, 80 >;
3184
3150
template [[host_name(" kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1 , dequantize_f16, 96 >;
@@ -3223,7 +3189,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_
3223
3189
3224
3190
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
3225
3191
template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &), short D, short Q = 1 , short C = 32 >
3226
- kernel void flash_attn_ext_vec (
3192
+ kernel void kernel_flash_attn_ext_vec (
3227
3193
device const char * q,
3228
3194
device const char * k,
3229
3195
device const char * v,
@@ -3548,22 +3514,21 @@ kernel void flash_attn_ext_vec(
3548
3514
}
3549
3515
}
3550
3516
3551
- // template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<128>;
3552
- // template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<256>;
3517
+ typedef decltype (kernel_flash_attn_ext_vec<half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_vec_t;
3553
3518
3554
- template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <half4x4, 1 , dequantize_f16, 128 >;
3555
- template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q4_0, 2 , dequantize_q4_0, 128 >;
3556
- template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q4_1, 2 , dequantize_q4_1, 128 >;
3557
- template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q5_0, 2 , dequantize_q5_0, 128 >;
3558
- template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q5_1, 2 , dequantize_q5_1, 128 >;
3559
- template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h128" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q8_0, 2 , dequantize_q8_0, 128 >;
3519
+ template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <half4x4, 1 , dequantize_f16, 128 >;
3520
+ template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q4_0, 2 , dequantize_q4_0, 128 >;
3521
+ template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q4_1, 2 , dequantize_q4_1, 128 >;
3522
+ template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q5_0, 2 , dequantize_q5_0, 128 >;
3523
+ template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q5_1, 2 , dequantize_q5_1, 128 >;
3524
+ template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q8_0, 2 , dequantize_q8_0, 128 >;
3560
3525
3561
- template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <half4x4, 1 , dequantize_f16, 256 >;
3562
- template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q4_0, 2 , dequantize_q4_0, 256 >;
3563
- template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q4_1, 2 , dequantize_q4_1, 256 >;
3564
- template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q5_0, 2 , dequantize_q5_0, 256 >;
3565
- template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q5_1, 2 , dequantize_q5_1, 256 >;
3566
- template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h256" )]] kernel flash_attn_ext_t flash_attn_ext_vec <block_q8_0, 2 , dequantize_q8_0, 256 >;
3526
+ template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <half4x4, 1 , dequantize_f16, 256 >;
3527
+ template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q4_0, 2 , dequantize_q4_0, 256 >;
3528
+ template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q4_1, 2 , dequantize_q4_1, 256 >;
3529
+ template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q5_0, 2 , dequantize_q5_0, 256 >;
3530
+ template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q5_1, 2 , dequantize_q5_1, 256 >;
3531
+ template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec <block_q8_0, 2 , dequantize_q8_0, 256 >;
3567
3532
3568
3533
template <typename T0, typename T1>
3569
3534
kernel void kernel_cpy (
0 commit comments