Skip to content

Commit 46a0881

Browse files
committed
metal : add dequantize_q8_0 kernel
1 parent c3e53b4 commit 46a0881

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

ggml-metal.m

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
GGML_METAL_DECL_KERNEL(get_rows_f16);
6464
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
6565
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
66+
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
6667
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
6768
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
6869
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -188,6 +189,7 @@ @implementation GGMLMetalClass
188189
GGML_METAL_ADD_KERNEL(get_rows_f16);
189190
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
190191
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
192+
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
191193
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
192194
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
193195
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -896,9 +898,10 @@ void ggml_metal_graph_compute(
896898
case GGML_OP_GET_ROWS:
897899
{
898900
switch (src0->type) {
899-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
901+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
900902
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
901903
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
904+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
902905
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
903906
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
904907
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;

ggml-metal.metal

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ typedef struct {
1818
uint8_t qs[QK4_1 / 2]; // nibbles / quants
1919
} block_q4_1;
2020

21+
#define QK8_0 32
22+
typedef struct {
23+
half d; // delta
24+
int8_t qs[QK8_0]; // quants
25+
} block_q8_0;
26+
2127
kernel void kernel_add(
2228
device const float * src0,
2329
device const float * src1,
@@ -1621,12 +1627,12 @@ template <typename type4x4>
16211627
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
16221628
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
16231629
const half d = il ? (xb->d / 16.h) : xb->d;
1624-
const half m = il ? (-8.h * 16.h) : -8.h;
1630+
const half m = il ? ( -8.h * 16.h) : -8.h;
16251631
const ushort mask0 = il ? 0x00F0 : 0x000F;
16261632
const ushort mask1 = il ? 0xF000 : 0x0F00;
16271633

16281634
for (int i=0;i<8;i++) {
1629-
reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
1635+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
16301636
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
16311637
}
16321638
}
@@ -1640,11 +1646,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
16401646
const ushort mask1 = il ? 0xF000 : 0x0F00;
16411647

16421648
for (int i=0;i<8;i++) {
1643-
reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
1649+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
16441650
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
16451651
}
16461652
}
16471653

1654+
template <typename type4x4>
1655+
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1656+
device const uint8_t * qs = ((device const uint8_t *)xb->qs);
1657+
const half d = xb->d;
1658+
1659+
for (int i=0;i<16;i++) {
1660+
reg[i/4][i%4] = (qs[i + 16*il] * d);
1661+
}
1662+
}
1663+
16481664
template <typename type4x4>
16491665
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
16501666
const half d = xb->d;
@@ -1947,9 +1963,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
19471963
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
19481964
constant uint64_t &, constant uint64_t &, uint, uint, uint);
19491965

1950-
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
1966+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
19511967
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
19521968
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
1969+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
19531970
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
19541971
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
19551972
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -1960,7 +1977,7 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
19601977
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
19611978
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
19621979

1963-
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
1980+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
19641981
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
19651982
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
19661983
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;

0 commit comments

Comments
 (0)