@@ -488,6 +488,34 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
488
488
}
489
489
}
490
490
491
+ static cudaError_t ggml_cuda_h2d_tensor_2d_hack (void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream, void * wdata) {
492
+ const uint64_t ne0 = src->ne [0 ];
493
+ const uint64_t ne1 = src->ne [1 ];
494
+ const uint64_t nb0 = src->nb [0 ];
495
+ const uint64_t nb1 = src->nb [1 ];
496
+ const uint64_t nb2 = src->nb [2 ];
497
+ const uint64_t nb3 = src->nb [3 ];
498
+ const enum ggml_type type = src->type ;
499
+ const size_t ts = ggml_type_size (type);
500
+ const size_t bs = ggml_blck_size (type);
501
+
502
+ const void * x = (const void *) ((const char *) wdata + i2*nb2 + i3*nb3);
503
+ if (nb0 == ts && nb1 == ts*ne0/bs) {
504
+ return cudaMemcpyAsync (dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
505
+ } else if (nb0 == ts) {
506
+ return cudaMemcpy2DAsync (dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
507
+ } else {
508
+ for (uint64_t i1 = 0 ; i1 < ne1; i1++) {
509
+ const void * rx = (const void *) ((const char *) x + i1*nb1);
510
+ void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
511
+ // pretend the row is a matrix with cols=1
512
+ cudaError_t r = cudaMemcpy2DAsync (rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
513
+ if (r != cudaSuccess) return r;
514
+ }
515
+ return cudaSuccess;
516
+ }
517
+ }
518
+
491
519
static void ggml_cuda_mul_mat_f32 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
492
520
const int64_t ne00 = src0->ne [0 ];
493
521
const int64_t ne01 = src0->ne [1 ];
@@ -695,13 +723,13 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
695
723
CUDA_CHECK (cudaEventRecord (cudaEvent, cudaStream2));
696
724
697
725
// copy src1 to device
698
- CUDA_CHECK (ggml_cuda_h2d_tensor_2d (c_Y, src1, i03, i02, cudaStream));
726
+ CUDA_CHECK (ggml_cuda_h2d_tensor_2d_hack (c_Y, src1, i03, i02, cudaStream, wdata ));
699
727
700
728
// wait for data
701
729
CUDA_CHECK (cudaStreamWaitEvent (cudaStream, cudaEvent, 0 ));
702
730
703
731
// compute
704
- dequantize_mul_mat_q4_0_cuda (c_Q, wdata + i * QK8_0 , c_D, ne00, ne01, cudaStream);
732
+ dequantize_mul_mat_q4_0_cuda (c_Q, c_Y , c_D, ne00, ne01, cudaStream);
705
733
CUDA_CHECK (cudaGetLastError ());
706
734
707
735
} else {
0 commit comments