Skip to content

Commit cdc19b1

Browse files
[SYCL][Graph] Improve offset handling in memory copy operations
The offsets are taken into account after CUdeviceptr type casting.
1 parent c60ab8f commit cdc19b1

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

sycl/plugins/unified_runtime/ur/adapters/cuda/command_buffer.cpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,18 @@ static ur_result_t getNodesFromSyncPoints(
7474
/// Set parameter for General 1D memory copy.
7575
/// If the source and/or destination is on the device, SrcPtr and/or DstPtr
7676
/// must be a pointer to a CUdeviceptr
77-
static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
78-
void *DstPtr, const CUmemorytype_enum DstType,
77+
static void setCopyParams(const void *SrcPtr, size_t SrcOffset,
78+
const CUmemorytype_enum SrcType, void *DstPtr,
79+
size_t DstOffset, const CUmemorytype_enum DstType,
7980
size_t Size, CUDA_MEMCPY3D &Params) {
8081

8182
Params.srcMemoryType = SrcType;
82-
Params.srcDevice = SrcType == CU_MEMORYTYPE_DEVICE
83-
? *static_cast<const CUdeviceptr *>(SrcPtr)
84-
: 0;
83+
Params.srcDevice =
84+
SrcType == CU_MEMORYTYPE_DEVICE
85+
? (*static_cast<const CUdeviceptr *>(SrcPtr)) + SrcOffset
86+
: 0;
87+
// UR entry point function definitions imply that offsets can be set only on
88+
// device side.
8589
Params.srcHost = SrcType == CU_MEMORYTYPE_HOST ? SrcPtr : nullptr;
8690
Params.srcXInBytes = 0;
8791
Params.srcY = 0;
@@ -91,8 +95,9 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
9195
Params.srcHeight = 0;
9296

9397
Params.dstMemoryType = DstType;
94-
Params.dstDevice =
95-
DstType == CU_MEMORYTYPE_DEVICE ? *static_cast<CUdeviceptr *>(DstPtr) : 0;
98+
Params.dstDevice = DstType == CU_MEMORYTYPE_DEVICE
99+
? (*static_cast<CUdeviceptr *>(DstPtr)) + DstOffset
100+
: 0;
96101
Params.dstHost = DstType == CU_MEMORYTYPE_HOST ? DstPtr : nullptr;
97102
Params.dstXInBytes = 0;
98103
Params.dstY = 0;
@@ -252,8 +257,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemcpyUSMExp(
252257

253258
try {
254259
CUDA_MEMCPY3D NodeParams = {};
255-
setCopyParams(pSrc, CU_MEMORYTYPE_HOST, pDst, CU_MEMORYTYPE_HOST, size,
256-
NodeParams);
260+
setCopyParams(pSrc, 0, CU_MEMORYTYPE_HOST, pDst, 0, CU_MEMORYTYPE_HOST,
261+
size, NodeParams);
257262

258263
Result = UR_CHECK_ERROR(cuGraphAddMemcpyNode(
259264
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
@@ -281,12 +286,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMembufferCopyExp(
281286
pSyncPointWaitList, DepsList));
282287

283288
try {
284-
auto Src = hSrcMem->Mem.BufferMem.get() + srcOffset;
285-
auto Dst = hDstMem->Mem.BufferMem.get() + dstOffset;
289+
auto Src = hSrcMem->Mem.BufferMem.get();
290+
auto Dst = hDstMem->Mem.BufferMem.get();
286291

287292
CUDA_MEMCPY3D NodeParams = {};
288-
setCopyParams(&Src, CU_MEMORYTYPE_DEVICE, &Dst, CU_MEMORYTYPE_DEVICE, size,
289-
NodeParams);
293+
setCopyParams(&Src, srcOffset, CU_MEMORYTYPE_DEVICE, &Dst, dstOffset,
294+
CU_MEMORYTYPE_DEVICE, size, NodeParams);
290295

291296
Result = UR_CHECK_ERROR(cuGraphAddMemcpyNode(
292297
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
@@ -351,11 +356,11 @@ ur_result_t UR_APICALL urCommandBufferAppendMembufferWriteExp(
351356
pSyncPointWaitList, DepsList));
352357

353358
try {
354-
auto Dst = hBuffer->Mem.BufferMem.get() + offset;
359+
auto Dst = hBuffer->Mem.BufferMem.get();
355360

356361
CUDA_MEMCPY3D NodeParams = {};
357-
setCopyParams(pSrc, CU_MEMORYTYPE_HOST, &Dst, CU_MEMORYTYPE_DEVICE, size,
358-
NodeParams);
362+
setCopyParams(pSrc, 0, CU_MEMORYTYPE_HOST, &Dst, offset,
363+
CU_MEMORYTYPE_DEVICE, size, NodeParams);
359364

360365
Result = UR_CHECK_ERROR(cuGraphAddMemcpyNode(
361366
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
@@ -383,11 +388,11 @@ ur_result_t UR_APICALL urCommandBufferAppendMembufferReadExp(
383388
pSyncPointWaitList, DepsList));
384389

385390
try {
386-
auto Src = hBuffer->Mem.BufferMem.get() + offset;
391+
auto Src = hBuffer->Mem.BufferMem.get();
387392

388393
CUDA_MEMCPY3D NodeParams = {};
389-
setCopyParams(&Src, CU_MEMORYTYPE_DEVICE, pDst, CU_MEMORYTYPE_HOST, size,
390-
NodeParams);
394+
setCopyParams(&Src, offset, CU_MEMORYTYPE_DEVICE, pDst, 0,
395+
CU_MEMORYTYPE_HOST, size, NodeParams);
391396

392397
Result = UR_CHECK_ERROR(cuGraphAddMemcpyNode(
393398
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),

0 commit comments

Comments
 (0)