Skip to content

Commit 99ed03a

Browse files
committed
metal : improve decoding speed for batches of 2-16
1 parent f1782c6 commit 99ed03a

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

ggml-metal.m

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,14 +993,34 @@ void ggml_metal_graph_compute(
993993
uint gqa = ne12/ne02;
994994
GGML_ASSERT(ne03 == ne13);
995995

996+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
997+
// to the matrix-vector kernel. the numbers below are measure on M2 Ultra
998+
// not sure if this translates across all chips
999+
int ne11_mm_min = 1;
1000+
1001+
switch (src0t) {
1002+
case GGML_TYPE_F16: ne11_mm_min = 2; break;
1003+
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1004+
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1005+
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1006+
case GGML_TYPE_Q4_0:
1007+
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1008+
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1009+
case GGML_TYPE_Q5_0: // not tested yet
1010+
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1011+
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1012+
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1013+
default: ne11_mm_min = 1; break;
1014+
}
1015+
9961016
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
9971017
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
9981018
if (!ggml_is_transposed(src0) &&
9991019
!ggml_is_transposed(src1) &&
10001020
src1t == GGML_TYPE_F32 &&
10011021
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
10021022
ne00%32 == 0 &&
1003-
ne11 > 2) {
1023+
ne11 > ne11_mm_min) {
10041024
switch (src0->type) {
10051025
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
10061026
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;

0 commit comments

Comments
 (0)