Skip to content

Commit 751687c

Browse files
CUDA: mul_mat_id always on GPU for batches >= 32
1 parent 799fc22 commit 751687c

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

ggml-cuda.cu

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8782,8 +8782,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
87828782
// TODO: mmq/mmv support
87838783
#endif
87848784

8785-
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
8786-
87878785
const int64_t nb11 = src1->nb[1];
87888786
const int64_t nb1 = dst->nb[1];
87898787

@@ -8812,13 +8810,24 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88128810
ggml_tensor src1_row = *src1;
88138811
ggml_tensor dst_row = *dst;
88148812

8813+
ggml_backend_type src1_original_backend = src1_row.backend;
8814+
ggml_backend_type dst_original_backend = dst_row.backend;
8815+
8816+
src1_row.backend = GGML_BACKEND_GPU;
8817+
dst_row.backend = GGML_BACKEND_GPU;
8818+
88158819
src1_row.extra = &src1_row_extra;
88168820
dst_row.extra = &dst_row_extra;
88178821

8818-
char * src1_original = (char *) src1_extra->data_device[g_main_device];
8819-
char * dst_original = (char *) dst_extra->data_device[g_main_device];
8822+
char * src1_original = src1_original_backend == GGML_BACKEND_CPU ?
8823+
(char *) src1->data : (char *) src1_extra->data_device[g_main_device];
8824+
char * dst_original = dst_original_backend == GGML_BACKEND_CPU ?
8825+
(char *) dst->data : (char *) dst_extra->data_device[g_main_device];
88208826

88218827
if (src1->ne[1] == 1) {
8828+
GGML_ASSERT(src1_original_backend == GGML_BACKEND_GPU);
8829+
GGML_ASSERT(dst_original_backend == GGML_BACKEND_GPU);
8830+
88228831
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
88238832
//int32_t row_id;
88248833
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
@@ -8846,6 +8855,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88468855
src1_row_extra.data_device[g_main_device] = src1_contiguous;
88478856
dst_row_extra.data_device[g_main_device] = dst_contiguous;
88488857

8858+
const cudaMemcpyKind src1_kind = src1_original_backend == GGML_BACKEND_CPU ?
8859+
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8860+
const cudaMemcpyKind dst_kind = src1_original_backend == GGML_BACKEND_CPU ?
8861+
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8862+
88498863
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
88508864
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
88518865

@@ -8860,7 +8874,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88608874
GGML_ASSERT(row_id >= 0 && row_id < n_as);
88618875

88628876
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
8863-
nb11, cudaMemcpyDeviceToDevice, stream));
8877+
nb11, src1_kind, stream));
88648878
num_src1_rows++;
88658879
}
88668880

@@ -8892,14 +8906,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88928906
GGML_ASSERT(row_id >= 0 && row_id < n_as);
88938907

88948908
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
8895-
nb1, cudaMemcpyDeviceToDevice, stream));
8909+
nb1, dst_kind, stream));
88968910
num_src1_rows++;
88978911
}
88988912
}
88998913

89008914
ggml_cuda_pool_free(src1_contiguous, as_src1);
89018915
ggml_cuda_pool_free(dst_contiguous, as_dst);
89028916
}
8917+
8918+
if (dst_original_backend == GGML_BACKEND_CPU) {
8919+
CUDA_CHECK(cudaStreamSynchronize(stream));
8920+
}
8921+
8922+
src1_row.backend = src1_original_backend;
8923+
dst_row.backend = dst_original_backend;
89038924
}
89048925

89058926
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9298,7 +9319,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
92989319
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
92999320
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
93009321

9301-
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
9322+
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
93029323
return false;
93039324
}
93049325

0 commit comments

Comments
 (0)