@@ -428,7 +428,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
428
428
char * src0_ddc = (char *) src0->data ;
429
429
char * src1_ddc = (char *) src1->data ;
430
430
431
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
431
+ if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
432
+ GGML_ASSERT (ggml_nbytes (src0) == ggml_nbytes (src1));
433
+ CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
434
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
432
435
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
433
436
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
434
437
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -461,26 +464,28 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
461
464
}
462
465
463
466
void * ggml_cuda_cpy_fn (const ggml_tensor * src0, ggml_tensor * src1) {
464
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
465
- return (void *) cpy_f32_f16<cpy_1_f32_f32>;
467
+ if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
468
+ return nullptr ;
469
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
470
+ return (void *) cpy_f32_f16<cpy_1_f32_f32>;
466
471
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
467
- return (void *) cpy_f32_f16<cpy_1_f32_f16>;
472
+ return (void *) cpy_f32_f16<cpy_1_f32_f16>;
468
473
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
469
- return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
474
+ return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
470
475
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
471
- return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
476
+ return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
472
477
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
473
- return (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
478
+ return (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
474
479
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
475
- return (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
480
+ return (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
476
481
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
477
- return (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
482
+ return (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
478
483
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
479
- return (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
484
+ return (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
480
485
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
481
- return (void *) cpy_f32_f16<cpy_1_f32_f16>;
486
+ return (void *) cpy_f32_f16<cpy_1_f32_f16>;
482
487
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
483
- return (void *) cpy_f32_f16<cpy_1_f16_f32>;
488
+ return (void *) cpy_f32_f16<cpy_1_f16_f32>;
484
489
} else {
485
490
fprintf (stderr, " %s: unsupported type combination (%s to %s)\n " , __func__,
486
491
ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments