Skip to content

Commit e8d9158

Browse files
ikawrakowKawrakow
andauthored
metal: somewhat faster f16 x f32 matrix multiply kernel (#2951)
* Somewhat faster f16 x f32 matrix multiply kernel * Better use 32 thread groups for f16 x f32 --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent bce1fef commit e8d9158

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

ggml-metal.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ void ggml_metal_graph_compute(
840840
switch (src0t) {
841841
case GGML_TYPE_F16:
842842
{
843-
nth0 = 64;
843+
nth0 = 32;
844844
nth1 = 1;
845845
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
846846
} break;

ggml-metal.metal

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32(
528528
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
529529
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
530530

531-
sum[tpitg.x] = 0.0f;
531+
uint ith = tpitg.x;
532+
uint nth = tptg.x;
532533

533-
for (int i = tpitg.x; i < ne00; i += tptg.x) {
534-
sum[tpitg.x] += (float) x[i] * (float) y[i];
534+
sum[ith] = 0.0f;
535+
536+
for (int i = ith; i < ne00; i += nth) {
537+
sum[ith] += (float) x[i] * (float) y[i];
535538
}
536539

537540
// accumulate the sum from all threads in the threadgroup
538541
threadgroup_barrier(mem_flags::mem_threadgroup);
539-
for (uint i = tptg.x/2; i > 0; i /= 2) {
540-
if (tpitg.x < i) {
541-
sum[tpitg.x] += sum[tpitg.x + i];
542-
}
543-
threadgroup_barrier(mem_flags::mem_threadgroup);
542+
if (ith%4 == 0) {
543+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
544544
}
545-
546-
if (tpitg.x == 0) {
545+
threadgroup_barrier(mem_flags::mem_threadgroup);
546+
if (ith%16 == 0) {
547+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
548+
}
549+
threadgroup_barrier(mem_flags::mem_threadgroup);
550+
if (ith == 0) {
551+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
547552
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
548553
}
554+
555+
// Original implementation. Left behind commented out for now
556+
//threadgroup_barrier(mem_flags::mem_threadgroup);
557+
//for (uint i = tptg.x/2; i > 0; i /= 2) {
558+
// if (tpitg.x < i) {
559+
// sum[tpitg.x] += sum[tpitg.x + i];
560+
// }
561+
// threadgroup_barrier(mem_flags::mem_threadgroup);
562+
//}
563+
//
564+
//if (tpitg.x == 0) {
565+
// dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
566+
//}
549567
}
550568

551569
kernel void kernel_alibi_f32(

0 commit comments

Comments
 (0)