Skip to content

Commit 49cec58

Browse files
committed
metal: use template to reduce size
The template mul_vec_q_n_f32 has codes that aim to maxmize the q4_0 and q4_1 throughput, but it shouldn't affect future q5_0 and q5_1 implementations.
1 parent 4088df1 commit 49cec58

File tree

1 file changed

+50
-123
lines changed

1 file changed

+50
-123
lines changed

ggml-metal.metal

Lines changed: 50 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -384,27 +384,44 @@ kernel void kernel_rms_norm(
384384
}
385385
}
386386

387+
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
388+
float block_q_n_dot_y(block_q4_0 qb_curr, float sumy, thread float * yl) {
389+
float d = qb_curr.d;
390+
float acc = sumy * -8.f;
391+
for (int i = 0; i < 16; i+=2) {
392+
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
393+
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
394+
}
395+
return d * acc;
396+
}
397+
398+
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
399+
float block_q_n_dot_y(block_q4_1 qb_curr, float sumy, thread float * yl) {
400+
float d = qb_curr.d;
401+
float m = qb_curr.m;
402+
float acc = 0.f;
403+
for (int i = 0; i < 16; i+=2) {
404+
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
405+
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
406+
}
407+
return d * acc + m * sumy;
408+
}
409+
387410
// putting them in the kernel cause a significant performance penalty
388411
#define N_DST 4 // each SIMD group works on 4 rows
389412
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
390413
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
391-
kernel void kernel_mul_mat_q4_0_f32(
392-
device const void * src0,
393-
device const float * src1,
394-
device float * dst,
395-
constant int64_t & ne00,
396-
constant int64_t & ne10,
397-
constant int64_t & ne0,
398-
constant int64_t & ne01[[buffer(4)]],
399-
uint2 tgpig[[threadgroup_position_in_grid]],
400-
uint tiisg[[thread_index_in_simdgroup]],
401-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
414+
415+
template<typename block_q_type>
416+
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
417+
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
418+
uint2 tgpig, uint tiisg, uint sgitg) {
402419
const int nb = ne00/QK4_0;
403420
const int r0 = tgpig.x;
404421
const int r1 = tgpig.y;
405-
device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
422+
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
406423
device const float * y = (device const float *) src1 + r1*ne10;
407-
block_q4_0 qb_curr, qb_next;
424+
block_q_type qb_curr;
408425
float4 y_curr[8]; // src1 vector cache
409426
float sumf[N_DST]={0.f}, all_sum;
410427
thread float * yl=(thread float *)y_curr;
@@ -419,25 +436,15 @@ kernel void kernel_mul_mat_q4_0_f32(
419436
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
420437
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
421438
}
422-
sumy *= (-8.f);
423439
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
440+
// this design is q4_0 and q4_1 centered, but I think most of the people use these two quantizations.
424441
for (int i = 0; i < 32; i++) {
425442
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
426443
}
427444

428445
for (int row = 0; row < N_DST; row++) {
429-
// prefetch next x block
430-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
431-
432-
// calculate
433-
float d = qb_curr.d;
434-
float acc = sumy;
435-
for (int i = 0; i < 16; i+=2) {
436-
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
437-
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
438-
}
439-
sumf[row] += d * acc;
440-
qb_curr = qb_next;
446+
sumf[row] += block_q_n_dot_y(qb_curr, sumy, yl);
447+
qb_curr = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
441448
}
442449
}
443450

@@ -449,32 +456,20 @@ kernel void kernel_mul_mat_q4_0_f32(
449456
}
450457
}
451458
} else {
452-
453459
float sumy = 0;
454460
for (int i = 0; i < QK4_0 / 4; i++) {
455461
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
456462
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
457463
}
458-
sumy *= (-8.f);
459464
for (int i = 0; i < 32; i++) {
460465
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
461466
}
462467

463468
for (int row = 0; row < N_DST; row++) {
464-
// prefetch next x block
465-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
466-
467-
// calculate
468-
float d = qb_curr.d;
469-
float acc = sumy;
470-
for (int i = 0; i < 16; i+=2) {
471-
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
472-
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
473-
}
474469
if (tiisg < nb % N_SIMDWIDTH) {
475-
sumf[row] += d * acc;
470+
sumf[row] += block_q_n_dot_y(qb_curr, sumy, yl);
476471
}
477-
qb_curr = qb_next;
472+
qb_curr = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
478473

479474
all_sum = simd_sum(sumf[row]);
480475
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
@@ -484,7 +479,7 @@ kernel void kernel_mul_mat_q4_0_f32(
484479
}
485480
}
486481

487-
kernel void kernel_mul_mat_q4_1_f32(
482+
kernel void kernel_mul_mat_q4_0_f32(
488483
device const void * src0,
489484
device const float * src1,
490485
device float * dst,
@@ -495,89 +490,21 @@ kernel void kernel_mul_mat_q4_1_f32(
495490
uint2 tgpig[[threadgroup_position_in_grid]],
496491
uint tiisg[[thread_index_in_simdgroup]],
497492
uint sgitg[[simdgroup_index_in_threadgroup]]) {
498-
const int nb = ne00/QK4_0;
499-
const int r0 = tgpig.x;
500-
const int r1 = tgpig.y;
501-
device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
502-
device const float * y = (device const float *) src1 + r1*ne10;
503-
block_q4_1 qb_curr, qb_next;
504-
float4 y_curr[8]; // src1 vector cache
505-
float sumf[N_DST]={0.f}, all_sum;
506-
thread float * yl=(thread float *)y_curr;
507-
508-
// bootstrap
509-
qb_curr = x[tiisg];
510-
// each thread in a SIMD group deals with 1 block.
511-
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
512-
513-
float sumy = 0;
514-
for (int i = 0; i < QK4_0 / 4; i++) {
515-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
516-
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
517-
}
518-
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
519-
for (int i = 0; i < 32; i++) {
520-
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
521-
}
522-
523-
for (int row = 0; row < N_DST; row++) {
524-
// prefetch next x block
525-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
526-
527-
// calculate
528-
const float d = qb_curr.d;
529-
const float m = qb_curr.m;
530-
float acc = 0.f;
531-
for (int i = 0; i < 16; i+=2) {
532-
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
533-
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
534-
}
535-
sumf[row] += d * acc + m * sumy;
536-
qb_curr = qb_next;
537-
}
538-
}
539-
540-
if (nb % N_SIMDWIDTH == 0) {
541-
for (int row = 0; row < N_DST; ++row) {
542-
all_sum = simd_sum(sumf[row]);
543-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
544-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
545-
}
546-
}
547-
} else {
548-
549-
float sumy = 0;
550-
for (int i = 0; i < QK4_0 / 4; i++) {
551-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
552-
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
553-
}
554-
for (int i = 0; i < 32; i++) {
555-
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
556-
}
557-
558-
for (int row = 0; row < N_DST; row++) {
559-
// prefetch next x block
560-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
561-
562-
// calculate
563-
const float d = qb_curr.d;
564-
const float m = qb_curr.m;
565-
float acc = 0.f;
566-
for (int i = 0; i < 16; i+=2) {
567-
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
568-
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
569-
}
570-
if (tiisg < nb % N_SIMDWIDTH) {
571-
sumf[row] += d * acc + m * sumy;
572-
}
573-
qb_curr = qb_next;
493+
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
494+
}
574495

575-
all_sum = simd_sum(sumf[row]);
576-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
577-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
578-
}
579-
}
580-
}
496+
kernel void kernel_mul_mat_q4_1_f32(
497+
device const void * src0,
498+
device const float * src1,
499+
device float * dst,
500+
constant int64_t & ne00,
501+
constant int64_t & ne10,
502+
constant int64_t & ne0,
503+
constant int64_t & ne01[[buffer(4)]],
504+
uint2 tgpig[[threadgroup_position_in_grid]],
505+
uint tiisg[[thread_index_in_simdgroup]],
506+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
507+
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
581508
}
582509

583510
kernel void kernel_mul_mat_f16_f32(

0 commit comments

Comments
 (0)