Skip to content

Commit 1c3c9df

Browse files
committed
use cudaMemcpy3DPeerAsync
1 parent 1659cd1 commit 1c3c9df

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

ggml-cuda.cu

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@
6868
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
6969
#endif
7070
#define cudaMemcpy hipMemcpy
71-
#define cudaMemcpy2DAsync hipMemcpy2DAsync
7271
#define cudaMemcpyAsync hipMemcpyAsync
7372
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
73+
#define cudaMemcpy2DAsync hipMemcpy2DAsync
74+
#define cudaMemcpy3DPeerAsync hipMemcpy3DPeerAsync
75+
#define cudaMemcpy3DPeerParms hipMemcpy3DPeerParms
76+
#define make_cudaPitchedPtr make_hipPitchedPtr
77+
#define make_cudaExtent make_hipExtent
7478
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
7579
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
7680
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
@@ -8258,14 +8262,15 @@ static void ggml_cuda_op_mul_mat(
82588262
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
82598263
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
82608264
dhf_dst_i += src1_col_0*ne0 + row_low[id];
8261-
82628265
if (kind == cudaMemcpyDeviceToDevice && id != g_main_device) {
8263-
// there is no cudaMemcpy2DPeerAsync so we need to copy each row separately
8264-
for (int64_t i = 0; i < src1_ncols; ++i) {
8265-
CUDA_CHECK(cudaMemcpyPeerAsync(dhf_dst_i + i*ne0, g_main_device,
8266-
dst_dd_i + i*row_diff, id,
8267-
row_diff*sizeof(float), stream));
8268-
}
8266+
// there is no cudaMemcpy2DPeerAsync, use a 3D copy instead
8267+
cudaMemcpy3DPeerParms p = {};
8268+
p.dstDevice = g_main_device;
8269+
p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), ne0, src1_ncols);
8270+
p.srcDevice = id;
8271+
p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
8272+
p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
8273+
CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
82698274
} else {
82708275
CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float),
82718276
row_diff*sizeof(float), src1_ncols, kind, stream));

0 commit comments

Comments
 (0)