Skip to content

Commit 76aa30a

Browse files
ikawrakowKawrakow
andauthored
Add ability to use Q5_0, Q5_1, and IQ4_NL for quantized K cache (#6183)
* k_cache: be able to use Q5_0 * k_cache: be able to use Q5_1 on CODA * k_cache: be able to use Q5_0 on Metal * k_cache: be able to use Q5_1 on Metal * k_cache: be able to use IQ4_NL - just CUDA for now * k_cache: be able to use IQ4_NL on Metal * k_cache: add newly added supported types to llama-bench and CUDA supports_op --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent c5b8595 commit 76aa30a

File tree

5 files changed

+424
-15
lines changed

5 files changed

+424
-15
lines changed

common/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
15901590
if (s == "q4_1") {
15911591
return GGML_TYPE_Q4_1;
15921592
}
1593+
if (s == "iq4_nl") {
1594+
return GGML_TYPE_IQ4_NL;
1595+
}
15931596
if (s == "q5_0") {
15941597
return GGML_TYPE_Q5_0;
15951598
}

examples/llama-bench/llama-bench.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
249249
if (s == "q5_1") {
250250
return GGML_TYPE_Q5_1;
251251
}
252+
if (s == "iq4_nl") {
253+
return GGML_TYPE_IQ4_NL;
254+
}
252255

253256
return GGML_TYPE_COUNT;
254257
}

ggml-cuda.cu

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6757,6 +6757,123 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
67576757
}
67586758
}
67596759

6760+
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
6761+
const float * xi = (const float *) cxi;
6762+
block_q5_0 * dsti = (block_q5_0 *) cdsti;
6763+
6764+
float amax = 0.0f;
6765+
float vmax = 0.0f;
6766+
6767+
for (int j = 0; j < QK5_0; ++j) {
6768+
const float v = xi[j];
6769+
if (amax < fabsf(v)) {
6770+
amax = fabsf(v);
6771+
vmax = v;
6772+
}
6773+
}
6774+
6775+
const float d = vmax / -16;
6776+
const float id = d ? 1.0f/d : 0.0f;
6777+
6778+
dsti->d = d;
6779+
6780+
uint32_t qh = 0;
6781+
for (int j = 0; j < QK5_0/2; ++j) {
6782+
const float x0 = xi[0 + j]*id;
6783+
const float x1 = xi[QK5_0/2 + j]*id;
6784+
6785+
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
6786+
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
6787+
6788+
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
6789+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
6790+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
6791+
}
6792+
memcpy(dsti->qh, &qh, sizeof(qh));
6793+
}
6794+
6795+
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
6796+
const float * xi = (const float *) cxi;
6797+
block_q5_1 * dsti = (block_q5_1 *) cdsti;
6798+
6799+
float min = xi[0];
6800+
float max = xi[0];
6801+
6802+
for (int j = 1; j < QK5_1; ++j) {
6803+
const float v = xi[j];
6804+
min = v < min ? v : min;
6805+
max = v > max ? v : max;
6806+
}
6807+
6808+
const float d = (max - min) / 31;
6809+
const float id = d ? 1.0f/d : 0.0f;
6810+
6811+
dsti->dm.x = d;
6812+
dsti->dm.y = min;
6813+
6814+
uint32_t qh = 0;
6815+
for (int j = 0; j < QK5_1/2; ++j) {
6816+
const float x0 = (xi[0 + j] - min)*id;
6817+
const float x1 = (xi[QK5_1/2 + j] - min)*id;
6818+
6819+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
6820+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
6821+
6822+
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
6823+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
6824+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
6825+
}
6826+
memcpy(dsti->qh, &qh, sizeof(qh));
6827+
}
6828+
6829+
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
6830+
if (x <= val[0]) return 0;
6831+
if (x >= val[n-1]) return n-1;
6832+
int ml = 0, mu = n-1;
6833+
while (mu-ml > 1) {
6834+
int mav = (ml+mu)/2;
6835+
if (x < val[mav]) mu = mav; else ml = mav;
6836+
}
6837+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
6838+
}
6839+
6840+
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
6841+
const float * xi = (const float *) cxi;
6842+
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
6843+
6844+
float amax = 0.0f;
6845+
float vmax = 0.0f;
6846+
6847+
for (int j = 0; j < QK4_NL; ++j) {
6848+
const float v = xi[j];
6849+
if (amax < fabsf(v)) {
6850+
amax = fabsf(v);
6851+
vmax = v;
6852+
}
6853+
}
6854+
6855+
float d = vmax / kvalues_iq4nl[0];
6856+
const float id = d ? 1.0f/d : 0.0f;
6857+
6858+
float sumqx = 0, sumq2 = 0;
6859+
for (int j = 0; j < QK4_NL/2; ++j) {
6860+
const float x0 = xi[0 + j]*id;
6861+
const float x1 = xi[QK4_NL/2 + j]*id;
6862+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
6863+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
6864+
dsti->qs[j] = xi0 | (xi1 << 4);
6865+
const float v0 = kvalues_iq4nl[xi0];
6866+
const float v1 = kvalues_iq4nl[xi1];
6867+
const float w0 = xi[0 + j]*xi[0 + j];
6868+
const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
6869+
sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
6870+
sumq2 += w0*v0*v0 + w1*v1*v1;
6871+
}
6872+
6873+
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
6874+
}
6875+
6876+
67606877
template <cpy_kernel_t cpy_blck, int qk>
67616878
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
67626879
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -8490,6 +8607,39 @@ static void ggml_cpy_f32_q4_1_cuda(
84908607
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
84918608
}
84928609

8610+
static void ggml_cpy_f32_q5_0_cuda(
8611+
const char * cx, char * cdst, const int ne,
8612+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
8613+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
8614+
8615+
GGML_ASSERT(ne % QK5_0 == 0);
8616+
const int num_blocks = ne / QK5_0;
8617+
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
8618+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
8619+
}
8620+
8621+
static void ggml_cpy_f32_q5_1_cuda(
8622+
const char * cx, char * cdst, const int ne,
8623+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
8624+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
8625+
8626+
GGML_ASSERT(ne % QK5_1 == 0);
8627+
const int num_blocks = ne / QK5_1;
8628+
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
8629+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
8630+
}
8631+
8632+
static void ggml_cpy_f32_iq4_nl_cuda(
8633+
const char * cx, char * cdst, const int ne,
8634+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
8635+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
8636+
8637+
GGML_ASSERT(ne % QK4_NL == 0);
8638+
const int num_blocks = ne / QK4_NL;
8639+
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
8640+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
8641+
}
8642+
84938643
static void ggml_cpy_f16_f16_cuda(
84948644
const char * cx, char * cdst, const int ne,
84958645
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -10888,6 +11038,12 @@ static void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * s
1088811038
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
1088911039
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
1089011040
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
11041+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
11042+
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
11043+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
11044+
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
11045+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
11046+
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
1089111047
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
1089211048
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
1089311049
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -11304,6 +11460,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
1130411460
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
1130511461
return true;
1130611462
}
11463+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
11464+
return true;
11465+
}
11466+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
11467+
return true;
11468+
}
11469+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
11470+
return true;
11471+
}
1130711472
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
1130811473
return true;
1130911474
}

ggml-metal.m

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,9 @@
173173
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
174174
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
175175
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
176-
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
177-
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
176+
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
177+
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
178+
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
178179
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
179180
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
180181
GGML_METAL_KERNEL_TYPE_CONCAT,
@@ -598,8 +599,9 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
598599
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
599600
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
600601
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
601-
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
602-
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
602+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
603+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
604+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
603605
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
604606
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
605607
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
@@ -739,6 +741,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
739741
case GGML_TYPE_Q8_0:
740742
case GGML_TYPE_Q4_0:
741743
case GGML_TYPE_Q4_1:
744+
case GGML_TYPE_Q5_0:
745+
case GGML_TYPE_Q5_1:
746+
case GGML_TYPE_IQ4_NL:
742747
return true;
743748
default:
744749
return false;
@@ -2431,13 +2436,14 @@ static enum ggml_status ggml_metal_graph_compute(
24312436
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
24322437

24332438
switch (dstt) {
2434-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2435-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2436-
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2437-
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2438-
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2439-
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2440-
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2439+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2440+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2441+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2442+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2443+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2444+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2445+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2446+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
24412447
default: GGML_ASSERT(false && "not implemented");
24422448
};
24432449
} break;

0 commit comments

Comments
 (0)