Skip to content

Commit 363f0bf

Browse files
committed
Massive improvement for TG for fp16
1 parent 01eed46 commit 363f0bf

File tree

1 file changed

+55
-18
lines changed

1 file changed

+55
-18
lines changed

ggml-metal.metal

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -534,14 +534,27 @@ kernel void kernel_mul_mat_f16_f32_1row(
534534
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
535535

536536
float sumf = 0;
537-
for (int i = tiisg; i < ne00; i += 32) {
538-
sumf += (float) x[i] * (float) y[i];
537+
if (ne00 < 128) {
538+
for (int i = tiisg; i < ne00; i += 32) {
539+
sumf += (float) x[i] * (float) y[i];
540+
}
541+
float all_sum = simd_sum(sumf);
542+
if (tiisg == 0) {
543+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
544+
}
545+
} else {
546+
device const half4 * x4 = (device const half4 *) x;
547+
device const float4 * y4 = (device const float4 *) y;
548+
for (int i = tiisg; i < ne00/4; i += 32) {
549+
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
550+
}
551+
float all_sum = simd_sum(sumf);
552+
if (tiisg == 0) {
553+
for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i];
554+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
555+
}
539556
}
540557

541-
float all_sum = simd_sum(sumf);
542-
if (tiisg == 0) {
543-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
544-
}
545558
}
546559

547560
#define N_F16_F32 4
@@ -573,22 +586,46 @@ kernel void kernel_mul_mat_f16_f32(
573586

574587
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
575588

576-
for (int row = 0; row < N_F16_F32; ++row) {
577-
int r1 = rb + row;
578-
if (r1 >= ne11) {
579-
break;
580-
}
589+
if (ne00 < 128) {
590+
for (int row = 0; row < N_F16_F32; ++row) {
591+
int r1 = rb + row;
592+
if (r1 >= ne11) {
593+
break;
594+
}
581595

582-
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
596+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
583597

584-
float sumf = 0;
585-
for (int i = tiisg; i < ne00; i += 32) {
586-
sumf += (float) x[i] * (float) y[i];
598+
float sumf = 0;
599+
for (int i = tiisg; i < ne00; i += 32) {
600+
sumf += (float) x[i] * (float) y[i];
601+
}
602+
603+
float all_sum = simd_sum(sumf);
604+
if (tiisg == 0) {
605+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
606+
}
587607
}
608+
} else {
609+
device const half4 * x4 = (device const half4 *)x;
610+
for (int row = 0; row < N_F16_F32; ++row) {
611+
int r1 = rb + row;
612+
if (r1 >= ne11) {
613+
break;
614+
}
588615

589-
float all_sum = simd_sum(sumf);
590-
if (tiisg == 0) {
591-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
616+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
617+
device const float4 * y4 = (device const float4 *) y;
618+
619+
float sumf = 0;
620+
for (int i = tiisg; i < ne00/4; i += 32) {
621+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
622+
}
623+
624+
float all_sum = simd_sum(sumf);
625+
if (tiisg == 0) {
626+
for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i];
627+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
628+
}
592629
}
593630
}
594631

0 commit comments

Comments
 (0)