|
76 | 76 | GGML_METAL_DECL_KERNEL(rms_norm);
|
77 | 77 | GGML_METAL_DECL_KERNEL(norm);
|
78 | 78 | GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
| 79 | + GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); |
79 | 80 | GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
80 | 81 | GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
81 | 82 | GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
@@ -205,6 +206,7 @@ @implementation GGMLMetalClass
|
205 | 206 | GGML_METAL_ADD_KERNEL(rms_norm);
|
206 | 207 | GGML_METAL_ADD_KERNEL(norm);
|
207 | 208 | GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
| 209 | + GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); |
208 | 210 | GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
209 | 211 | GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
210 | 212 | GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
@@ -270,6 +272,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
270 | 272 | GGML_METAL_DEL_KERNEL(rms_norm);
|
271 | 273 | GGML_METAL_DEL_KERNEL(norm);
|
272 | 274 | GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
| 275 | + GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); |
273 | 276 | GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
274 | 277 | GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
275 | 278 | GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
@@ -854,7 +857,11 @@ void ggml_metal_graph_compute(
|
854 | 857 | {
|
855 | 858 | nth0 = 32;
|
856 | 859 | nth1 = 1;
|
857 |
| - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; |
| 860 | + if (ne11 * ne12 < 2) { |
| 861 | + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; |
| 862 | + } else { |
| 863 | + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; |
| 864 | + } |
858 | 865 | } break;
|
859 | 866 | case GGML_TYPE_Q4_0:
|
860 | 867 | {
|
@@ -974,8 +981,8 @@ void ggml_metal_graph_compute(
|
974 | 981 | else if (src0t == GGML_TYPE_Q6_K) {
|
975 | 982 | [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
976 | 983 | } else {
|
977 |
| - //[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; |
978 |
| - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne11 + 3)/4, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |
| 984 | + int64_t ny = (ne11 + 3)/4; |
| 985 | + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |
979 | 986 | }
|
980 | 987 | }
|
981 | 988 | } break;
|
|
0 commit comments