@@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext(
3128
3128
const int iq2 = tgpig[1 ];
3129
3129
const int iq1 = tgpig[0 ]*Q;
3130
3130
3131
- const short DK4 = DK/4 ;
3132
- const short DK8 = DK/8 ;
3133
- const short DK16 = DK/16 ;
3134
- const short DV4 = DV/4 ;
3135
- const short DV8 = DV/8 ;
3136
- const short DV16 = DV/16 ;
3137
- const short NW = N_SIMDWIDTH;
3138
- const short SH = (2 *C + Q); // shared memory per simdgroup (s_t == float)
3131
+ constexpr short DK4 = DK/4 ;
3132
+ constexpr short DK8 = DK/8 ;
3133
+ constexpr short DK16 = DK/16 ;
3134
+ constexpr short DV4 = DV/4 ;
3135
+ constexpr short DV8 = DV/8 ;
3136
+ constexpr short DV16 = DV/16 ;
3137
+
3138
+ constexpr short NW = N_SIMDWIDTH;
3139
+ constexpr short SH = (2 *C + Q); // shared memory per simdgroup (s_t == float)
3139
3140
3140
3141
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3141
3142
const short T = DK + 2 *TS; // shared memory size per query in (half)
@@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec(
3641
3642
const int iq2 = tgpig[1 ];
3642
3643
const int iq1 = tgpig[0 ];
3643
3644
3644
- const short DK4 = DK/4 ;
3645
- const short DV4 = DV/4 ;
3646
- const short NW = N_SIMDWIDTH;
3647
- const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3648
- const short SH = 2 *C; // shared memory per simdgroup
3645
+ constexpr short DK4 = DK/4 ;
3646
+ constexpr short DV4 = DV/4 ;
3647
+ constexpr short NW = N_SIMDWIDTH;
3648
+ constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649
+ constexpr short SH = 2 *C; // shared memory per simdgroup
3649
3650
3650
3651
const short T = DK + nsg*SH; // shared memory size per query in (half)
3651
3652
@@ -3956,7 +3957,7 @@ kernel void kernel_flash_attn_ext_vec(
3956
3957
half, half4, \
3957
3958
half4
3958
3959
3959
- typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 128 >) flash_attn_ext_vec_t;
3960
+ typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >) flash_attn_ext_vec_t;
3960
3961
3961
3962
template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >;
3962
3963
#if defined(GGML_METAL_USE_BF16)
0 commit comments