@@ -351,42 +351,49 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
351
351
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
352
352
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
353
353
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
354
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
354
355
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
355
356
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
356
357
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
357
358
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
358
359
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
359
360
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
361
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
360
362
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
361
363
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
362
364
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
363
365
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
364
366
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
365
367
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
368
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
366
369
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
367
370
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
368
371
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
369
372
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
370
373
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
371
374
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
375
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
372
376
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
373
377
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
374
378
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
375
379
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
376
380
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
377
381
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
382
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
378
383
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
379
384
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
380
385
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
381
386
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
382
387
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
383
388
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
389
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
384
390
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
385
391
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
386
392
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
387
393
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
388
394
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
389
395
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
396
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
390
397
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
391
398
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
392
399
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
@@ -980,42 +987,49 @@ @implementation GGMLMetalClass
980
987
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
981
988
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
982
989
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
990
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
983
991
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
984
992
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
985
993
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
986
994
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
987
995
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
988
996
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
997
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
989
998
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
990
999
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
991
1000
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
992
1001
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
993
1002
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
994
1003
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
1004
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
995
1005
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
996
1006
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
997
1007
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
998
1008
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
999
1009
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
1000
1010
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
1011
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
1001
1012
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
1002
1013
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
1003
1014
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
1004
1015
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
1005
1016
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
1006
1017
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
1018
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
1007
1019
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
1008
1020
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
1009
1021
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
1010
1022
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
1011
1023
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
1012
1024
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
1025
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
1013
1026
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
1014
1027
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
1015
1028
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
1016
1029
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
1017
1030
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
1018
1031
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
1032
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
1019
1033
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
1020
1034
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
1021
1035
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
@@ -3789,6 +3803,7 @@ static void ggml_metal_encode_node(
3789
3803
case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline ; break ;
3790
3804
case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline ; break ;
3791
3805
case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline ; break ;
3806
+ case 192 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline ; break ;
3792
3807
case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline ; break ;
3793
3808
default :
3794
3809
{
@@ -3806,6 +3821,7 @@ static void ggml_metal_encode_node(
3806
3821
case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline ; break ;
3807
3822
case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline ; break ;
3808
3823
case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline ; break ;
3824
+ case 192 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline ; break ;
3809
3825
case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline ; break ;
3810
3826
default :
3811
3827
{
@@ -3823,6 +3839,7 @@ static void ggml_metal_encode_node(
3823
3839
case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline ; break ;
3824
3840
case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline ; break ;
3825
3841
case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline ; break ;
3842
+ case 192 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline ; break ;
3826
3843
case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline ; break ;
3827
3844
default :
3828
3845
{
@@ -3840,6 +3857,7 @@ static void ggml_metal_encode_node(
3840
3857
case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline ; break ;
3841
3858
case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline ; break ;
3842
3859
case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline ; break ;
3860
+ case 192 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline ; break ;
3843
3861
case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline ; break ;
3844
3862
default :
3845
3863
{
@@ -3857,6 +3875,7 @@ static void ggml_metal_encode_node(
3857
3875
case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline ; break ;
3858
3876
case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline ; break ;
3859
3877
case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline ; break ;
3878
+ case 192 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline ; break ;
3860
3879
case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline ; break ;
3861
3880
default :
3862
3881
{
@@ -3874,6 +3893,7 @@ static void ggml_metal_encode_node(
3874
3893
case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline ; break ;
3875
3894
case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline ; break ;
3876
3895
case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline ; break ;
3896
+ case 192 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline ; break ;
3877
3897
case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline ; break ;
3878
3898
default :
3879
3899
{
@@ -3891,6 +3911,7 @@ static void ggml_metal_encode_node(
3891
3911
case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline ; break ;
3892
3912
case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline ; break ;
3893
3913
case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline ; break ;
3914
+ case 192 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline ; break ;
3894
3915
case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline ; break ;
3895
3916
default :
3896
3917
{
0 commit comments