Skip to content

Commit f31b6f4

Browse files
ikawrakowKawrakowggerganov
authored
metal : PP speedup (#3084)
* Minor speed gains for all quantization types * metal: faster kernel_scale via float4 * Various other speedups for "small" kernels * metal: faster soft_max vial float4 * metal: faster diagonal infinity Although, to me it looks like one should simply fuse scale + diagnonal infinity + soft_max on the KQtensor. * Another faster f16 x f32 matrix multiply kernel * Reverting the diag infinity change It does work for PP, but somehow it fails for TG. Need to look more into it. * metal: add back faster diagonal infinity This time more carefully * metal : minor (readibility) --------- Co-authored-by: Iwan Kawrakow <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 6eeb4d9 commit f31b6f4

File tree

2 files changed

+218
-102
lines changed

2 files changed

+218
-102
lines changed

ggml-metal.m

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@
6363
GGML_METAL_DECL_KERNEL(relu);
6464
GGML_METAL_DECL_KERNEL(gelu);
6565
GGML_METAL_DECL_KERNEL(soft_max);
66+
GGML_METAL_DECL_KERNEL(soft_max_4);
6667
GGML_METAL_DECL_KERNEL(diag_mask_inf);
68+
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
6769
GGML_METAL_DECL_KERNEL(get_rows_f16);
6870
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
6971
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@@ -77,6 +79,7 @@
7779
GGML_METAL_DECL_KERNEL(norm);
7880
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
7981
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
82+
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
8083
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
8184
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
8285
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -218,7 +221,9 @@ @implementation GGMLMetalClass
218221
GGML_METAL_ADD_KERNEL(relu);
219222
GGML_METAL_ADD_KERNEL(gelu);
220223
GGML_METAL_ADD_KERNEL(soft_max);
224+
GGML_METAL_ADD_KERNEL(soft_max_4);
221225
GGML_METAL_ADD_KERNEL(diag_mask_inf);
226+
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
222227
GGML_METAL_ADD_KERNEL(get_rows_f16);
223228
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
224229
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@@ -232,6 +237,7 @@ @implementation GGMLMetalClass
232237
GGML_METAL_ADD_KERNEL(norm);
233238
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
234239
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
240+
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
235241
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
236242
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
237243
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -286,7 +292,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
286292
GGML_METAL_DEL_KERNEL(relu);
287293
GGML_METAL_DEL_KERNEL(gelu);
288294
GGML_METAL_DEL_KERNEL(soft_max);
289-
GGML_METAL_DEL_KERNEL(diag_mask_inf);
295+
GGML_METAL_DEL_KERNEL(soft_max_4);
296+
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
290297
GGML_METAL_DEL_KERNEL(get_rows_f16);
291298
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
292299
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@@ -300,6 +307,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
300307
GGML_METAL_DEL_KERNEL(norm);
301308
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
302309
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
310+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
303311
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
304312
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
305313
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -767,7 +775,7 @@ void ggml_metal_graph_compute(
767775
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
768776
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
769777

770-
const int64_t n = ggml_nelements(dst);
778+
const int64_t n = ggml_nelements(dst)/4;
771779

772780
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
773781
} break;
@@ -779,7 +787,7 @@ void ggml_metal_graph_compute(
779787
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
780788
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
781789

782-
const int64_t n = ggml_nelements(dst);
790+
const int64_t n = ggml_nelements(dst)/4;
783791

784792
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
785793
} break;
@@ -799,7 +807,7 @@ void ggml_metal_graph_compute(
799807
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
800808
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
801809

802-
const int64_t n = ggml_nelements(dst);
810+
const int64_t n = ggml_nelements(dst)/4;
803811

804812
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
805813
} break;
@@ -813,28 +821,40 @@ void ggml_metal_graph_compute(
813821
{
814822
const int nth = 32;
815823

816-
[encoder setComputePipelineState:ctx->pipeline_soft_max];
824+
if (ne00%4 == 0) {
825+
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
826+
} else {
827+
[encoder setComputePipelineState:ctx->pipeline_soft_max];
828+
}
817829
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
818830
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
819831
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
820832
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
821833
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
822-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
823834

824835
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
825836
} break;
826837
case GGML_OP_DIAG_MASK_INF:
827838
{
828839
const int n_past = ((int32_t *)(dst->op_params))[0];
829840

830-
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
841+
if (ne00%8 == 0) {
842+
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
843+
} else {
844+
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
845+
}
831846
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
832847
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
833848
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
834849
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
835850
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
836851

837-
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
852+
if (ne00%8 == 0) {
853+
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
854+
}
855+
else {
856+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
857+
}
838858
} break;
839859
case GGML_OP_MUL_MAT:
840860
{
@@ -881,6 +901,7 @@ void ggml_metal_graph_compute(
881901
} else {
882902
int nth0 = 32;
883903
int nth1 = 1;
904+
int nrows = 1;
884905

885906
// use custom matrix x vector kernel
886907
switch (src0t) {
@@ -890,8 +911,12 @@ void ggml_metal_graph_compute(
890911
nth1 = 1;
891912
if (ne11 * ne12 < 4) {
892913
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
914+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
915+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
916+
nrows = ne11;
893917
} else {
894918
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
919+
nrows = 4;
895920
}
896921
} break;
897922
case GGML_TYPE_Q4_0:
@@ -1012,7 +1037,7 @@ void ggml_metal_graph_compute(
10121037
else if (src0t == GGML_TYPE_Q6_K) {
10131038
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
10141039
} else {
1015-
int64_t ny = (ne11 + 3)/4;
1040+
int64_t ny = (ne11 + nrows - 1)/nrows;
10161041
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
10171042
}
10181043
}

0 commit comments

Comments
 (0)