1
1
#include " cpy.cuh"
2
+ #include " dequantize.cuh"
2
3
3
4
typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
4
5
@@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
82
83
}
83
84
84
85
static __device__ void cpy_blck_q8_0_f32 (const char * cxi, char * cdsti) {
85
- const block_q8_0 * xi = (const block_q8_0 *) cxi;
86
- float * dsti = (float *) cdsti;
87
-
88
- const float d = (float )xi->d ;
89
-
90
- for (int j = 0 ; j < QK8_0; j++) {
91
- dsti[j] = xi->qs [j] * d;
86
+ float * cdstf = (float *)(cdsti);
87
+
88
+ #pragma unroll
89
+ for (int j = 0 ; j < QK8_0; j += 2 ) {
90
+ dfloat2 dq;
91
+ dequantize_q8_0 (cxi, 0 , j, dq);
92
+ *(cdstf + j) = dq.x ;
93
+ *(cdstf + j + 1 ) = dq.y ;
92
94
}
93
95
}
94
96
@@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
225
227
memcpy (dsti->qh , &qh, sizeof (qh));
226
228
}
227
229
230
+ template <dequantize_kernel_t dequant, int qk>
231
+ static __device__ void cpy_blck_q_f32 (const char * cxi, char * cdsti) {
232
+ float * cdstf = (float *)(cdsti);
233
+
234
+ #pragma unroll
235
+ for (int j = 0 ; j < qk/2 ; j++) {
236
+ dfloat2 dq;
237
+ dequant (cxi, 0 , j, dq);
238
+ *(cdstf + j) = dq.x ;
239
+ *(cdstf + j + qk/2 ) = dq.y ;
240
+ }
241
+ }
228
242
229
243
static __device__ __forceinline__ int best_index_int8 (int n, const int8_t * val, float x) {
230
244
if (x <= val[0 ]) return 0 ;
@@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
387
401
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
388
402
}
389
403
404
+ static void ggml_cpy_q4_0_f32_cuda (
405
+ const char * cx, char * cdst, const int ne,
406
+ const int ne00, const int ne01, const int ne02,
407
+ const int nb00, const int nb01, const int nb02,
408
+ const int nb03, const int ne10, const int ne11, const int ne12,
409
+ const int nb10, const int nb11, const int nb12, const int nb13,
410
+ cudaStream_t stream) {
411
+ const int num_blocks = ne;
412
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1 , 0 , stream>>> (
413
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
414
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
415
+ }
416
+
390
417
static void ggml_cpy_f32_q4_1_cuda (
391
418
const char * cx, char * cdst, const int ne,
392
419
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
398
425
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
399
426
}
400
427
428
+ static void ggml_cpy_q4_1_f32_cuda (
429
+ const char * cx, char * cdst, const int ne,
430
+ const int ne00, const int ne01, const int ne02,
431
+ const int nb00, const int nb01, const int nb02,
432
+ const int nb03, const int ne10, const int ne11, const int ne12,
433
+ const int nb10, const int nb11, const int nb12, const int nb13,
434
+ cudaStream_t stream) {
435
+ const int num_blocks = ne;
436
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1 , 0 , stream>>> (
437
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
438
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
439
+ }
440
+
401
441
static void ggml_cpy_f32_q5_0_cuda (
402
442
const char * cx, char * cdst, const int ne,
403
443
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
409
449
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
410
450
}
411
451
452
+ static void ggml_cpy_q5_0_f32_cuda (
453
+ const char * cx, char * cdst, const int ne,
454
+ const int ne00, const int ne01, const int ne02,
455
+ const int nb00, const int nb01, const int nb02,
456
+ const int nb03, const int ne10, const int ne11, const int ne12,
457
+ const int nb10, const int nb11, const int nb12, const int nb13,
458
+ cudaStream_t stream) {
459
+ const int num_blocks = ne;
460
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1 , 0 , stream>>> (
461
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
462
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
463
+ }
464
+
412
465
static void ggml_cpy_f32_q5_1_cuda (
413
466
const char * cx, char * cdst, const int ne,
414
467
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
420
473
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
421
474
}
422
475
476
+ static void ggml_cpy_q5_1_f32_cuda (
477
+ const char * cx, char * cdst, const int ne,
478
+ const int ne00, const int ne01, const int ne02,
479
+ const int nb00, const int nb01, const int nb02,
480
+ const int nb03, const int ne10, const int ne11, const int ne12,
481
+ const int nb10, const int nb11, const int nb12, const int nb13,
482
+ cudaStream_t stream) {
483
+ const int num_blocks = ne;
484
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1 , 0 , stream>>> (
485
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
486
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
487
+ }
488
+
423
489
static void ggml_cpy_f32_iq4_nl_cuda (
424
490
const char * cx, char * cdst, const int ne,
425
491
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
488
554
ggml_cpy_q8_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
489
555
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
490
556
ggml_cpy_f32_q4_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
557
+ } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
558
+ ggml_cpy_q4_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
559
+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
491
560
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
492
561
ggml_cpy_f32_q4_1_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
562
+ } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
563
+ ggml_cpy_q4_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
564
+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
493
565
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
494
566
ggml_cpy_f32_q5_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
567
+ } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
568
+ ggml_cpy_q5_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
569
+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
495
570
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
496
571
ggml_cpy_f32_iq4_nl_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
497
572
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
498
573
ggml_cpy_f32_q5_1_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
574
+ } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
575
+ ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
499
576
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
500
577
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
501
578
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
524
601
return (void *) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
525
602
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
526
603
return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
604
+ } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
605
+ return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
527
606
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
528
607
return (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
608
+ } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
609
+ return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
529
610
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
530
611
return (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
612
+ } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
613
+ return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
531
614
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
532
615
return (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
533
616
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
534
617
return (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
618
+ } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
619
+ return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
535
620
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
536
621
return (void *) cpy_f32_f16<cpy_1_f32_f16>;
537
622
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
0 commit comments