Skip to content

Commit 1659cd1

Browse files
committed
fix mixtral
1 parent 6f35a4a commit 1659cd1

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

ggml-cuda.cu

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8258,8 +8258,18 @@ static void ggml_cuda_op_mul_mat(
82588258
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
82598259
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
82608260
dhf_dst_i += src1_col_0*ne0 + row_low[id];
8261-
CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float),
8262-
row_diff*sizeof(float), src1_ncols, kind, stream));
8261+
8262+
if (kind == cudaMemcpyDeviceToDevice && id != g_main_device) {
8263+
// there is no cudaMemcpy2DPeerAsync so we need to copy each row separately
8264+
for (int64_t i = 0; i < src1_ncols; ++i) {
8265+
CUDA_CHECK(cudaMemcpyPeerAsync(dhf_dst_i + i*ne0, g_main_device,
8266+
dst_dd_i + i*row_diff, id,
8267+
row_diff*sizeof(float), stream));
8268+
}
8269+
} else {
8270+
CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float),
8271+
row_diff*sizeof(float), src1_ncols, kind, stream));
8272+
}
82638273
} else {
82648274
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
82658275
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));

0 commit comments

Comments
 (0)