Skip to content

Commit 61c8259

Browse files
committed
metal : add mul_mat_q8_0_f32 kernel
1 parent 46a0881 commit 61c8259

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

ggml-metal.m

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
7575
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
7676
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
77+
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
7778
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
7879
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
7980
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
@@ -200,6 +201,7 @@ @implementation GGMLMetalClass
200201
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
201202
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
202203
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
204+
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
203205
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
204206
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
205207
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
@@ -802,6 +804,15 @@ void ggml_metal_graph_compute(
802804
nth1 = 8;
803805
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
804806
} break;
807+
case GGML_TYPE_Q8_0:
808+
{
809+
GGML_ASSERT(ne02 == 1);
810+
GGML_ASSERT(ne12 == 1);
811+
812+
nth0 = 8;
813+
nth1 = 8;
814+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
815+
} break;
805816
case GGML_TYPE_Q2_K:
806817
{
807818
GGML_ASSERT(ne02 == 1);
@@ -873,7 +884,7 @@ void ggml_metal_graph_compute(
873884
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
874885
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
875886

876-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
887+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
877888
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
878889
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
879890
}

ggml-metal.metal

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
363363
const int first_row = (r0 * nsg + sgitg) * nr;
364364
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
365365
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
366-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
366+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
367367
float yl[16]; // src1 vector cache
368368
float sumf[nr]={0.f};
369369

@@ -435,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
435435
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
436436
}
437437

438+
kernel void kernel_mul_mat_q8_0_f32(
439+
device const void * src0,
440+
device const float * src1,
441+
device float * dst,
442+
constant int64_t & ne00,
443+
constant int64_t & ne01[[buffer(4)]],
444+
constant int64_t & ne02[[buffer(5)]],
445+
constant int64_t & ne10[[buffer(9)]],
446+
constant int64_t & ne12[[buffer(11)]],
447+
constant int64_t & ne0[[buffer(15)]],
448+
constant int64_t & ne1[[buffer(16)]],
449+
constant uint & gqa[[buffer(17)]],
450+
uint3 tgpig[[threadgroup_position_in_grid]],
451+
uint tiisg[[thread_index_in_simdgroup]],
452+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
453+
const int nr = N_DST;
454+
const int nsg = N_SIMDGROUP;
455+
const int nw = N_SIMDWIDTH;
456+
457+
const int nb = ne00/QK8_0;
458+
const int r0 = tgpig.x;
459+
const int r1 = tgpig.y;
460+
const int im = tgpig.z;
461+
const int first_row = (r0 * nsg + sgitg) * nr;
462+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
463+
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465+
466+
float yl[16];
467+
float sumf[nr]={0.f};
468+
469+
const int ix = tiisg/2;
470+
const int il = tiisg%2;
471+
472+
device const float * yb = y + ix * QK8_0 + 16*il;
473+
474+
// each thread in a SIMD group deals with half a block.
475+
for (int ib = ix; ib < nb; ib += nw/2) {
476+
for (int i = 0; i < 16; ++i) {
477+
yl[i] = yb[i];
478+
}
479+
480+
for (int row = 0; row < nr; row++) {
481+
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
482+
float sumq = 0.f;
483+
for (int iq = 0; iq < 16; ++iq) {
484+
sumq += qs[iq] * yl[iq];
485+
}
486+
sumf[row] += sumq*x[ib+row*nb].d;
487+
}
488+
489+
yb += QK8_0 * 16;
490+
}
491+
492+
for (int row = 0; row < nr; ++row) {
493+
const float tot = simd_sum(sumf[row]);
494+
if (tiisg == 0 && first_row + row < ne01) {
495+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
496+
}
497+
}
498+
}
499+
438500
kernel void kernel_mul_mat_f16_f32(
439501
device const char * src0,
440502
device const char * src1,
@@ -486,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
486548
}
487549
}
488550

489-
490551
kernel void kernel_alibi_f32(
491552
device const float * src0,
492553
device float * dst,
@@ -1653,7 +1714,7 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
16531714

16541715
template <typename type4x4>
16551716
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1656-
device const uint8_t * qs = ((device const uint8_t *)xb->qs);
1717+
device const int8_t * qs = ((device const int8_t *)xb->qs);
16571718
const half d = xb->d;
16581719

16591720
for (int i=0;i<16;i++) {

0 commit comments

Comments
 (0)