Skip to content

Commit ac33a92

Browse files
committed
metal : add FA vector kernels for heads K 192 and V 128
ggml-ci
1 parent fa5f91d commit ac33a92

File tree

3 files changed

+154
-57
lines changed

3 files changed

+154
-57
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,20 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
409409
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
410410
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
411411
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
412+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,
413+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,
414+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,
415+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,
416+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,
417+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,
418+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,
419+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
420+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,
421+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,
422+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,
423+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,
424+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,
425+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
412426
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
413427
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
414428
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
@@ -1052,6 +1066,20 @@ @implementation GGMLMetalClass
10521066
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
10531067
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
10541068
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
1069+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction);
1070+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat);
1071+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction);
1072+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction);
1073+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction);
1074+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction);
1075+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction);
1076+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction);
1077+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat);
1078+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction);
1079+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction);
1080+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction);
1081+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction);
1082+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction);
10551083
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
10561084
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
10571085
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
@@ -3809,7 +3837,7 @@ static void ggml_metal_encode_node(
38093837

38103838
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
38113839
// for now avoiding mainly to keep the number of templates/kernels a bit lower
3812-
if (ne01 >= 4 || (ne00%128 != 0)) {
3840+
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
38133841
switch (src1->type) {
38143842
case GGML_TYPE_F16:
38153843
{
@@ -3994,6 +4022,42 @@ static void ggml_metal_encode_node(
39944022
}
39954023
}
39964024
} break;
4025+
case 192:
4026+
{
4027+
if (ne20 == 128) {
4028+
switch (src1->type) {
4029+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break;
4030+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break;
4031+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break;
4032+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break;
4033+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break;
4034+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break;
4035+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break;
4036+
default:
4037+
{
4038+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4039+
GGML_LOG_ERROR("add template specialization for this type\n");
4040+
GGML_ABORT("add template specialization for this type");
4041+
}
4042+
}
4043+
} else {
4044+
switch (src1->type) {
4045+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break;
4046+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break;
4047+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break;
4048+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break;
4049+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break;
4050+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break;
4051+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break;
4052+
default:
4053+
{
4054+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4055+
GGML_LOG_ERROR("add template specialization for this type\n");
4056+
GGML_ABORT("add template specialization for this type");
4057+
}
4058+
}
4059+
}
4060+
} break;
39974061
case 256:
39984062
{
39994063
switch (src1->type) {

0 commit comments

Comments
 (0)