Skip to content

Commit 4d53d53

Browse files
IvanNexesenex
authored andcommitted
cuda: add q8_0->f32 cpy operation (ggml-org#9571)
llama: enable K-shift for quantized KV cache It will fail on unsupported backends or quant types.
1 parent 72ef3a7 commit 4d53d53

File tree

3 files changed

+81
-8
lines changed

3 files changed

+81
-8
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2896,6 +2896,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28962896
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
28972897
return true;
28982898
}
2899+
if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
2900+
return true;
2901+
}
28992902
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
29002903
return true;
29012904
}

ggml/src/ggml-cuda/cpy.cu

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
8888
}
8989
}
9090

91+
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
92+
const block_q8_0 * xi = (const block_q8_0 *) cxi;
93+
float * dsti = (float *) cdsti;
94+
95+
const float d = (float)xi->d;
96+
97+
for (int j = 0; j < QK8_0; j++) {
98+
dsti[j] = xi->qs[j] * d;
99+
}
100+
}
101+
91102
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
92103
const float * xi = (const float *) cxi;
93104
block_q4_0 * dsti = (block_q4_0 *) cdsti;
@@ -337,6 +348,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
337348
cpy_blck(cx + x_offset, cdst + dst_offset);
338349
}
339350

351+
template <cpy_kernel_t cpy_blck, int qk>
352+
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
353+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
354+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
355+
const int nb12, const int nb13) {
356+
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
357+
358+
if (i >= ne) {
359+
return;
360+
}
361+
362+
const int i03 = i/(ne00 * ne01 * ne02);
363+
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
364+
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
365+
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
366+
const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
367+
368+
const int i13 = i/(ne10 * ne11 * ne12);
369+
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
370+
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
371+
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
372+
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
373+
374+
cpy_blck(cx + x_offset, cdst + dst_offset);
375+
}
376+
340377
static void ggml_cpy_f16_f32_cuda(
341378
const char * cx, char * cdst, const int ne,
342379
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -388,6 +425,16 @@ static void ggml_cpy_f32_q8_0_cuda(
388425
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
389426
}
390427

428+
static void ggml_cpy_q8_0_f32_cuda(
429+
const char * cx, char * cdst, const int ne,
430+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
431+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
432+
433+
const int num_blocks = ne;
434+
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
435+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
436+
}
437+
391438
static void ggml_cpy_f32_q4_0_cuda(
392439
const char * cx, char * cdst, const int ne,
393440
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -509,6 +556,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
509556
ggml_cpy_f32_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
510557
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
511558
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
559+
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
560+
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
512561
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
513562
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
514563
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -547,6 +596,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
547596
return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
548597
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
549598
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
599+
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
600+
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
550601
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
551602
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
552603
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {

src/llama.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8910,17 +8910,36 @@ struct llm_build_context {
89108910
const int64_t n_head_kv = hparams.n_head_kv(il);
89118911
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
89128912
struct ggml_tensor * rope_factors = build_rope_factors(il);
8913-
struct ggml_tensor * tmp =
8913+
struct ggml_tensor * k =
8914+
ggml_view_3d(ctx0, kv_self.k_l[il],
8915+
n_embd_head_k, n_head_kv, n_ctx,
8916+
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
8917+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
8918+
0);
8919+
8920+
struct ggml_tensor * tmp;
8921+
if (ggml_is_quantized(k->type)) {
8922+
// dequantize to f32 -> RoPE -> quantize back
8923+
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
8924+
cb(tmp, "K_f32", il);
8925+
for (auto * backend : lctx.backends) {
8926+
// Figure out which backend KV cache belongs to
8927+
if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft)) {
8928+
ggml_backend_sched_set_tensor_backend(lctx.sched, tmp, backend);
8929+
break;
8930+
}
8931+
}
8932+
tmp = ggml_rope_ext_inplace(ctx0, tmp,
8933+
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8934+
ext_factor, attn_factor, beta_fast, beta_slow);
8935+
cb(tmp, "K_shifted_f32", il);
8936+
tmp = ggml_cpy(ctx0, tmp, k);
8937+
} else {
89148938
// we rotate only the first n_rot dimensions
8915-
ggml_rope_ext_inplace(ctx0,
8916-
ggml_view_3d(ctx0, kv_self.k_l[il],
8917-
n_embd_head_k, n_head_kv, n_ctx,
8918-
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
8919-
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
8920-
0),
8939+
tmp = ggml_rope_ext_inplace(ctx0, k,
89218940
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
89228941
ext_factor, attn_factor, beta_fast, beta_slow);
8923-
8942+
}
89248943
cb(tmp, "K_shifted", il);
89258944
ggml_build_forward_expand(gf, tmp);
89268945
}

0 commit comments

Comments
 (0)