Skip to content

Commit e782c9e

Browse files
ikawrakowKawrakow
andauthored
Faster Q5_K and Q6_K on Metal (#2294)
* Faster Q6_K on Metal * Faster Q5_K on Metal * Another Q5_K speedup --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 785829d commit e782c9e

File tree

2 files changed

+136
-111
lines changed

2 files changed

+136
-111
lines changed

ggml-metal.m

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -703,17 +703,17 @@ void ggml_metal_graph_compute(
703703
GGML_ASSERT(ne02 == 1);
704704
GGML_ASSERT(ne12 == 1);
705705

706-
nth0 = 4;
707-
nth1 = 16;
706+
nth0 = 2;
707+
nth1 = 32;
708708
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
709709
} break;
710710
case GGML_TYPE_Q6_K:
711711
{
712712
GGML_ASSERT(ne02 == 1);
713713
GGML_ASSERT(ne12 == 1);
714714

715-
nth0 = 4;
716-
nth1 = 16;
715+
nth0 = 2;
716+
nth1 = 32;
717717
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
718718
} break;
719719
default:
@@ -743,11 +743,14 @@ void ggml_metal_graph_compute(
743743
src0t == GGML_TYPE_Q4_K) {
744744
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
745745
}
746+
else if (src0t == GGML_TYPE_Q5_K) {
747+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
748+
}
749+
else if (src0t == GGML_TYPE_Q6_K) {
750+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
751+
}
746752
else if (src0t == GGML_TYPE_Q2_K ||
747-
src0t == GGML_TYPE_Q3_K ||
748-
src0t == GGML_TYPE_Q4_K ||
749-
src0t == GGML_TYPE_Q5_K ||
750-
src0t == GGML_TYPE_Q6_K) {
753+
src0t == GGML_TYPE_Q3_K) {
751754
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
752755
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
753756
} else {

ggml-metal.metal

Lines changed: 125 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,39 +1642,39 @@ kernel void kernel_mul_mat_q5_K_f32(
16421642
constant int64_t & ne00,
16431643
constant int64_t & ne10,
16441644
constant int64_t & ne0,
1645-
threadgroup float * sum [[threadgroup(0)]],
16461645
uint2 tgpig[[threadgroup_position_in_grid]],
1647-
uint2 tpitg[[thread_position_in_threadgroup]],
1648-
uint2 tptg[[threads_per_threadgroup]]) {
1646+
uint tiisg[[thread_index_in_simdgroup]],
1647+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
16491648

16501649
const int nb = ne00/QK_K;
16511650

16521651
const int64_t r0 = tgpig.x;
16531652
const int64_t r1 = tgpig.y;
16541653

1655-
device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1654+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1655+
1656+
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
16561657
device const float * yy = (device const float *) src1 + r1*ne10;
16571658

1658-
const int nth = tptg.x*tptg.y;
1659-
const int ith = tptg.y*tpitg.x + tpitg.y;
1659+
float sumf[2]={0.f};
16601660

1661-
float sumf = 0;
1661+
const int step = sizeof(block_q5_K) * nb;
16621662

16631663
#if QK_K == 256
1664+
#
1665+
float yl[16], yh[16];
16641666

16651667
const uint16_t kmask1 = 0x3f3f;
16661668
const uint16_t kmask2 = 0x0f0f;
16671669
const uint16_t kmask3 = 0xc0c0;
16681670

1669-
const int tid = tpitg.y; // 0...16
1670-
const int il = tid/4; // 0...3
1671-
const int ir = tid - 4*il;// 0...3
1672-
const int n = 4;
1673-
1674-
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1675-
const int in = il%2;
1671+
const int tid = tiisg/4;
1672+
const int ix = tiisg%4;
1673+
const int im = tid/4;
1674+
const int ir = tid%4;
1675+
const int n = 8;
16761676

1677-
const int l0 = n*(2*ir + in);
1677+
const int l0 = n*ir;
16781678
const int q_offset = 32*im + l0;
16791679
const int y_offset = 64*im + l0;
16801680

@@ -1683,78 +1683,114 @@ kernel void kernel_mul_mat_q5_K_f32(
16831683
const uint8_t hm3 = hm1 << 4;
16841684
const uint8_t hm4 = hm2 << 4;
16851685

1686-
uchar2 sc1, sc2, sc3, sc4;
1686+
uint16_t sc16[4];
1687+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
16871688

1688-
for (int i = tpitg.x; i < nb; i += tptg.x) {
1689+
device const float * y1 = yy + ix*QK_K + y_offset;
16891690

1690-
device const uint8_t * q1 = (x + i)->qs + q_offset;
1691-
device const uint8_t * q2 = q1 + 64;
1692-
device const uint8_t * qh = (x + i)->qh + l0;
1693-
device const float * y1 = yy + i*QK_K + y_offset;
1694-
device const float * y2 = y1 + 128;
1691+
for (int i = ix; i < nb; i += 4) {
16951692

1696-
const float dall = (float)((x + i)->d);
1697-
const float dmin = (float)((x + i)->dmin);
1693+
device const uint8_t * q1 = x[i].qs + q_offset;
1694+
device const uint8_t * qh = x[i].qh + l0;
1695+
device const half * dh = &x[i].d;
1696+
device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
16981697

1699-
device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1700-
sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1701-
sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1702-
sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1703-
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1698+
device const float * y2 = y1 + 128;
1699+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
1700+
for (int l = 0; l < 8; ++l) {
1701+
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
1702+
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
1703+
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
1704+
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
1705+
}
17041706

1705-
float4 s = {0.f, 0.f, 0.f, 0.f};
1706-
float smin = 0;
1707-
for (int l = 0; l < n; ++l) {
1707+
for (int row = 0; row < 2; ++row) {
1708+
1709+
device const uint8_t * q2 = q1 + 64;
1710+
1711+
sc16[0] = a[0] & kmask1;
1712+
sc16[1] = a[2] & kmask1;
1713+
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1714+
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1715+
1716+
float4 acc = {0.f, 0.f, 0.f, 0.f};
1717+
for (int l = 0; l < n; ++l) {
1718+
uint8_t h = qh[l];
1719+
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1720+
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1721+
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1722+
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1723+
}
1724+
const float dall = dh[0];
1725+
const float dmin = dh[1];
1726+
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1727+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
17081728

1709-
s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1710-
s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1711-
s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1712-
s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1713-
smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1729+
q1 += step;
1730+
qh += step;
1731+
dh += step/2;
1732+
a += step/2;
17141733

17151734
}
1716-
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1735+
1736+
y1 += 4 * QK_K;
17171737

17181738
}
17191739
#else
1720-
const int il = 4 * tpitg.x; // 0, 4, 8, 12
1721-
const int im = il/8; // 0, 0, 1, 1
1722-
const int in = il%8; // 0, 4, 0, 4
1740+
float yl[8], yh[8];
17231741

1724-
for (int i = tpitg.y; i < nb; i += tptg.y) {
1742+
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
1743+
const int ix = tiisg%8;
1744+
const int im = il/8; // 0, 0, 1, 1
1745+
const int in = il%8; // 0, 4, 0, 4
17251746

1726-
const float d = (float)x[i].d;
1747+
device const float * y = yy + ix*QK_K + il;
1748+
1749+
for (int i = ix; i < nb; i += 8) {
1750+
1751+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
1752+
for (int l = 0; l < 4; ++l) {
1753+
yl[l+0] = y[l+ 0];
1754+
yl[l+4] = y[l+16];
1755+
yh[l+0] = y[l+32];
1756+
yh[l+4] = y[l+48];
1757+
}
1758+
1759+
device const half * dh = &x[i].d;
17271760
device const uint8_t * q = x[i].qs + il;
17281761
device const uint8_t * h = x[i].qh + in;
17291762
device const int8_t * s = x[i].scales;
1730-
device const float * y = yy + i*QK_K + il;
17311763

1732-
for (int l = 0; l < 4; ++l) {
1733-
const uint8_t hl = h[l] >> im;
1734-
sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1735-
+ y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1736-
+ y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1737-
+ y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1764+
for (int row = 0; row < 2; ++row) {
1765+
1766+
const float d = dh[0];
1767+
1768+
float2 acc = {0.f, 0.f};
1769+
for (int l = 0; l < 4; ++l) {
1770+
const uint8_t hl = h[l] >> im;
1771+
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
1772+
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
1773+
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
1774+
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
1775+
}
1776+
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
1777+
1778+
q += step;
1779+
h += step;
1780+
s += step;
1781+
dh += step/2;
1782+
17381783
}
1784+
1785+
y += 8 * QK_K;
17391786
}
17401787
#endif
1741-
sum[ith] = sumf;
17421788

1743-
//
1744-
// Accumulate the sum from all threads in the threadgroup
1745-
//
1746-
threadgroup_barrier(mem_flags::mem_threadgroup);
1747-
if (ith%4 == 0) {
1748-
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1749-
}
1750-
threadgroup_barrier(mem_flags::mem_threadgroup);
1751-
if (ith%16 == 0) {
1752-
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1753-
}
1754-
threadgroup_barrier(mem_flags::mem_threadgroup);
1755-
if (ith == 0) {
1756-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1757-
dst[r1*ne0 + r0] = sum[0];
1789+
for (int row = 0; row < 2; ++row) {
1790+
const float tot = simd_sum(sumf[row]);
1791+
if (tiisg == 0) {
1792+
dst[r1*ne0 + first_row + row] = tot;
1793+
}
17581794
}
17591795

17601796
}
@@ -1766,10 +1802,9 @@ kernel void kernel_mul_mat_q6_K_f32(
17661802
constant int64_t & ne00,
17671803
constant int64_t & ne10,
17681804
constant int64_t & ne0,
1769-
threadgroup float * sum [[threadgroup(0)]],
17701805
uint2 tgpig[[threadgroup_position_in_grid]],
1771-
uint2 tpitg[[thread_position_in_threadgroup]],
1772-
uint2 tptg[[threads_per_threadgroup]]) {
1806+
uint tiisg[[thread_index_in_simdgroup]],
1807+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
17731808

17741809
const uint8_t kmask1 = 0x03;
17751810
const uint8_t kmask2 = 0x0C;
@@ -1781,19 +1816,18 @@ kernel void kernel_mul_mat_q6_K_f32(
17811816
const int64_t r0 = tgpig.x;
17821817
const int64_t r1 = tgpig.y;
17831818

1784-
device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1785-
device const float * yy = (device const float *) src1 + r1*ne10;
1819+
const int row = 2 * r0 + sgitg;
17861820

1787-
const int nth = tptg.x*tptg.y;
1788-
const int ith = tptg.y*tpitg.x + tpitg.y;
1821+
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
1822+
device const float * yy = (device const float *) src1 + r1*ne10;
17891823

17901824
float sumf = 0;
17911825

17921826
#if QK_K == 256
1793-
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1794-
const int iqs = 16 * tpitg.y;
1795-
const int ip = iqs / 128; // 0 or 1
1796-
const int il = (iqs - 128*ip)/16; // 0...7
1827+
const int tid = tiisg/2;
1828+
const int ix = tiisg%2;
1829+
const int ip = tid/8; // 0 or 1
1830+
const int il = tid%8;
17971831
const int n = 4;
17981832
const int l0 = n*il;
17991833
const int is = 8*ip + l0/16;
@@ -1802,9 +1836,10 @@ kernel void kernel_mul_mat_q6_K_f32(
18021836
const int q_offset_l = 64*ip + l0;
18031837
const int q_offset_h = 32*ip + l0;
18041838

1805-
for (int i = tpitg.x; i < nb; i += tptg.x) {
1839+
for (int i = ix; i < nb; i += 2) {
18061840

1807-
device const uint8_t * ql = x[i].ql + q_offset_l;
1841+
device const uint8_t * q1 = x[i].ql + q_offset_l;
1842+
device const uint8_t * q2 = q1 + 32;
18081843
device const uint8_t * qh = x[i].qh + q_offset_h;
18091844
device const int8_t * sc = x[i].scales + is;
18101845

@@ -1814,19 +1849,21 @@ kernel void kernel_mul_mat_q6_K_f32(
18141849

18151850
float4 sums = {0.f, 0.f, 0.f, 0.f};
18161851
for (int l = 0; l < n; ++l) {
1817-
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1818-
sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1819-
sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1820-
sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1852+
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1853+
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1854+
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1855+
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
18211856
}
18221857

18231858
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
18241859

18251860
}
1861+
18261862
#else
1827-
const int il = 4*tpitg.x; // 0, 4, 8, 12
1863+
const int ix = tiisg/4;
1864+
const int il = 4*(tiisg%4);
18281865

1829-
for (int i = tpitg.y; i < nb; i += tptg.y) {
1866+
for (int i = ix; i < nb; i += 8) {
18301867
device const float * y = yy + i * QK_K + il;
18311868
device const uint8_t * ql = x[i].ql + il;
18321869
device const uint8_t * qh = x[i].qh + il;
@@ -1846,23 +1883,8 @@ kernel void kernel_mul_mat_q6_K_f32(
18461883

18471884
#endif
18481885

1849-
sum[ith] = sumf;
1850-
1851-
//
1852-
// Accumulate the sum from all threads in the threadgroup
1853-
//
1854-
threadgroup_barrier(mem_flags::mem_threadgroup);
1855-
if (ith%4 == 0) {
1856-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1857-
}
1858-
threadgroup_barrier(mem_flags::mem_threadgroup);
1859-
if (ith%16 == 0) {
1860-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1861-
}
1862-
threadgroup_barrier(mem_flags::mem_threadgroup);
1863-
if (ith == 0) {
1864-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1865-
dst[r1*ne0 + r0] = sum[0];
1886+
const float tot = simd_sum(sumf);
1887+
if (tiisg == 0) {
1888+
dst[r1*ne0 + row] = tot;
18661889
}
1867-
18681890
}

0 commit comments

Comments
 (0)