Skip to content

Commit e68c96f

Browse files
ikawrakowKawrakow
andauthored
Faster Q2_K on Metal (#2297)
* Faster Q2_K on Metal * Deleting unnoticed and dangereous trailing white space * Fixed bug in new metal Q2_K implementation --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 9cf022a commit e68c96f

File tree

2 files changed

+104
-80
lines changed

2 files changed

+104
-80
lines changed

ggml-metal.m

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,8 @@ void ggml_metal_graph_compute(
676676
GGML_ASSERT(ne02 == 1);
677677
GGML_ASSERT(ne12 == 1);
678678

679-
nth0 = 4;
680-
nth1 = 16;
679+
nth0 = 2;
680+
nth1 = 32;
681681
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
682682
} break;
683683
case GGML_TYPE_Q3_K:
@@ -740,7 +740,7 @@ void ggml_metal_graph_compute(
740740
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
741741

742742
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
743-
src0t == GGML_TYPE_Q4_K) {
743+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
744744
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
745745
}
746746
else if (src0t == GGML_TYPE_Q5_K) {
@@ -749,8 +749,7 @@ void ggml_metal_graph_compute(
749749
else if (src0t == GGML_TYPE_Q6_K) {
750750
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
751751
}
752-
else if (src0t == GGML_TYPE_Q2_K ||
753-
src0t == GGML_TYPE_Q3_K) {
752+
else if (src0t == GGML_TYPE_Q3_K) {
754753
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
755754
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
756755
} else {

ggml-metal.metal

Lines changed: 100 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,108 +1209,133 @@ kernel void kernel_mul_mat_q2_K_f32(
12091209
constant int64_t & ne00,
12101210
constant int64_t & ne10,
12111211
constant int64_t & ne0,
1212-
threadgroup float * sum [[threadgroup(0)]],
1212+
constant int64_t & ne01[[buffer(4)]],
12131213
uint2 tgpig[[threadgroup_position_in_grid]],
1214-
uint2 tpitg[[thread_position_in_threadgroup]],
1215-
uint2 tptg[[threads_per_threadgroup]]) {
1214+
uint tiisg[[thread_index_in_simdgroup]],
1215+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
12161216

12171217
const int nb = ne00/QK_K;
1218+
const int r0 = tgpig.x;
1219+
const int r1 = tgpig.y;
12181220

1219-
const int64_t r0 = tgpig.x;
1220-
const int64_t r1 = tgpig.y;
1221-
1222-
device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1223-
device const float * yy = (device const float *) src1 + r1*ne10;
1224-
1225-
const int nth = tptg.x*tptg.y;
1226-
const int ith = tptg.y*tpitg.x + tpitg.y;
1221+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1222+
const int ib_row = first_row * nb;
1223+
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
1224+
device const float * y = (device const float *) src1 + r1*ne10;
1225+
float yl[32];
1226+
float sumf[N_DST]={0.f}, all_sum;
12271227

1228-
float sumf = 0;
1228+
const int step = sizeof(block_q2_K) * nb;
12291229

12301230
#if QK_K == 256
1231-
const int tid = tpitg.y; // 0...16
1232-
const int il = tid/4; // 0...3
1233-
const int ir = tid%4; // 0...3
1234-
const int ip = il/2; // 0 or 1
1235-
const int shift1 = 4*(il%2);// 0 or 4
1236-
const int shift2 = shift1+2;// 2 or 6
1237-
const int n = 8;
1238-
const int is = 4*il + (n*ir)/16;
1231+
const int ix = tiisg/8; // 0...3
1232+
const int it = tiisg%8; // 0...7
1233+
const int im = it/4; // 0 or 1
1234+
const int ir = it%4; // 0...3
1235+
const int is = (8*ir)/16;// 0 or 1
12391236

1240-
const int y_offset = 64*il + n*ir;
1241-
const int q_offset = 32*ip + n*ir;
1237+
device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
12421238

1243-
for (int i = tpitg.x; i < nb; i += tptg.x) {
1239+
for (int ib = ix; ib < nb; ib += 4) {
12441240

1245-
device const uint8_t * q = x[i].qs + q_offset;
1246-
device const uint8_t * scales = x[i].scales + is;
1241+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
1242+
for (int i = 0; i < 8; ++i) {
1243+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1244+
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
1245+
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
1246+
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1247+
}
12471248

1248-
uint8_t d1 = scales[0] & 0xF;
1249-
uint8_t d2 = scales[2] & 0xF;
1250-
uint8_t m1 = scales[0] >> 4;
1251-
uint8_t m2 = scales[2] >> 4;
1249+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1250+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1251+
device const half * dh = &x[ib].d;
12521252

1253-
device const float * y = yy + i*QK_K + y_offset;
1253+
for (int row = 0; row < N_DST; row++) {
12541254

1255-
float2 s = {0.f, 0.f};
1256-
float smin = 0;
1257-
for (int l = 0; l < n; ++l) {
1258-
s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
1259-
s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1260-
smin += y[l+ 0] * m1 + y[l+32] * m2;
1255+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1256+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1257+
for (int i = 0; i < 8; i += 2) {
1258+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1259+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1260+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1261+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1262+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1263+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1264+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1265+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1266+
}
1267+
float dall = dh[0];
1268+
float dmin = dh[1] * 1.f/16.f;
1269+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1270+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
1271+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
1272+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
1273+
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
1274+
1275+
qs += step/2;
1276+
sc += step;
1277+
dh += step/2;
12611278
}
12621279

1263-
const float dall = (float)x[i].d;
1264-
const float dmin = (float)x[i].dmin;
1265-
1266-
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1267-
1280+
y4 += 4 * QK_K;
12681281
}
12691282
#else
1270-
const int il = 4 * tpitg.x;
1283+
const int ix = tiisg/2; // 0...15
1284+
const int it = tiisg%2; // 0...1
12711285

1272-
uint32_t aux[2];
1273-
thread const uint8_t * d = (thread const uint8_t *)aux;
1274-
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1286+
device const float * y4 = y + ix * QK_K + 8 * it;
12751287

1276-
for (int i = tpitg.y; i < nb; i += tptg.y) {
1288+
for (int ib = ix; ib < nb; ib += 16) {
12771289

1278-
device const uint8_t * q = x[i].qs + il;
1279-
device const float * y = yy + i*QK_K + il;
1290+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
1291+
for (int i = 0; i < 8; ++i) {
1292+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1293+
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
1294+
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
1295+
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
1296+
}
12801297

1281-
const float dall = (float)x[i].d;
1282-
const float dmin = (float)x[i].dmin;
1298+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
1299+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1300+
device const half * dh = &x[ib].d;
12831301

1284-
device const uint32_t * a = (device const uint32_t *)x[i].scales;
1285-
aux[0] = a[0] & 0x0f0f0f0f;
1286-
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1302+
for (int row = 0; row < N_DST; row++) {
12871303

1288-
for (int l = 0; l < 4; ++l) {
1289-
sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1290-
+ y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1291-
+ y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1292-
+ y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1304+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1305+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1306+
for (int i = 0; i < 8; i += 2) {
1307+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1308+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1309+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1310+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1311+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1312+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1313+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1314+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1315+
}
1316+
1317+
float dall = dh[0];
1318+
float dmin = dh[1];
1319+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1320+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
1321+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
1322+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
1323+
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
1324+
1325+
qs += step/2;
1326+
sc += step;
1327+
dh += step/2;
12931328
}
1329+
1330+
y4 += 16 * QK_K;
12941331
}
12951332
#endif
12961333

1297-
sum[ith] = sumf;
1298-
1299-
//
1300-
// Accumulate the sum from all threads in the threadgroup
1301-
//
1302-
threadgroup_barrier(mem_flags::mem_threadgroup);
1303-
if (ith%4 == 0) {
1304-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1305-
}
1306-
threadgroup_barrier(mem_flags::mem_threadgroup);
1307-
if (ith%16 == 0) {
1308-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1309-
}
1310-
threadgroup_barrier(mem_flags::mem_threadgroup);
1311-
if (ith == 0) {
1312-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1313-
dst[r1*ne0 + r0] = sum[0];
1334+
for (int row = 0; row < N_DST; ++row) {
1335+
all_sum = simd_sum(sumf[row]);
1336+
if (tiisg == 0) {
1337+
dst[r1*ne0 + first_row + row] = all_sum;
1338+
}
13141339
}
13151340
}
13161341

0 commit comments

Comments
 (0)