@@ -74,14 +74,18 @@ static ur_result_t getNodesFromSyncPoints(
74
74
// / Set parameter for General 1D memory copy.
75
75
// / If the source and/or destination is on the device, SrcPtr and/or DstPtr
76
76
// / must be a pointer to a CUdeviceptr
77
- static void setCopyParams (const void *SrcPtr, const CUmemorytype_enum SrcType,
78
- void *DstPtr, const CUmemorytype_enum DstType,
77
+ static void setCopyParams (const void *SrcPtr, size_t SrcOffset,
78
+ const CUmemorytype_enum SrcType, void *DstPtr,
79
+ size_t DstOffset, const CUmemorytype_enum DstType,
79
80
size_t Size, CUDA_MEMCPY3D &Params) {
80
81
81
82
Params.srcMemoryType = SrcType;
82
- Params.srcDevice = SrcType == CU_MEMORYTYPE_DEVICE
83
- ? *static_cast <const CUdeviceptr *>(SrcPtr)
84
- : 0 ;
83
+ Params.srcDevice =
84
+ SrcType == CU_MEMORYTYPE_DEVICE
85
+ ? (*static_cast <const CUdeviceptr *>(SrcPtr)) + SrcOffset
86
+ : 0 ;
87
+ // UR entry point function definitions imply that offsets can be set only on
88
+ // device side.
85
89
Params.srcHost = SrcType == CU_MEMORYTYPE_HOST ? SrcPtr : nullptr ;
86
90
Params.srcXInBytes = 0 ;
87
91
Params.srcY = 0 ;
@@ -91,8 +95,9 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
91
95
Params.srcHeight = 0 ;
92
96
93
97
Params.dstMemoryType = DstType;
94
- Params.dstDevice =
95
- DstType == CU_MEMORYTYPE_DEVICE ? *static_cast <CUdeviceptr *>(DstPtr) : 0 ;
98
+ Params.dstDevice = DstType == CU_MEMORYTYPE_DEVICE
99
+ ? (*static_cast <CUdeviceptr *>(DstPtr)) + DstOffset
100
+ : 0 ;
96
101
Params.dstHost = DstType == CU_MEMORYTYPE_HOST ? DstPtr : nullptr ;
97
102
Params.dstXInBytes = 0 ;
98
103
Params.dstY = 0 ;
@@ -252,8 +257,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemcpyUSMExp(
252
257
253
258
try {
254
259
CUDA_MEMCPY3D NodeParams = {};
255
- setCopyParams (pSrc, CU_MEMORYTYPE_HOST, pDst, CU_MEMORYTYPE_HOST, size ,
256
- NodeParams);
260
+ setCopyParams (pSrc, 0 , CU_MEMORYTYPE_HOST, pDst, 0 , CU_MEMORYTYPE_HOST ,
261
+ size, NodeParams);
257
262
258
263
Result = UR_CHECK_ERROR (cuGraphAddMemcpyNode (
259
264
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
@@ -281,12 +286,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMembufferCopyExp(
281
286
pSyncPointWaitList, DepsList));
282
287
283
288
try {
284
- auto Src = hSrcMem->Mem .BufferMem .get () + srcOffset ;
285
- auto Dst = hDstMem->Mem .BufferMem .get () + dstOffset ;
289
+ auto Src = hSrcMem->Mem .BufferMem .get ();
290
+ auto Dst = hDstMem->Mem .BufferMem .get ();
286
291
287
292
CUDA_MEMCPY3D NodeParams = {};
288
- setCopyParams (&Src, CU_MEMORYTYPE_DEVICE, &Dst, CU_MEMORYTYPE_DEVICE, size ,
289
- NodeParams);
293
+ setCopyParams (&Src, srcOffset, CU_MEMORYTYPE_DEVICE, &Dst, dstOffset ,
294
+ CU_MEMORYTYPE_DEVICE, size, NodeParams);
290
295
291
296
Result = UR_CHECK_ERROR (cuGraphAddMemcpyNode (
292
297
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
@@ -351,11 +356,11 @@ ur_result_t UR_APICALL urCommandBufferAppendMembufferWriteExp(
351
356
pSyncPointWaitList, DepsList));
352
357
353
358
try {
354
- auto Dst = hBuffer->Mem .BufferMem .get () + offset ;
359
+ auto Dst = hBuffer->Mem .BufferMem .get ();
355
360
356
361
CUDA_MEMCPY3D NodeParams = {};
357
- setCopyParams (pSrc, CU_MEMORYTYPE_HOST, &Dst, CU_MEMORYTYPE_DEVICE, size ,
358
- NodeParams);
362
+ setCopyParams (pSrc, 0 , CU_MEMORYTYPE_HOST, &Dst, offset ,
363
+ CU_MEMORYTYPE_DEVICE, size, NodeParams);
359
364
360
365
Result = UR_CHECK_ERROR (cuGraphAddMemcpyNode (
361
366
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
@@ -383,11 +388,11 @@ ur_result_t UR_APICALL urCommandBufferAppendMembufferReadExp(
383
388
pSyncPointWaitList, DepsList));
384
389
385
390
try {
386
- auto Src = hBuffer->Mem .BufferMem .get () + offset ;
391
+ auto Src = hBuffer->Mem .BufferMem .get ();
387
392
388
393
CUDA_MEMCPY3D NodeParams = {};
389
- setCopyParams (&Src, CU_MEMORYTYPE_DEVICE, pDst, CU_MEMORYTYPE_HOST, size ,
390
- NodeParams);
394
+ setCopyParams (&Src, offset, CU_MEMORYTYPE_DEVICE, pDst, 0 ,
395
+ CU_MEMORYTYPE_HOST, size, NodeParams);
391
396
392
397
Result = UR_CHECK_ERROR (cuGraphAddMemcpyNode (
393
398
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
0 commit comments