@@ -142,12 +142,33 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
142
142
ur_buffer_ *Buffer = ur_cast<ur_buffer_ *>(hBuffer);
143
143
144
144
ur_result_t Result = UR_RESULT_SUCCESS;
145
+
146
+ ur_lock MemoryMigrationLock (hBuffer->MemoryMigrationMutex );
147
+
148
+ // Note that this entry point may be called on a specific queue that may not
149
+ // be the last queue to write to the MemBuffer
150
+ auto DeviceToCopyFrom = Buffer->LastEventWritingToMemObj == nullptr
151
+ ? hQueue->getDevice ()
152
+ : Buffer->LastEventWritingToMemObj ->getDevice ();
153
+
145
154
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
146
155
147
156
try {
148
- ScopedDevice Active (hQueue->getDevice ());
149
- hipStream_t HIPStream = hQueue->getNextTransferStream ();
150
- Result = enqueueEventsWait (HIPStream, numEventsInWaitList, phEventWaitList);
157
+ ScopedDevice Active (DeviceToCopyFrom);
158
+ // Use the default stream if copying from another device
159
+ hipStream_t HIPStream = DeviceToCopyFrom == hQueue->getDevice ()
160
+ ? hQueue->getNextTransferStream ()
161
+ : hipStream_t{0 };
162
+
163
+ UR_CHECK_ERROR (
164
+ enqueueEventsWait (HIPStream, numEventsInWaitList, phEventWaitList));
165
+ if (Buffer->LastEventWritingToMemObj != nullptr &&
166
+ hQueue->getDevice () != DeviceToCopyFrom) {
167
+ // We may have to wait for an event on another queue if it is the last
168
+ // event writing to mem obj
169
+ UR_CHECK_ERROR (
170
+ enqueueEventsWait (HIPStream, 1 , &Buffer->LastEventWritingToMemObj ));
171
+ }
151
172
152
173
if (phEvent) {
153
174
RetImplEvent =
@@ -156,7 +177,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
156
177
UR_CHECK_ERROR (RetImplEvent->start ());
157
178
}
158
179
159
- if (auto SrcPtr = Buffer->getWithOffset (offset, hQueue-> getDevice () )) {
180
+ if (auto SrcPtr = Buffer->getWithOffset (offset, DeviceToCopyFrom )) {
160
181
UR_CHECK_ERROR (hipMemcpyDtoHAsync (pDst, SrcPtr, size, HIPStream));
161
182
}
162
183
@@ -349,8 +370,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
349
370
// if it has been written to
350
371
if (phEvent && (MemArg.AccessFlags &
351
372
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY))) {
352
- ur_cast<ur_buffer_ *>(MemArg.Mem )
353
- ->setLastEventWritingToMemObj (RetImplEvent.get ());
373
+ MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.get ());
354
374
}
355
375
}
356
376
// We can release the MemoryMigrationMutexes now
@@ -592,13 +612,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
592
612
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
593
613
594
614
try {
595
- ScopedDevice Active (hQueue->getDevice ());
596
- hipStream_t HIPStream = hQueue->getNextTransferStream ();
597
-
598
- Result = enqueueEventsWait (HIPStream, numEventsInWaitList, phEventWaitList);
599
- if (Buffer->LastEventWritingToMemObj != nullptr ) {
600
- Result =
601
- enqueueEventsWait (HIPStream, 1 , &Buffer->LastEventWritingToMemObj );
615
+ ScopedDevice Active (DeviceToCopyFrom);
616
+ // Use the default stream if copying from another device
617
+ hipStream_t HIPStream = DeviceToCopyFrom == hQueue->getDevice ()
618
+ ? hQueue->getNextTransferStream ()
619
+ : hipStream_t{0 };
620
+
621
+ UR_CHECK_ERROR (
622
+ enqueueEventsWait (HIPStream, numEventsInWaitList, phEventWaitList));
623
+ if (Buffer->LastEventWritingToMemObj != nullptr &&
624
+ hQueue->getDevice () != DeviceToCopyFrom) {
625
+ // We may have to wait for an event on another queue if it is the last
626
+ // event writing to mem obj
627
+ UR_CHECK_ERROR (
628
+ enqueueEventsWait (HIPStream, 1 , &Buffer->LastEventWritingToMemObj ));
602
629
}
603
630
604
631
if (phEvent) {
@@ -608,7 +635,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
608
635
UR_CHECK_ERROR (RetImplEvent->start ());
609
636
}
610
637
611
- if (auto SrcPtr = Buffer->getNativePtr (DeviceToCopyFrom)) {
638
+ if (auto SrcPtr = Buffer->getPtr (DeviceToCopyFrom)) {
612
639
UR_CHECK_ERROR (commonEnqueueMemBufferCopyRect (
613
640
HIPStream, region, &SrcPtr, hipMemoryTypeDevice, bufferOrigin,
614
641
bufferRowPitch, bufferSlicePitch, pDst, hipMemoryTypeHost, hostOrigin,
@@ -645,7 +672,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
645
672
646
673
ur_buffer_ *Buffer = ur_cast<ur_buffer_ *>(hBuffer);
647
674
Buffer->allocateMemObjOnDeviceIfNeeded (hQueue->getDevice ());
648
- hipDeviceptr_t DevPtr = Buffer->getNativePtr (hQueue->getDevice ());
675
+ hipDeviceptr_t DevPtr = Buffer->getPtr (hQueue->getDevice ());
649
676
650
677
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
651
678
@@ -991,16 +1018,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
991
1018
992
1019
ur_result_t Result = UR_RESULT_SUCCESS;
993
1020
1021
+ ur_lock MemoryMigrationLock (hImage->MemoryMigrationMutex );
1022
+
1023
+ // Note that this entry point may be called on a specific queue that may not
1024
+ // be the last queue to write to the MemBuffer
1025
+ auto DeviceToCopyFrom = Image->LastEventWritingToMemObj == nullptr
1026
+ ? hQueue->getDevice ()
1027
+ : Image->LastEventWritingToMemObj ->getDevice ();
1028
+
994
1029
try {
995
- ScopedDevice Active (hQueue->getDevice ());
996
- hipStream_t HIPStream = hQueue->getNextTransferStream ();
1030
+ ScopedDevice Active (DeviceToCopyFrom);
1031
+ // Use the default stream if copying from another device
1032
+ hipStream_t HIPStream = DeviceToCopyFrom == hQueue->getDevice ()
1033
+ ? hQueue->getNextTransferStream ()
1034
+ : hipStream_t{0 };
997
1035
998
1036
if (phEventWaitList) {
999
- Result =
1000
- enqueueEventsWait (HIPStream, numEventsInWaitList, phEventWaitList);
1037
+ UR_CHECK_ERROR (
1038
+ enqueueEventsWait (HIPStream, numEventsInWaitList, phEventWaitList)) ;
1001
1039
}
1002
1040
1003
- hipArray *Array = Image->getArray (hQueue-> getDevice () );
1041
+ hipArray *Array = Image->getArray (DeviceToCopyFrom );
1004
1042
1005
1043
hipArray_Format Format;
1006
1044
size_t NumChannels;
@@ -1024,9 +1062,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
1024
1062
UR_CHECK_ERROR (RetImplEvent->start ());
1025
1063
}
1026
1064
1027
- Result = commonEnqueueMemImageNDCopy (HIPStream, ImgType, AdjustedRegion,
1028
- Array, hipMemoryTypeArray, SrcOffset,
1029
- pDst, hipMemoryTypeHost, nullptr );
1065
+ if (Array != nullptr ) {
1066
+ UR_CHECK_ERROR (commonEnqueueMemImageNDCopy (
1067
+ HIPStream, ImgType, AdjustedRegion, Array, hipMemoryTypeArray,
1068
+ SrcOffset, pDst, hipMemoryTypeHost, nullptr ));
1069
+ }
1030
1070
1031
1071
if (Result != UR_RESULT_SUCCESS) {
1032
1072
return Result;
@@ -1121,14 +1161,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
1121
1161
UR_ASSERT (ImageSrc->getImageType () == ImageDst->getImageType (),
1122
1162
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1123
1163
1124
- ur_result_t Result = UR_RESULT_SUCCESS ;
1164
+ ImageDst-> allocateMemObjOnDeviceIfNeeded (hQueue-> getDevice ()) ;
1125
1165
1126
1166
try {
1127
1167
ScopedDevice Active (hQueue->getDevice ());
1128
1168
hipStream_t HIPStream = hQueue->getNextTransferStream ();
1129
1169
if (phEventWaitList) {
1130
- Result =
1131
- enqueueEventsWait (HIPStream, numEventsInWaitList, phEventWaitList);
1170
+ UR_CHECK_ERROR (
1171
+ enqueueEventsWait (HIPStream, numEventsInWaitList, phEventWaitList)) ;
1132
1172
}
1133
1173
1134
1174
hipArray *SrcArray = ImageSrc->getArray (hQueue->getDevice ());
@@ -1166,13 +1206,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
1166
1206
UR_CHECK_ERROR (RetImplEvent->start ());
1167
1207
}
1168
1208
1169
- Result = commonEnqueueMemImageNDCopy (
1209
+ UR_CHECK_ERROR ( commonEnqueueMemImageNDCopy (
1170
1210
HIPStream, ImgType, AdjustedRegion, SrcArray, hipMemoryTypeArray,
1171
- SrcOffset, DstArray, hipMemoryTypeArray, DstOffset);
1172
-
1173
- if (Result != UR_RESULT_SUCCESS) {
1174
- return Result;
1175
- }
1211
+ SrcOffset, DstArray, hipMemoryTypeArray, DstOffset));
1176
1212
1177
1213
if (phEvent) {
1178
1214
UR_CHECK_ERROR (RetImplEvent->record ());
@@ -1220,15 +1256,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
1220
1256
if (!IsPinned &&
1221
1257
((mapFlags & UR_MAP_FLAG_READ) || (mapFlags & UR_MAP_FLAG_WRITE))) {
1222
1258
// Pinned host memory is already on host so it doesn't need to be read.
1223
- Result = urEnqueueMemBufferRead (hQueue, hBuffer, blockingMap, offset, size ,
1224
- HostPtr, numEventsInWaitList,
1225
- phEventWaitList, phEvent);
1259
+ UR_CHECK_ERROR ( urEnqueueMemBufferRead (hQueue, hBuffer, blockingMap, offset,
1260
+ size, HostPtr, numEventsInWaitList,
1261
+ phEventWaitList, phEvent) );
1226
1262
} else {
1227
1263
ScopedDevice Active (hQueue->getDevice ());
1228
1264
1229
1265
if (IsPinned) {
1230
- Result = urEnqueueEventsWait (hQueue, numEventsInWaitList, phEventWaitList ,
1231
- nullptr );
1266
+ UR_CHECK_ERROR ( urEnqueueEventsWait (hQueue, numEventsInWaitList,
1267
+ phEventWaitList, nullptr ) );
1232
1268
}
1233
1269
1234
1270
if (phEvent) {
@@ -1254,7 +1290,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
1254
1290
ur_queue_handle_t hQueue, ur_mem_handle_t hMem, void *pMappedPtr,
1255
1291
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1256
1292
ur_event_handle_t *phEvent) {
1257
- ur_result_t Result = UR_RESULT_SUCCESS;
1258
1293
UR_ASSERT (hMem->isBuffer (), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1259
1294
ur_buffer_ *Mem = ur_cast<ur_buffer_ *>(hMem);
1260
1295
UR_ASSERT (Mem->getMapPtr () != nullptr , UR_RESULT_ERROR_INVALID_MEM_OBJECT);
@@ -1267,15 +1302,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
1267
1302
((Mem->getMapFlags () & UR_MAP_FLAG_WRITE) ||
1268
1303
(Mem->getMapFlags () & UR_MAP_FLAG_WRITE_INVALIDATE_REGION))) {
1269
1304
// Pinned host memory is only on host so it doesn't need to be written to.
1270
- Result = urEnqueueMemBufferWrite (
1305
+ UR_CHECK_ERROR ( urEnqueueMemBufferWrite (
1271
1306
hQueue, hMem, true , Mem->getMapOffset (), Mem->getMapSize (), pMappedPtr,
1272
- numEventsInWaitList, phEventWaitList, phEvent);
1307
+ numEventsInWaitList, phEventWaitList, phEvent)) ;
1273
1308
} else {
1274
1309
ScopedDevice Active (hQueue->getDevice ());
1275
1310
1276
1311
if (IsPinned) {
1277
- Result = urEnqueueEventsWait (hQueue, numEventsInWaitList, phEventWaitList ,
1278
- nullptr );
1312
+ UR_CHECK_ERROR ( urEnqueueEventsWait (hQueue, numEventsInWaitList,
1313
+ phEventWaitList, nullptr ) );
1279
1314
}
1280
1315
1281
1316
if (phEvent) {
@@ -1285,13 +1320,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
1285
1320
UR_CHECK_ERROR ((*phEvent)->start ());
1286
1321
UR_CHECK_ERROR ((*phEvent)->record ());
1287
1322
} catch (ur_result_t Error) {
1288
- Result = Error;
1323
+ return Error;
1289
1324
}
1290
1325
}
1291
1326
}
1292
1327
1293
1328
Mem->unmap (pMappedPtr);
1294
- return Result ;
1329
+ return UR_RESULT_SUCCESS ;
1295
1330
}
1296
1331
1297
1332
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill (
0 commit comments