@@ -88,6 +88,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
88
88
}
89
89
}
90
90
91
+ static __device__ void cpy_blck_q8_0_f32 (const char * cxi, char * cdsti) {
92
+ const block_q8_0 * xi = (const block_q8_0 *) cxi;
93
+ float * dsti = (float *) cdsti;
94
+
95
+ const float d = (float )xi->d ;
96
+
97
+ for (int j = 0 ; j < QK8_0; j++) {
98
+ dsti[j] = xi->qs [j] * d;
99
+ }
100
+ }
101
+
91
102
static __device__ void cpy_blck_f32_q4_0 (const char * cxi, char * cdsti) {
92
103
const float * xi = (const float *) cxi;
93
104
block_q4_0 * dsti = (block_q4_0 *) cdsti;
@@ -337,6 +348,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
337
348
cpy_blck (cx + x_offset, cdst + dst_offset);
338
349
}
339
350
351
+ template <cpy_kernel_t cpy_blck, int qk>
352
+ static __global__ void cpy_q_f32 (const char * cx, char * cdst, const int ne,
353
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
354
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
355
+ const int nb12, const int nb13) {
356
+ const int i = (blockDim .x *blockIdx .x + threadIdx .x )*qk;
357
+
358
+ if (i >= ne) {
359
+ return ;
360
+ }
361
+
362
+ const int i03 = i/(ne00 * ne01 * ne02);
363
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
364
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
365
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
366
+ const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
367
+
368
+ const int i13 = i/(ne10 * ne11 * ne12);
369
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
370
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
371
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
372
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
373
+
374
+ cpy_blck (cx + x_offset, cdst + dst_offset);
375
+ }
376
+
340
377
static void ggml_cpy_f16_f32_cuda (
341
378
const char * cx, char * cdst, const int ne,
342
379
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -388,6 +425,16 @@ static void ggml_cpy_f32_q8_0_cuda(
388
425
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
389
426
}
390
427
428
+ static void ggml_cpy_q8_0_f32_cuda (
429
+ const char * cx, char * cdst, const int ne,
430
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
431
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
432
+
433
+ const int num_blocks = ne;
434
+ cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1 , 0 , stream>>>
435
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
436
+ }
437
+
391
438
static void ggml_cpy_f32_q4_0_cuda (
392
439
const char * cx, char * cdst, const int ne,
393
440
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -509,6 +556,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
509
556
ggml_cpy_f32_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
510
557
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
511
558
ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
559
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
560
+ ggml_cpy_q8_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
512
561
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
513
562
ggml_cpy_f32_q4_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
514
563
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -547,6 +596,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
547
596
return (void *) cpy_f32_f16<cpy_1_f32_bf16>;
548
597
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
549
598
return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
599
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
600
+ return (void *) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
550
601
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
551
602
return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
552
603
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
0 commit comments