Skip to content

Commit 844f4fd

Browse files
retorder loops
1 parent 54da1a2 commit 844f4fd

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

ggml-cuda.cu

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5985,26 +5985,26 @@ static void ggml_cuda_op_mul_mat(
59855985
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
59865986
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
59875987

5988-
for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
5989-
const int64_t i03 = i0 / ne12;
5990-
const int64_t i02 = i0 % ne12;
5988+
for (int64_t id = 0; id < g_device_count; ++id) {
5989+
if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
5990+
continue;
5991+
}
59915992

5992-
for (int64_t id = 0; id < g_device_count; ++id) {
5993-
if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
5994-
continue;
5995-
}
5993+
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
5994+
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
5995+
const int64_t row_diff = row_high[id] - row_low[id];
59965996

5997-
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
5998-
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
5999-
const int64_t row_diff = row_high[id] - row_low[id];
5997+
cudaSetDevice(id);
5998+
const cudaStream_t stream = g_cudaStreams[id][is];
60005999

6001-
cudaSetDevice(id);
6002-
const cudaStream_t stream = g_cudaStreams[id][is];
6000+
// wait for main GPU data if necessary
6001+
if (split && (id != g_main_device || is != 0)) {
6002+
CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0]));
6003+
}
60036004

6004-
// wait for main GPU data if necessary
6005-
if (split && (id != g_main_device || is != 0)) {
6006-
CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0]));
6007-
}
6005+
for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
6006+
const int64_t i03 = i0 / ne12;
6007+
const int64_t i02 = i0 % ne12;
60086008

60096009
const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
60106010

0 commit comments

Comments
 (0)