184
184
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
185
185
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
186
186
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
187
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
187
+ // GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
188
188
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
189
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
189
+ // GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
190
190
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
191
191
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
192
192
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -634,9 +634,9 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
634
634
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm );
635
635
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm );
636
636
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm );
637
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm );
637
+ // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
638
638
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction );
639
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction );
639
+ // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
640
640
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
641
641
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
642
642
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
@@ -770,6 +770,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
770
770
case GGML_OP_LEAKY_RELU:
771
771
return true ;
772
772
case GGML_OP_FLASH_ATTN_EXT:
773
+ if (op->src [0 ]->ne [0 ] == 256 ) {
774
+ return false ;
775
+ }
773
776
return ctx->support_simdgroup_mm ; // TODO: over-restricted for vec-kernels
774
777
case GGML_OP_MUL_MAT:
775
778
case GGML_OP_MUL_MAT_ID:
@@ -2573,7 +2576,7 @@ static enum ggml_status ggml_metal_graph_compute(
2573
2576
case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline ; break ;
2574
2577
case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline ; break ;
2575
2578
case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline ; break ;
2576
- case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline ; break ;
2579
+ // case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2577
2580
default :
2578
2581
{
2579
2582
GGML_METAL_LOG_ERROR (" unsupported size: %lld \n " , ne00);
@@ -2586,7 +2589,7 @@ static enum ggml_status ggml_metal_graph_compute(
2586
2589
2587
2590
switch (ne00) {
2588
2591
case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline ; break ;
2589
- case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline ; break ;
2592
+ // case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2590
2593
default :
2591
2594
{
2592
2595
GGML_METAL_LOG_ERROR (" unsupported size: %lld \n " , ne00);
0 commit comments