Skip to content

Commit 9789cd2

Browse files
ggerganovakawrykow
authored andcommitted
metal : add Q8_0 support (ggml-org#2763)
* metal : add dequantize_q8_0 kernel * metal : add mul_mat_q8_0_f32 kernel * metal : add Q8_0 mul_mm kernel
1 parent 8282d03 commit 9789cd2

File tree

2 files changed

+106
-10
lines changed

2 files changed

+106
-10
lines changed

ggml-metal.m

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
GGML_METAL_DECL_KERNEL(get_rows_f16);
6464
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
6565
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
66+
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
6667
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
6768
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
6869
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -73,6 +74,7 @@
7374
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
7475
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
7576
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
77+
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
7678
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
7779
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
7880
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
@@ -81,6 +83,7 @@
8183
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
8284
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
8385
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
86+
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
8487
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
8588
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
8689
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
@@ -188,6 +191,7 @@ @implementation GGMLMetalClass
188191
GGML_METAL_ADD_KERNEL(get_rows_f16);
189192
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
190193
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
194+
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
191195
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
192196
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
193197
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -198,13 +202,15 @@ @implementation GGMLMetalClass
198202
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
199203
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
200204
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
205+
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
201206
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
202207
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
203208
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
204209
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
205210
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
206211
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
207212
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
213+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
208214
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
209215
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
210216
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
@@ -747,9 +753,10 @@ void ggml_metal_graph_compute(
747753
ne00%32 == 0 &&
748754
ne11 > 1) {
749755
switch (src0->type) {
750-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
756+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
751757
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
752758
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
759+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
753760
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
754761
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
755762
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
@@ -800,6 +807,15 @@ void ggml_metal_graph_compute(
800807
nth1 = 8;
801808
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
802809
} break;
810+
case GGML_TYPE_Q8_0:
811+
{
812+
GGML_ASSERT(ne02 == 1);
813+
GGML_ASSERT(ne12 == 1);
814+
815+
nth0 = 8;
816+
nth1 = 8;
817+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
818+
} break;
803819
case GGML_TYPE_Q2_K:
804820
{
805821
GGML_ASSERT(ne02 == 1);
@@ -871,7 +887,7 @@ void ggml_metal_graph_compute(
871887
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
872888
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
873889

874-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
890+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
875891
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
876892
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
877893
}
@@ -896,9 +912,10 @@ void ggml_metal_graph_compute(
896912
case GGML_OP_GET_ROWS:
897913
{
898914
switch (src0->type) {
899-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
915+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
900916
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
901917
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
918+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
902919
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
903920
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
904921
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;

ggml-metal.metal

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ typedef struct {
1818
uint8_t qs[QK4_1 / 2]; // nibbles / quants
1919
} block_q4_1;
2020

21+
#define QK8_0 32
22+
typedef struct {
23+
half d; // delta
24+
int8_t qs[QK8_0]; // quants
25+
} block_q8_0;
26+
2127
kernel void kernel_add(
2228
device const float * src0,
2329
device const float * src1,
@@ -357,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
357363
const int first_row = (r0 * nsg + sgitg) * nr;
358364
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
359365
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
360-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
366+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
361367
float yl[16]; // src1 vector cache
362368
float sumf[nr]={0.f};
363369

@@ -429,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
429435
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
430436
}
431437

438+
kernel void kernel_mul_mat_q8_0_f32(
439+
device const void * src0,
440+
device const float * src1,
441+
device float * dst,
442+
constant int64_t & ne00,
443+
constant int64_t & ne01[[buffer(4)]],
444+
constant int64_t & ne02[[buffer(5)]],
445+
constant int64_t & ne10[[buffer(9)]],
446+
constant int64_t & ne12[[buffer(11)]],
447+
constant int64_t & ne0[[buffer(15)]],
448+
constant int64_t & ne1[[buffer(16)]],
449+
constant uint & gqa[[buffer(17)]],
450+
uint3 tgpig[[threadgroup_position_in_grid]],
451+
uint tiisg[[thread_index_in_simdgroup]],
452+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
453+
const int nr = N_DST;
454+
const int nsg = N_SIMDGROUP;
455+
const int nw = N_SIMDWIDTH;
456+
457+
const int nb = ne00/QK8_0;
458+
const int r0 = tgpig.x;
459+
const int r1 = tgpig.y;
460+
const int im = tgpig.z;
461+
const int first_row = (r0 * nsg + sgitg) * nr;
462+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
463+
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465+
466+
float yl[16];
467+
float sumf[nr]={0.f};
468+
469+
const int ix = tiisg/2;
470+
const int il = tiisg%2;
471+
472+
device const float * yb = y + ix * QK8_0 + 16*il;
473+
474+
// each thread in a SIMD group deals with half a block.
475+
for (int ib = ix; ib < nb; ib += nw/2) {
476+
for (int i = 0; i < 16; ++i) {
477+
yl[i] = yb[i];
478+
}
479+
480+
for (int row = 0; row < nr; row++) {
481+
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
482+
float sumq = 0.f;
483+
for (int iq = 0; iq < 16; ++iq) {
484+
sumq += qs[iq] * yl[iq];
485+
}
486+
sumf[row] += sumq*x[ib+row*nb].d;
487+
}
488+
489+
yb += QK8_0 * 16;
490+
}
491+
492+
for (int row = 0; row < nr; ++row) {
493+
const float tot = simd_sum(sumf[row]);
494+
if (tiisg == 0 && first_row + row < ne01) {
495+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
496+
}
497+
}
498+
}
499+
432500
kernel void kernel_mul_mat_f16_f32(
433501
device const char * src0,
434502
device const char * src1,
@@ -480,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
480548
}
481549
}
482550

483-
484551
kernel void kernel_alibi_f32(
485552
device const float * src0,
486553
device float * dst,
@@ -1621,12 +1688,12 @@ template <typename type4x4>
16211688
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
16221689
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
16231690
const half d = il ? (xb->d / 16.h) : xb->d;
1624-
const half m = il ? (-8.h * 16.h) : -8.h;
1691+
const half m = il ? ( -8.h * 16.h) : -8.h;
16251692
const ushort mask0 = il ? 0x00F0 : 0x000F;
16261693
const ushort mask1 = il ? 0xF000 : 0x0F00;
16271694

16281695
for (int i=0;i<8;i++) {
1629-
reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
1696+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
16301697
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
16311698
}
16321699
}
@@ -1640,11 +1707,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
16401707
const ushort mask1 = il ? 0xF000 : 0x0F00;
16411708

16421709
for (int i=0;i<8;i++) {
1643-
reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
1710+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
16441711
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
16451712
}
16461713
}
16471714

1715+
template <typename type4x4>
1716+
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1717+
device const int8_t * qs = ((device const int8_t *)xb->qs);
1718+
const half d = xb->d;
1719+
1720+
for (int i=0;i<16;i++) {
1721+
reg[i/4][i%4] = (qs[i + 16*il] * d);
1722+
}
1723+
}
1724+
16481725
template <typename type4x4>
16491726
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
16501727
const half d = xb->d;
@@ -1947,9 +2024,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
19472024
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
19482025
constant uint64_t &, constant uint64_t &, uint, uint, uint);
19492026

1950-
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2027+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
19512028
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
19522029
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2030+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
19532031
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
19542032
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
19552033
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -1960,9 +2038,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
19602038
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
19612039
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
19622040

1963-
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2041+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
19642042
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
19652043
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2044+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
19662045
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
19672046
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
19682047
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;

0 commit comments

Comments
 (0)