Skip to content

Commit b557bc3

Browse files
committed
Another attempt
1 parent 2b60170 commit b557bc3

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

ggml-metal.m

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
GGML_METAL_DECL_KERNEL(rms_norm);
7777
GGML_METAL_DECL_KERNEL(norm);
7878
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
79+
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
7980
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
8081
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
8182
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -205,6 +206,7 @@ @implementation GGMLMetalClass
205206
GGML_METAL_ADD_KERNEL(rms_norm);
206207
GGML_METAL_ADD_KERNEL(norm);
207208
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
209+
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
208210
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
209211
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
210212
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -270,6 +272,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
270272
GGML_METAL_DEL_KERNEL(rms_norm);
271273
GGML_METAL_DEL_KERNEL(norm);
272274
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
275+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
273276
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
274277
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
275278
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -854,7 +857,11 @@ void ggml_metal_graph_compute(
854857
{
855858
nth0 = 32;
856859
nth1 = 1;
857-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
860+
if (ne11 * ne12 < 2) {
861+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
862+
} else {
863+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
864+
}
858865
} break;
859866
case GGML_TYPE_Q4_0:
860867
{
@@ -974,8 +981,8 @@ void ggml_metal_graph_compute(
974981
else if (src0t == GGML_TYPE_Q6_K) {
975982
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
976983
} else {
977-
//[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
978-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne11 + 3)/4, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
984+
int64_t ny = (ne11 + 3)/4;
985+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
979986
}
980987
}
981988
} break;

ggml-metal.metal

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,45 @@ kernel void kernel_mul_mat_q8_0_f32(
505505
}
506506
}
507507

508+
kernel void kernel_mul_mat_f16_f32_1row(
509+
device const char * src0,
510+
device const char * src1,
511+
device float * dst,
512+
constant int64_t & ne00,
513+
constant int64_t & ne01,
514+
constant int64_t & ne02,
515+
constant uint64_t & nb00,
516+
constant uint64_t & nb01,
517+
constant uint64_t & nb02,
518+
constant int64_t & ne10,
519+
constant int64_t & ne11,
520+
constant int64_t & ne12,
521+
constant uint64_t & nb10,
522+
constant uint64_t & nb11,
523+
constant uint64_t & nb12,
524+
constant int64_t & ne0,
525+
constant int64_t & ne1,
526+
uint3 tgpig[[threadgroup_position_in_grid]],
527+
uint tiisg[[thread_index_in_simdgroup]]) {
528+
529+
const int64_t r0 = tgpig.x;
530+
const int64_t r1 = tgpig.y;
531+
const int64_t im = tgpig.z;
532+
533+
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
534+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
535+
536+
float sumf = 0;
537+
for (int i = tiisg; i < ne00; i += 32) {
538+
sumf += (float) x[i] * (float) y[i];
539+
}
540+
541+
float all_sum = simd_sum(sumf);
542+
if (tiisg == 0) {
543+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
544+
}
545+
}
546+
508547
#define N_F16_F32 4
509548

510549
kernel void kernel_mul_mat_f16_f32(

0 commit comments

Comments
 (0)