@@ -5985,26 +5985,26 @@ static void ggml_cuda_op_mul_mat(
5985
5985
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0 ;
5986
5986
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
5987
5987
5988
- for (int64_t i0 = 0 ; i0 < ne13*ne12; ++i0) {
5989
- const int64_t i03 = i0 / ne12;
5990
- const int64_t i02 = i0 % ne12;
5988
+ for (int64_t id = 0 ; id < g_device_count; ++id) {
5989
+ if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
5990
+ continue ;
5991
+ }
5991
5992
5992
- for (int64_t id = 0 ; id < g_device_count; ++id) {
5993
- if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
5994
- continue ;
5995
- }
5993
+ const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
5994
+ const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
5995
+ const int64_t row_diff = row_high[id] - row_low[id];
5996
5996
5997
- const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
5998
- const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
5999
- const int64_t row_diff = row_high[id] - row_low[id];
5997
+ cudaSetDevice (id);
5998
+ const cudaStream_t stream = g_cudaStreams[id][is];
6000
5999
6001
- cudaSetDevice (id);
6002
- const cudaStream_t stream = g_cudaStreams[id][is];
6000
+ // wait for main GPU data if necessary
6001
+ if (split && (id != g_main_device || is != 0 )) {
6002
+ CUDA_CHECK (cudaStreamWaitEvent (stream, src0_extra->events [g_main_device][0 ]));
6003
+ }
6003
6004
6004
- // wait for main GPU data if necessary
6005
- if (split && (id != g_main_device || is != 0 )) {
6006
- CUDA_CHECK (cudaStreamWaitEvent (stream, src0_extra->events [g_main_device][0 ]));
6007
- }
6005
+ for (int64_t i0 = 0 ; i0 < ne13*ne12; ++i0) {
6006
+ const int64_t i03 = i0 / ne12;
6007
+ const int64_t i02 = i0 % ne12;
6008
6008
6009
6009
const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
6010
6010
0 commit comments