@@ -8781,16 +8781,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8781
8781
8782
8782
GGML_ASSERT (dst->backend == GGML_BACKEND_GPU);
8783
8783
8784
+ const int64_t nb11 = src1->nb [1 ];
8785
+ const int64_t nb1 = dst->nb [1 ];
8786
+
8784
8787
const struct ggml_tensor * ids = src0;
8785
8788
const int32_t id = ((int32_t *) dst->op_params )[0 ];
8786
8789
const int32_t n_as = ((int32_t *) dst->op_params )[1 ];
8787
8790
8788
8791
std::vector<char > ids_host (ggml_nbytes (ids));
8789
8792
8793
+ const cudaStream_t stream = g_cudaStreams[g_main_device][0 ];
8794
+
8790
8795
if (ids->backend == GGML_BACKEND_GPU) {
8791
8796
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra )->data_device [g_main_device];
8792
- CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids_dev, ggml_nbytes (ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][ 0 ] ));
8793
- CUDA_CHECK (cudaStreamSynchronize (g_cudaStreams[g_main_device][ 0 ] ));
8797
+ CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids_dev, ggml_nbytes (ids), cudaMemcpyDeviceToHost, stream ));
8798
+ CUDA_CHECK (cudaStreamSynchronize (stream ));
8794
8799
} else {
8795
8800
memcpy (ids_host.data (), ids->data , ggml_nbytes (ids));
8796
8801
}
@@ -8804,37 +8809,93 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8804
8809
ggml_tensor src1_row = *src1;
8805
8810
ggml_tensor dst_row = *dst;
8806
8811
8807
- src1_row.ne [ 1 ] = 1 ;
8808
- dst_row.ne [ 1 ] = 1 ;
8812
+ src1_row.extra = &src1_row_extra ;
8813
+ dst_row.extra = &dst_row_extra ;
8809
8814
8810
- src1_row. nb [ 2 ] = src1_row. nb [ 1 ];
8811
- dst_row. nb [ 2 ] = dst_row. nb [ 1 ];
8815
+ char * src1_original = ( char *) src1_extra-> data_device [g_main_device ];
8816
+ char * dst_original = ( char *) dst_extra-> data_device [g_main_device ];
8812
8817
8813
- src1_row.nb [3 ] = src1_row.nb [1 ];
8814
- dst_row.nb [3 ] = dst_row.nb [1 ];
8818
+ if (src1->ne [1 ] == 1 ) {
8819
+ for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
8820
+ // int32_t row_id;
8821
+ // CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8822
+ // CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
8815
8823
8816
- src1_row.extra = &src1_row_extra;
8817
- dst_row.extra = &dst_row_extra;
8824
+ const int32_t row_id = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
8818
8825
8826
+ GGML_ASSERT (row_id >= 0 && row_id < n_as);
8819
8827
8820
- for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
8821
- // int32_t row_id;
8822
- // CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8823
- // CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
8828
+ const struct ggml_tensor * src0_row = dst->src [row_id + 2 ];
8824
8829
8825
- const int32_t row_id = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
8830
+ src1_row_extra.data_device [g_main_device] = src1_original + i01*src1->nb [1 ];
8831
+ src1_row.data = (char *) src1->data + i01*src1->nb [1 ]; // for consistency, not currently used.
8826
8832
8827
- GGML_ASSERT (row_id >= 0 && row_id < n_as);
8833
+ dst_row_extra.data_device [g_main_device] = dst_original + i01*dst->nb [1 ];
8834
+ dst_row.data = (char *) dst->data + i01*dst->nb [1 ]; // for consistency, not currently used.
8828
8835
8829
- const struct ggml_tensor * src0_row = dst->src [row_id + 2 ];
8836
+ ggml_cuda_mul_mat (src0_row, &src1_row, &dst_row);
8837
+ }
8838
+ } else {
8839
+ size_t as_src1, as_dst;
8840
+ char * src1_contiguous = (char *) ggml_cuda_pool_malloc (sizeof (float )*ggml_nelements (src1), &as_src1);
8841
+ char * dst_contiguous = (char *) ggml_cuda_pool_malloc (sizeof (float )*ggml_nelements (dst), &as_dst);
8842
+
8843
+ src1_row_extra.data_device [g_main_device] = src1_contiguous;
8844
+ dst_row_extra.data_device [g_main_device] = dst_contiguous;
8830
8845
8831
- src1_row_extra. data_device [g_main_device] = ( char *) src1_extra-> data_device [g_main_device] + i01*src1-> nb [ 1 ];
8832
- src1_row. data = ( char *) src1-> data + i01*src1-> nb [ 1 ];
8846
+ for ( int32_t row_id = 0 ; row_id < 8 ; ++row_id) {
8847
+ const struct ggml_tensor * src0_row = dst-> src [row_id + 2 ];
8833
8848
8834
- dst_row_extra.data_device [g_main_device] = (char *) dst_extra->data_device [g_main_device] + i01*dst->nb [1 ];
8835
- dst_row.data = (char *) dst->data + i01*dst->nb [1 ];
8849
+ int64_t num_src1_rows = 0 ;
8850
+ for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
8851
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
8852
+
8853
+ if (row_id_i != row_id) {
8854
+ continue ;
8855
+ }
8856
+
8857
+ GGML_ASSERT (row_id >= 0 && row_id < n_as);
8858
+
8859
+ CUDA_CHECK (cudaMemcpyAsync (src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
8860
+ nb11, cudaMemcpyDeviceToDevice, stream));
8861
+ num_src1_rows++;
8862
+ }
8863
+
8864
+ if (num_src1_rows == 0 ) {
8865
+ continue ;
8866
+ }
8867
+
8868
+ src1_row.ne [1 ] = num_src1_rows;
8869
+ dst_row.ne [1 ] = num_src1_rows;
8870
+
8871
+ src1_row.nb [1 ] = nb11;
8872
+ src1_row.nb [2 ] = num_src1_rows*nb11;
8873
+ src1_row.nb [3 ] = num_src1_rows*nb11;
8874
+
8875
+ dst_row.nb [1 ] = nb1;
8876
+ dst_row.nb [2 ] = num_src1_rows*nb1;
8877
+ dst_row.nb [3 ] = num_src1_rows*nb1;
8878
+
8879
+ ggml_cuda_mul_mat (src0_row, &src1_row, &dst_row);
8880
+
8881
+ num_src1_rows = 0 ;
8882
+ for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
8883
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data () + i01*ids->nb [1 ] + id*ids->nb [0 ]);
8884
+
8885
+ if (row_id_i != row_id) {
8886
+ continue ;
8887
+ }
8888
+
8889
+ GGML_ASSERT (row_id >= 0 && row_id < n_as);
8890
+
8891
+ CUDA_CHECK (cudaMemcpyAsync (dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
8892
+ nb1, cudaMemcpyDeviceToDevice, stream));
8893
+ num_src1_rows++;
8894
+ }
8895
+ }
8836
8896
8837
- ggml_cuda_mul_mat (src0_row, &src1_row, &dst_row);
8897
+ ggml_cuda_pool_free (src1_contiguous, as_src1);
8898
+ ggml_cuda_pool_free (dst_contiguous, as_dst);
8838
8899
}
8839
8900
}
8840
8901
0 commit comments