Skip to content

Commit fcd0c2c

Browse files
CUDA: mul_mat_id always on GPU for batches >= 32
1 parent 799fc22 commit fcd0c2c

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

ggml-cuda.cu

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8782,8 +8782,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
87828782
// TODO: mmq/mmv support
87838783
#endif
87848784

8785-
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
8786-
87878785
const int64_t nb11 = src1->nb[1];
87888786
const int64_t nb1 = dst->nb[1];
87898787

@@ -8812,13 +8810,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88128810
ggml_tensor src1_row = *src1;
88138811
ggml_tensor dst_row = *dst;
88148812

8813+
src1_row.backend = GGML_BACKEND_GPU;
8814+
dst_row.backend = GGML_BACKEND_GPU;
8815+
88158816
src1_row.extra = &src1_row_extra;
88168817
dst_row.extra = &dst_row_extra;
88178818

8818-
char * src1_original = (char *) src1_extra->data_device[g_main_device];
8819-
char * dst_original = (char *) dst_extra->data_device[g_main_device];
8819+
char * src1_original = src1->backend == GGML_BACKEND_CPU ?
8820+
(char *) src1->data : (char *) src1_extra->data_device[g_main_device];
8821+
char * dst_original = dst->backend == GGML_BACKEND_CPU ?
8822+
(char *) dst->data : (char *) dst_extra->data_device[g_main_device];
88208823

88218824
if (src1->ne[1] == 1) {
8825+
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
8826+
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
8827+
88228828
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
88238829
//int32_t row_id;
88248830
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
@@ -8846,6 +8852,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88468852
src1_row_extra.data_device[g_main_device] = src1_contiguous;
88478853
dst_row_extra.data_device[g_main_device] = dst_contiguous;
88488854

8855+
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
8856+
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8857+
const cudaMemcpyKind dst_kind = dst->backend == GGML_BACKEND_CPU ?
8858+
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8859+
88498860
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
88508861
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
88518862

@@ -8860,7 +8871,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88608871
GGML_ASSERT(row_id >= 0 && row_id < n_as);
88618872

88628873
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
8863-
nb11, cudaMemcpyDeviceToDevice, stream));
8874+
nb11, src1_kind, stream));
88648875
num_src1_rows++;
88658876
}
88668877

@@ -8892,14 +8903,18 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88928903
GGML_ASSERT(row_id >= 0 && row_id < n_as);
88938904

88948905
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
8895-
nb1, cudaMemcpyDeviceToDevice, stream));
8906+
nb1, dst_kind, stream));
88968907
num_src1_rows++;
88978908
}
88988909
}
88998910

89008911
ggml_cuda_pool_free(src1_contiguous, as_src1);
89018912
ggml_cuda_pool_free(dst_contiguous, as_dst);
89028913
}
8914+
8915+
if (dst->backend == GGML_BACKEND_CPU) {
8916+
CUDA_CHECK(cudaStreamSynchronize(stream));
8917+
}
89038918
}
89048919

89058920
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9298,7 +9313,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
92989313
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
92999314
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
93009315

9301-
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
9316+
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
93029317
return false;
93039318
}
93049319

0 commit comments

Comments
 (0)