@@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec(
3356
3356
const short D4 = D/4 ;
3357
3357
const short D16 = D/16 ;
3358
3358
const short NW = N_SIMDWIDTH;
3359
- const short NL = NW/4 ;
3360
- const short SH = 2 *C; // shared memory per simdgroup
3359
+ const short NL = NW/4 ; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
3360
+ const short SH = 2 *C; // shared memory per simdgroup
3361
3361
3362
3362
const short T = D + nsg*SH; // shared memory size per query in (half)
3363
3363
@@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(
3448
3448
3449
3449
// Q*K^T
3450
3450
{
3451
- // each simdgroup processes 1 query and 4 keys
3451
+ // each simdgroup processes 1 query and 4 (NW/NL) keys
3452
3452
for (short cc = 0 ; cc < C/4 ; ++cc) {
3453
3453
qk_t mqka[4 ] = { 0.0 , 0.0 , 0.0 , 0.0 };
3454
3454
@@ -3646,7 +3646,7 @@ kernel void kernel_flash_attn_ext_vec(
3646
3646
half, half4, half4x4, \
3647
3647
half4x4
3648
3648
3649
- typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_vec_t;
3649
+ typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 128 >) flash_attn_ext_vec_t;
3650
3650
3651
3651
template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 128 >;
3652
3652
#if defined(GGML_METAL_USE_BF16)
0 commit comments