Skip to content

Commit 27ad57a

Browse files
ikawrakowKawrakow
andauthored
Metal: faster Q4_0 and Q4_1 matrix x vector kernels (#2212)
* 3-5% faster Q4_0 on Metal * 7-25% faster Q4_1 on Metal * Oops, forgot to delete the original Q4_1 kernel --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 32c5411 commit 27ad57a

File tree

2 files changed

+105
-77
lines changed

2 files changed

+105
-77
lines changed

ggml-metal.m

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -739,12 +739,8 @@ void ggml_metal_graph_compute(
739739
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
740740
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
741741

742-
if (src0t == GGML_TYPE_Q4_0) {
743-
[encoder dispatchThreadgroups:MTLSizeMake(ne01 / 8+((ne01 % 8) & 0x01), ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
744-
}
745-
else if (src0t == GGML_TYPE_Q4_1) {
746-
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
747-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
742+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
743+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
748744
}
749745
else if (src0t == GGML_TYPE_Q2_K ||
750746
src0t == GGML_TYPE_Q3_K ||

ggml-metal.metal

Lines changed: 103 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -395,49 +395,63 @@ kernel void kernel_mul_mat_q4_0_f32(
395395
// each thread in a SIMD group deals with 1 block.
396396
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
397397

398+
float sumy = 0;
398399
for (int i = 0; i < QK4_0 / 4; i++) {
399400
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
401+
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
400402
}
403+
sumy *= (-8.f);
401404

402405
for (int row = 0; row < N_DST; row++) {
403406
// prefetch next x block
404407
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
405408

406409
// calculate
407410
float d = qb_curr.d;
408-
float2 acc = {0.0f, 0.0f};
411+
float acc = sumy;
409412
for (int i = 0; i < 16; i++) {
410-
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
411-
acc[1] += yl[i] + yl[i+16];
413+
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
412414
}
413-
sumf[row] += d * (acc[0] - 8.f*acc[1]);
415+
sumf[row] += d * acc;
414416
qb_curr = qb_next;
415417
}
416418
}
417419

418-
for (int i = 0; i < QK4_0 / 4; i++) {
419-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
420-
}
421-
422-
for (int row = 0; row < N_DST; row++) {
423-
// prefetch next x block
424-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
425-
426-
// calculate
427-
float d = qb_curr.d;
428-
float2 acc = {0.0f, 0.0f};
429-
for (int i = 0; i < 16; i++) {
430-
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
431-
acc[1] += yl[i] + yl[i+16];
420+
if (nb % N_SIMDWIDTH == 0) {
421+
for (int row = 0; row < N_DST; ++row) {
422+
all_sum = simd_sum(sumf[row]);
423+
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
424+
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
425+
}
432426
}
433-
if (tiisg < nb % N_SIMDWIDTH) {
434-
sumf[row] += d * (acc[0] - 8.f*acc[1]);
427+
} else {
428+
429+
float sumy = 0;
430+
for (int i = 0; i < QK4_0 / 4; i++) {
431+
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
432+
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
435433
}
436-
qb_curr = qb_next;
434+
sumy *= (-8.f);
435+
436+
for (int row = 0; row < N_DST; row++) {
437+
// prefetch next x block
438+
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
439+
440+
// calculate
441+
float d = qb_curr.d;
442+
float acc = sumy;
443+
for (int i = 0; i < 16; i++) {
444+
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
445+
}
446+
if (tiisg < nb % N_SIMDWIDTH) {
447+
sumf[row] += d * acc;
448+
}
449+
qb_curr = qb_next;
437450

438-
all_sum = simd_sum(sumf[row]);
439-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
440-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
451+
all_sum = simd_sum(sumf[row]);
452+
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
453+
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
454+
}
441455
}
442456
}
443457
}
@@ -449,65 +463,83 @@ kernel void kernel_mul_mat_q4_1_f32(
449463
constant int64_t & ne00,
450464
constant int64_t & ne10,
451465
constant int64_t & ne0,
452-
threadgroup float * sum [[threadgroup(0)]],
466+
constant int64_t & ne01[[buffer(4)]],
453467
uint2 tgpig[[threadgroup_position_in_grid]],
454-
uint2 tpitg[[thread_position_in_threadgroup]],
455-
uint2 tptg[[threads_per_threadgroup]]) {
456-
const int nb = ne00/QK4_1;
457-
458-
const int64_t r0 = tgpig.x;
459-
const int64_t r1 = tgpig.y;
460-
461-
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
468+
uint tiisg[[thread_index_in_simdgroup]],
469+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
470+
const int nb = ne00/QK4_0;
471+
const int r0 = tgpig.x;
472+
const int r1 = tgpig.y;
473+
device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
462474
device const float * y = (device const float *) src1 + r1*ne10;
475+
block_q4_1 qb_curr, qb_next;
476+
float4 y_curr[8]; // src1 vector cache
477+
float sumf[N_DST]={0.f}, all_sum;
478+
thread float * yl=(thread float *)y_curr;
463479

464-
const uint nth = tptg.x*tptg.y;
465-
const uint ith = tptg.y*tpitg.x + tpitg.y;
466-
467-
const int ix = tpitg.y/4; // 0 or 1
468-
const int iy = tpitg.y - 4*ix; // 0...3
469-
470-
const int first = 4 * iy;
471-
472-
float sumf = 0;
473-
474-
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
475-
476-
const float d = (float)x[i].d;
477-
const float m = (float)x[i].m;
480+
// bootstrap
481+
qb_curr = x[tiisg];
482+
// each thread in a SIMD group deals with 1 block.
483+
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
478484

479-
device const uint8_t * xl = x[i].qs + first;
480-
device const float * yl = y + i * QK4_1 + first;
485+
float sumy = 0;
486+
for (int i = 0; i < QK4_0 / 4; i++) {
487+
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
488+
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
489+
}
481490

482-
float2 acc = {0.0f, 0.0f};
491+
for (int row = 0; row < N_DST; row++) {
492+
// prefetch next x block
493+
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
483494

484-
for (int j = 0; j < 4; ++j) {
495+
// calculate
496+
const float d = qb_curr.d;
497+
const float m = qb_curr.m;
498+
float acc = 0.f;
499+
for (int i = 0; i < 16; i++) {
500+
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
501+
}
502+
sumf[row] += d * acc + m * sumy;
503+
qb_curr = qb_next;
504+
}
505+
}
485506

486-
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
487-
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
507+
if (nb % N_SIMDWIDTH == 0) {
508+
for (int row = 0; row < N_DST; ++row) {
509+
all_sum = simd_sum(sumf[row]);
510+
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
511+
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
512+
}
513+
}
514+
} else {
488515

516+
float sumy = 0;
517+
for (int i = 0; i < QK4_0 / 4; i++) {
518+
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
519+
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
489520
}
490521

491-
sumf += acc[0] + acc[1];
492-
}
522+
for (int row = 0; row < N_DST; row++) {
523+
// prefetch next x block
524+
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
493525

494-
sum[ith] = sumf;
526+
// calculate
527+
const float d = qb_curr.d;
528+
const float m = qb_curr.m;
529+
float acc = 0.f;
530+
for (int i = 0; i < 16; i++) {
531+
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
532+
}
533+
if (tiisg < nb % N_SIMDWIDTH) {
534+
sumf[row] += d * acc + m * sumy;
535+
}
536+
qb_curr = qb_next;
495537

496-
//
497-
// Accumulate the sum from all threads in the threadgroup
498-
//
499-
threadgroup_barrier(mem_flags::mem_threadgroup);
500-
if (ith%4 == 0) {
501-
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
502-
}
503-
threadgroup_barrier(mem_flags::mem_threadgroup);
504-
if (ith%16 == 0) {
505-
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
506-
}
507-
threadgroup_barrier(mem_flags::mem_threadgroup);
508-
if (ith == 0) {
509-
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
510-
dst[r1*ne0 + r0] = sum[0];
538+
all_sum = simd_sum(sumf[row]);
539+
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
540+
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
541+
}
542+
}
511543
}
512544
}
513545

0 commit comments

Comments
 (0)