Skip to content

Commit bca5d0c

Browse files
committed
metal: template for mat-vec multiplication kernels
1 parent 06abf8e commit bca5d0c

File tree

2 files changed

+632
-1279
lines changed

2 files changed

+632
-1279
lines changed

ggml-metal.m

Lines changed: 64 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,15 @@
7676
GGML_METAL_DECL_KERNEL(rms_norm);
7777
GGML_METAL_DECL_KERNEL(norm);
7878
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
79-
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
80-
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
81-
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
82-
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
83-
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
84-
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
85-
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
86-
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
79+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
80+
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
81+
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
82+
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
83+
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
84+
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
85+
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
86+
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
87+
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
8788
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
8889
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
8990
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
@@ -205,14 +206,15 @@ @implementation GGMLMetalClass
205206
GGML_METAL_ADD_KERNEL(rms_norm);
206207
GGML_METAL_ADD_KERNEL(norm);
207208
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
208-
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
209-
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
210-
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
211-
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
212-
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
213-
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
214-
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
215-
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
209+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
210+
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
211+
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
212+
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
213+
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
214+
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
215+
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
216+
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
217+
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
216218
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
217219
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
218220
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
@@ -270,14 +272,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
270272
GGML_METAL_DEL_KERNEL(rms_norm);
271273
GGML_METAL_DEL_KERNEL(norm);
272274
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
273-
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
274-
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
275-
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
276-
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
277-
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
278-
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
279-
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
280-
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
275+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
276+
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
277+
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
278+
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
279+
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
280+
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
281+
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
282+
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
283+
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
281284
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
282285
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
283286
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
@@ -832,97 +835,42 @@ void ggml_metal_graph_compute(
832835
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
833836
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
834837
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
835-
} else {
836-
int nth0 = 32;
837-
int nth1 = 1;
838-
838+
} else if (ggml_is_contiguous(src0) &&
839+
ggml_is_contiguous(src1) &&
840+
src1t == GGML_TYPE_F32 &&
841+
ne00%32 == 0) {
839842
// use custom matrix x vector kernel
840-
switch (src0t) {
841-
case GGML_TYPE_F16:
842-
{
843-
nth0 = 64;
844-
nth1 = 1;
845-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
846-
} break;
847-
case GGML_TYPE_Q4_0:
848-
{
849-
GGML_ASSERT(ne02 == 1);
850-
GGML_ASSERT(ne12 == 1);
851-
852-
nth0 = 8;
853-
nth1 = 8;
854-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
855-
} break;
856-
case GGML_TYPE_Q4_1:
857-
{
858-
GGML_ASSERT(ne02 == 1);
859-
GGML_ASSERT(ne12 == 1);
860-
861-
nth0 = 8;
862-
nth1 = 8;
863-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
864-
} break;
865-
case GGML_TYPE_Q8_0:
866-
{
867-
GGML_ASSERT(ne02 == 1);
868-
GGML_ASSERT(ne12 == 1);
869-
870-
nth0 = 8;
871-
nth1 = 8;
872-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
873-
} break;
874-
case GGML_TYPE_Q2_K:
875-
{
876-
GGML_ASSERT(ne02 == 1);
877-
GGML_ASSERT(ne12 == 1);
878-
879-
nth0 = 2;
880-
nth1 = 32;
881-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
882-
} break;
883-
case GGML_TYPE_Q3_K:
884-
{
885-
GGML_ASSERT(ne02 == 1);
886-
GGML_ASSERT(ne12 == 1);
887-
888-
nth0 = 2;
889-
nth1 = 32;
890-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
891-
} break;
892-
case GGML_TYPE_Q4_K:
893-
{
894-
GGML_ASSERT(ne02 == 1);
895-
GGML_ASSERT(ne12 == 1);
896-
897-
nth0 = 2;
898-
nth1 = 32;
899-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
900-
} break;
901-
case GGML_TYPE_Q5_K:
902-
{
903-
GGML_ASSERT(ne02 == 1);
904-
GGML_ASSERT(ne12 == 1);
905-
906-
nth0 = 2;
907-
nth1 = 32;
908-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
909-
} break;
910-
case GGML_TYPE_Q6_K:
911-
{
912-
GGML_ASSERT(ne02 == 1);
913-
GGML_ASSERT(ne12 == 1);
914-
915-
nth0 = 2;
916-
nth1 = 32;
917-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
918-
} break;
919-
default:
920-
{
921-
metal_printf("Asserting on type %d\n",(int)src0t);
922-
GGML_ASSERT(false && "not implemented");
923-
}
924-
};
925-
843+
switch (src0->type) {
844+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; break;
845+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32]; break;
846+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; break;
847+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32]; break;
848+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32]; break;
849+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32]; break;
850+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32]; break;
851+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32]; break;
852+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; break;
853+
default: GGML_ASSERT(false && "MUL MAT-VEC not implemented");
854+
}
855+
int buffer_size_aligned = (512 / ggml_blck_size(src0t) * ggml_element_size(src0) + 31) / 32 * 32;
856+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
857+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
858+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
859+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
860+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
861+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
862+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:6];
863+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
864+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
865+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
866+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
867+
[encoder setThreadgroupMemoryLength:8 * buffer_size_aligned atIndex:0];
868+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
869+
} else {
870+
switch (src0->type) {
871+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; break;
872+
default: GGML_ASSERT(false && " not implemented");
873+
}
926874
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
927875
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
928876
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -941,27 +889,8 @@ void ggml_metal_graph_compute(
941889
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
942890
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
943891
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
944-
945-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
946-
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
947-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
948-
}
949-
else if (src0t == GGML_TYPE_Q3_K) {
950-
#ifdef GGML_QKK_64
951-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
952-
#else
953-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
954-
#endif
955-
}
956-
else if (src0t == GGML_TYPE_Q5_K) {
957-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
958-
}
959-
else if (src0t == GGML_TYPE_Q6_K) {
960-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
961-
} else {
962-
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
963-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
964-
}
892+
[encoder setThreadgroupMemoryLength:64*sizeof(float) atIndex:0];
893+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
965894
}
966895
} break;
967896
case GGML_OP_GET_ROWS:

0 commit comments

Comments
 (0)