Skip to content

Commit cc178c8

Browse files
gcpmostlyuseful
authored andcommitted
cuda: Add Q5_1, Q5_0, Q4_1 and Q4_0 to F32 conversion support. (ggml-org#12000)
1 parent b712c96 commit cc178c8

File tree

2 files changed

+104
-7
lines changed

2 files changed

+104
-7
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "cpy.cuh"
2+
#include "dequantize.cuh"
23

34
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
45

@@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
8283
}
8384

8485
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
85-
const block_q8_0 * xi = (const block_q8_0 *) cxi;
86-
float * dsti = (float *) cdsti;
87-
88-
const float d = (float)xi->d;
89-
90-
for (int j = 0; j < QK8_0; j++) {
91-
dsti[j] = xi->qs[j] * d;
86+
float * cdstf = (float *)(cdsti);
87+
88+
#pragma unroll
89+
for (int j = 0; j < QK8_0; j += 2) {
90+
dfloat2 dq;
91+
dequantize_q8_0(cxi, 0, j, dq);
92+
*(cdstf + j) = dq.x;
93+
*(cdstf + j + 1) = dq.y;
9294
}
9395
}
9496

@@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
225227
memcpy(dsti->qh, &qh, sizeof(qh));
226228
}
227229

230+
template<dequantize_kernel_t dequant, int qk>
231+
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
232+
float * cdstf = (float *)(cdsti);
233+
234+
#pragma unroll
235+
for (int j = 0; j < qk/2; j++) {
236+
dfloat2 dq;
237+
dequant(cxi, 0, j, dq);
238+
*(cdstf + j) = dq.x;
239+
*(cdstf + j + qk/2) = dq.y;
240+
}
241+
}
228242

229243
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
230244
if (x <= val[0]) return 0;
@@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
387401
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
388402
}
389403

404+
static void ggml_cpy_q4_0_f32_cuda(
405+
const char * cx, char * cdst, const int ne,
406+
const int ne00, const int ne01, const int ne02,
407+
const int nb00, const int nb01, const int nb02,
408+
const int nb03, const int ne10, const int ne11, const int ne12,
409+
const int nb10, const int nb11, const int nb12, const int nb13,
410+
cudaStream_t stream) {
411+
const int num_blocks = ne;
412+
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
413+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
414+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
415+
}
416+
390417
static void ggml_cpy_f32_q4_1_cuda(
391418
const char * cx, char * cdst, const int ne,
392419
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
398425
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
399426
}
400427

428+
static void ggml_cpy_q4_1_f32_cuda(
429+
const char * cx, char * cdst, const int ne,
430+
const int ne00, const int ne01, const int ne02,
431+
const int nb00, const int nb01, const int nb02,
432+
const int nb03, const int ne10, const int ne11, const int ne12,
433+
const int nb10, const int nb11, const int nb12, const int nb13,
434+
cudaStream_t stream) {
435+
const int num_blocks = ne;
436+
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
437+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
438+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
439+
}
440+
401441
static void ggml_cpy_f32_q5_0_cuda(
402442
const char * cx, char * cdst, const int ne,
403443
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
409449
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
410450
}
411451

452+
static void ggml_cpy_q5_0_f32_cuda(
453+
const char * cx, char * cdst, const int ne,
454+
const int ne00, const int ne01, const int ne02,
455+
const int nb00, const int nb01, const int nb02,
456+
const int nb03, const int ne10, const int ne11, const int ne12,
457+
const int nb10, const int nb11, const int nb12, const int nb13,
458+
cudaStream_t stream) {
459+
const int num_blocks = ne;
460+
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
461+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
462+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
463+
}
464+
412465
static void ggml_cpy_f32_q5_1_cuda(
413466
const char * cx, char * cdst, const int ne,
414467
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
420473
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
421474
}
422475

476+
static void ggml_cpy_q5_1_f32_cuda(
477+
const char * cx, char * cdst, const int ne,
478+
const int ne00, const int ne01, const int ne02,
479+
const int nb00, const int nb01, const int nb02,
480+
const int nb03, const int ne10, const int ne11, const int ne12,
481+
const int nb10, const int nb11, const int nb12, const int nb13,
482+
cudaStream_t stream) {
483+
const int num_blocks = ne;
484+
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
485+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
486+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
487+
}
488+
423489
static void ggml_cpy_f32_iq4_nl_cuda(
424490
const char * cx, char * cdst, const int ne,
425491
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
488554
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
489555
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
490556
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
557+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
558+
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
559+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
491560
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
492561
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
562+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
563+
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
564+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
493565
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
494566
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
567+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
568+
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
569+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
495570
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
496571
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
497572
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
498573
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
574+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
575+
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
499576
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
500577
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
501578
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
524601
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
525602
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
526603
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
604+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
605+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
527606
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
528607
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
608+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
609+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
529610
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
530611
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
612+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
613+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
531614
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
532615
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
533616
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
534617
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
618+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
619+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
535620
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
536621
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
537622
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3075,15 +3075,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30753075
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
30763076
return true;
30773077
}
3078+
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
3079+
return true;
3080+
}
30783081
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
30793082
return true;
30803083
}
3084+
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
3085+
return true;
3086+
}
30813087
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
30823088
return true;
30833089
}
3090+
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
3091+
return true;
3092+
}
30843093
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
30853094
return true;
30863095
}
3096+
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
3097+
return true;
3098+
}
30873099
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
30883100
return true;
30893101
}

0 commit comments

Comments
 (0)