@@ -8782,8 +8782,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8782
8782
// TODO: mmq/mmv support
8783
8783
#endif
8784
8784
8785
- GGML_ASSERT (dst->backend == GGML_BACKEND_GPU);
8786
-
8787
8785
const int64_t nb11 = src1->nb [1 ];
8788
8786
const int64_t nb1 = dst->nb [1 ];
8789
8787
@@ -8812,13 +8810,24 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8812
8810
ggml_tensor src1_row = *src1;
8813
8811
ggml_tensor dst_row = *dst;
8814
8812
8813
+ ggml_backend_type src1_original_backend = src1_row.backend ;
8814
+ ggml_backend_type dst_original_backend = dst_row.backend ;
8815
+
8816
+ src1_row.backend = GGML_BACKEND_GPU;
8817
+ dst_row.backend = GGML_BACKEND_GPU;
8818
+
8815
8819
src1_row.extra = &src1_row_extra;
8816
8820
dst_row.extra = &dst_row_extra;
8817
8821
8818
- char * src1_original = (char *) src1_extra->data_device [g_main_device];
8819
- char * dst_original = (char *) dst_extra->data_device [g_main_device];
8822
+ char * src1_original = src1_original_backend == GGML_BACKEND_CPU ?
8823
+ (char *) src1->data : (char *) src1_extra->data_device [g_main_device];
8824
+ char * dst_original = dst_original_backend == GGML_BACKEND_CPU ?
8825
+ (char *) dst->data : (char *) dst_extra->data_device [g_main_device];
8820
8826
8821
8827
if (src1->ne [1 ] == 1 ) {
8828
+ GGML_ASSERT (src1_original_backend == GGML_BACKEND_GPU);
8829
+ GGML_ASSERT (dst_original_backend == GGML_BACKEND_GPU);
8830
+
8822
8831
for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
8823
8832
// int32_t row_id;
8824
8833
// CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
@@ -8846,6 +8855,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8846
8855
src1_row_extra.data_device [g_main_device] = src1_contiguous;
8847
8856
dst_row_extra.data_device [g_main_device] = dst_contiguous;
8848
8857
8858
+ const cudaMemcpyKind src1_kind = src1_original_backend == GGML_BACKEND_CPU ?
8859
+ cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8860
+ const cudaMemcpyKind dst_kind = src1_original_backend == GGML_BACKEND_CPU ?
8861
+ cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8862
+
8849
8863
for (int32_t row_id = 0 ; row_id < n_as; ++row_id) {
8850
8864
const struct ggml_tensor * src0_row = dst->src [row_id + 2 ];
8851
8865
@@ -8860,7 +8874,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8860
8874
GGML_ASSERT (row_id >= 0 && row_id < n_as);
8861
8875
8862
8876
CUDA_CHECK (cudaMemcpyAsync (src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
8863
- nb11, cudaMemcpyDeviceToDevice , stream));
8877
+ nb11, src1_kind , stream));
8864
8878
num_src1_rows++;
8865
8879
}
8866
8880
@@ -8892,14 +8906,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8892
8906
GGML_ASSERT (row_id >= 0 && row_id < n_as);
8893
8907
8894
8908
CUDA_CHECK (cudaMemcpyAsync (dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
8895
- nb1, cudaMemcpyDeviceToDevice , stream));
8909
+ nb1, dst_kind , stream));
8896
8910
num_src1_rows++;
8897
8911
}
8898
8912
}
8899
8913
8900
8914
ggml_cuda_pool_free (src1_contiguous, as_src1);
8901
8915
ggml_cuda_pool_free (dst_contiguous, as_dst);
8902
8916
}
8917
+
8918
+ if (dst_original_backend == GGML_BACKEND_CPU) {
8919
+ CUDA_CHECK (cudaStreamSynchronize (stream));
8920
+ }
8921
+
8922
+ src1_row.backend = src1_original_backend;
8923
+ dst_row.backend = dst_original_backend;
8903
8924
}
8904
8925
8905
8926
static void ggml_cuda_scale (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9298,7 +9319,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
9298
9319
|| (tensor->src [0 ] != nullptr && (tensor->src [0 ]->backend == GGML_BACKEND_GPU || tensor->src [0 ]->backend == GGML_BACKEND_GPU_SPLIT))
9299
9320
|| (tensor->src [1 ] != nullptr && tensor->src [1 ]->backend == GGML_BACKEND_GPU);
9300
9321
9301
- if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
9322
+ if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor-> op != GGML_OP_MUL_MAT_ID ) {
9302
9323
return false ;
9303
9324
}
9304
9325
0 commit comments