Skip to content

Commit d882d1c

Browse files
Performance no longer terrible
1 parent 4b12881 commit d882d1c

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

ggml-cuda.cu

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,34 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
488488
}
489489
}
490490

491+
static cudaError_t ggml_cuda_h2d_tensor_2d_hack(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream, void * wdata) {
492+
const uint64_t ne0 = src->ne[0];
493+
const uint64_t ne1 = src->ne[1];
494+
const uint64_t nb0 = src->nb[0];
495+
const uint64_t nb1 = src->nb[1];
496+
const uint64_t nb2 = src->nb[2];
497+
const uint64_t nb3 = src->nb[3];
498+
const enum ggml_type type = src->type;
499+
const size_t ts = ggml_type_size(type);
500+
const size_t bs = ggml_blck_size(type);
501+
502+
const void * x = (const void *) ((const char *) wdata + i2*nb2 + i3*nb3);
503+
if (nb0 == ts && nb1 == ts*ne0/bs) {
504+
return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
505+
} else if (nb0 == ts) {
506+
return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
507+
} else {
508+
for (uint64_t i1 = 0; i1 < ne1; i1++) {
509+
const void * rx = (const void *) ((const char *) x + i1*nb1);
510+
void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
511+
// pretend the row is a matrix with cols=1
512+
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
513+
if (r != cudaSuccess) return r;
514+
}
515+
return cudaSuccess;
516+
}
517+
}
518+
491519
static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
492520
const int64_t ne00 = src0->ne[0];
493521
const int64_t ne01 = src0->ne[1];
@@ -695,13 +723,13 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
695723
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
696724

697725
// copy src1 to device
698-
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
726+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d_hack(c_Y, src1, i03, i02, cudaStream, wdata));
699727

700728
// wait for data
701729
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
702730

703731
// compute
704-
dequantize_mul_mat_q4_0_cuda(c_Q, wdata + i * QK8_0, c_D, ne00, ne01, cudaStream);
732+
dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
705733
CUDA_CHECK(cudaGetLastError());
706734

707735
} else {

0 commit comments

Comments
 (0)