179
179
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
180
180
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
181
181
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
182
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
183
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80,
184
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
185
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112,
186
182
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
187
183
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
188
184
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
@@ -625,10 +621,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
625
621
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true );
626
622
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true );
627
623
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true );
628
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, true );
629
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, true );
630
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, true );
631
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, true );
632
624
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true );
633
625
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true );
634
626
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
@@ -2521,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute(
2521
2513
2522
2514
id <MTLComputePipelineState > pipeline = nil ;
2523
2515
2524
- if (ne01 > 1 ) {
2516
+ if (ne01 > 1 || (ne00% 128 != 0 ) ) {
2525
2517
switch (ne00) {
2526
2518
case 64 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline ; break ;
2527
2519
case 80 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline ; break ;
@@ -2538,10 +2530,6 @@ static enum ggml_status ggml_metal_graph_compute(
2538
2530
}
2539
2531
} else {
2540
2532
switch (ne00) {
2541
- case 64 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline ; break ;
2542
- case 80 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline ; break ;
2543
- case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline ; break ;
2544
- case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline ; break ;
2545
2533
case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline ; break ;
2546
2534
case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline ; break ;
2547
2535
default :
0 commit comments