Skip to content

Commit d6ea231

Browse files
author
Hugh Delaney
committed
Add some description and add further checks for Images
1 parent b9913b1 commit d6ea231

File tree

3 files changed

+178
-66
lines changed

3 files changed

+178
-66
lines changed

sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp

Lines changed: 79 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,33 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
142142
ur_buffer_ *Buffer = ur_cast<ur_buffer_ *>(hBuffer);
143143

144144
ur_result_t Result = UR_RESULT_SUCCESS;
145+
146+
ur_lock MemoryMigrationLock(hBuffer->MemoryMigrationMutex);
147+
148+
// Note that this entry point may be called on a specific queue that may not
149+
// be the last queue to write to the MemBuffer
150+
auto DeviceToCopyFrom = Buffer->LastEventWritingToMemObj == nullptr
151+
? hQueue->getDevice()
152+
: Buffer->LastEventWritingToMemObj->getDevice();
153+
145154
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
146155

147156
try {
148-
ScopedDevice Active(hQueue->getDevice());
149-
hipStream_t HIPStream = hQueue->getNextTransferStream();
150-
Result = enqueueEventsWait(HIPStream, numEventsInWaitList, phEventWaitList);
157+
ScopedDevice Active(DeviceToCopyFrom);
158+
// Use the default stream if copying from another device
159+
hipStream_t HIPStream = DeviceToCopyFrom == hQueue->getDevice()
160+
? hQueue->getNextTransferStream()
161+
: hipStream_t{0};
162+
163+
UR_CHECK_ERROR(
164+
enqueueEventsWait(HIPStream, numEventsInWaitList, phEventWaitList));
165+
if (Buffer->LastEventWritingToMemObj != nullptr &&
166+
hQueue->getDevice() != DeviceToCopyFrom) {
167+
// We may have to wait for an event on another queue if it is the last
168+
// event writing to mem obj
169+
UR_CHECK_ERROR(
170+
enqueueEventsWait(HIPStream, 1, &Buffer->LastEventWritingToMemObj));
171+
}
151172

152173
if (phEvent) {
153174
RetImplEvent =
@@ -156,7 +177,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
156177
UR_CHECK_ERROR(RetImplEvent->start());
157178
}
158179

159-
if (auto SrcPtr = Buffer->getWithOffset(offset, hQueue->getDevice())) {
180+
if (auto SrcPtr = Buffer->getWithOffset(offset, DeviceToCopyFrom)) {
160181
UR_CHECK_ERROR(hipMemcpyDtoHAsync(pDst, SrcPtr, size, HIPStream));
161182
}
162183

@@ -349,8 +370,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
349370
// if it has been written to
350371
if (phEvent && (MemArg.AccessFlags &
351372
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY))) {
352-
ur_cast<ur_buffer_ *>(MemArg.Mem)
353-
->setLastEventWritingToMemObj(RetImplEvent.get());
373+
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.get());
354374
}
355375
}
356376
// We can release the MemoryMigrationMutexes now
@@ -592,13 +612,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
592612
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
593613

594614
try {
595-
ScopedDevice Active(hQueue->getDevice());
596-
hipStream_t HIPStream = hQueue->getNextTransferStream();
597-
598-
Result = enqueueEventsWait(HIPStream, numEventsInWaitList, phEventWaitList);
599-
if (Buffer->LastEventWritingToMemObj != nullptr) {
600-
Result =
601-
enqueueEventsWait(HIPStream, 1, &Buffer->LastEventWritingToMemObj);
615+
ScopedDevice Active(DeviceToCopyFrom);
616+
// Use the default stream if copying from another device
617+
hipStream_t HIPStream = DeviceToCopyFrom == hQueue->getDevice()
618+
? hQueue->getNextTransferStream()
619+
: hipStream_t{0};
620+
621+
UR_CHECK_ERROR(
622+
enqueueEventsWait(HIPStream, numEventsInWaitList, phEventWaitList));
623+
if (Buffer->LastEventWritingToMemObj != nullptr &&
624+
hQueue->getDevice() != DeviceToCopyFrom) {
625+
// We may have to wait for an event on another queue if it is the last
626+
// event writing to mem obj
627+
UR_CHECK_ERROR(
628+
enqueueEventsWait(HIPStream, 1, &Buffer->LastEventWritingToMemObj));
602629
}
603630

604631
if (phEvent) {
@@ -608,7 +635,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
608635
UR_CHECK_ERROR(RetImplEvent->start());
609636
}
610637

611-
if (auto SrcPtr = Buffer->getNativePtr(DeviceToCopyFrom)) {
638+
if (auto SrcPtr = Buffer->getPtr(DeviceToCopyFrom)) {
612639
UR_CHECK_ERROR(commonEnqueueMemBufferCopyRect(
613640
HIPStream, region, &SrcPtr, hipMemoryTypeDevice, bufferOrigin,
614641
bufferRowPitch, bufferSlicePitch, pDst, hipMemoryTypeHost, hostOrigin,
@@ -645,7 +672,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
645672

646673
ur_buffer_ *Buffer = ur_cast<ur_buffer_ *>(hBuffer);
647674
Buffer->allocateMemObjOnDeviceIfNeeded(hQueue->getDevice());
648-
hipDeviceptr_t DevPtr = Buffer->getNativePtr(hQueue->getDevice());
675+
hipDeviceptr_t DevPtr = Buffer->getPtr(hQueue->getDevice());
649676

650677
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
651678

@@ -991,16 +1018,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
9911018

9921019
ur_result_t Result = UR_RESULT_SUCCESS;
9931020

1021+
ur_lock MemoryMigrationLock(hImage->MemoryMigrationMutex);
1022+
1023+
// Note that this entry point may be called on a specific queue that may not
1024+
// be the last queue to write to the MemBuffer
1025+
auto DeviceToCopyFrom = Image->LastEventWritingToMemObj == nullptr
1026+
? hQueue->getDevice()
1027+
: Image->LastEventWritingToMemObj->getDevice();
1028+
9941029
try {
995-
ScopedDevice Active(hQueue->getDevice());
996-
hipStream_t HIPStream = hQueue->getNextTransferStream();
1030+
ScopedDevice Active(DeviceToCopyFrom);
1031+
// Use the default stream if copying from another device
1032+
hipStream_t HIPStream = DeviceToCopyFrom == hQueue->getDevice()
1033+
? hQueue->getNextTransferStream()
1034+
: hipStream_t{0};
9971035

9981036
if (phEventWaitList) {
999-
Result =
1000-
enqueueEventsWait(HIPStream, numEventsInWaitList, phEventWaitList);
1037+
UR_CHECK_ERROR(
1038+
enqueueEventsWait(HIPStream, numEventsInWaitList, phEventWaitList));
10011039
}
10021040

1003-
hipArray *Array = Image->getArray(hQueue->getDevice());
1041+
hipArray *Array = Image->getArray(DeviceToCopyFrom);
10041042

10051043
hipArray_Format Format;
10061044
size_t NumChannels;
@@ -1024,9 +1062,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
10241062
UR_CHECK_ERROR(RetImplEvent->start());
10251063
}
10261064

1027-
Result = commonEnqueueMemImageNDCopy(HIPStream, ImgType, AdjustedRegion,
1028-
Array, hipMemoryTypeArray, SrcOffset,
1029-
pDst, hipMemoryTypeHost, nullptr);
1065+
if (Array != nullptr) {
1066+
UR_CHECK_ERROR(commonEnqueueMemImageNDCopy(
1067+
HIPStream, ImgType, AdjustedRegion, Array, hipMemoryTypeArray,
1068+
SrcOffset, pDst, hipMemoryTypeHost, nullptr));
1069+
}
10301070

10311071
if (Result != UR_RESULT_SUCCESS) {
10321072
return Result;
@@ -1121,14 +1161,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
11211161
UR_ASSERT(ImageSrc->getImageType() == ImageDst->getImageType(),
11221162
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
11231163

1124-
ur_result_t Result = UR_RESULT_SUCCESS;
1164+
ImageDst->allocateMemObjOnDeviceIfNeeded(hQueue->getDevice());
11251165

11261166
try {
11271167
ScopedDevice Active(hQueue->getDevice());
11281168
hipStream_t HIPStream = hQueue->getNextTransferStream();
11291169
if (phEventWaitList) {
1130-
Result =
1131-
enqueueEventsWait(HIPStream, numEventsInWaitList, phEventWaitList);
1170+
UR_CHECK_ERROR(
1171+
enqueueEventsWait(HIPStream, numEventsInWaitList, phEventWaitList));
11321172
}
11331173

11341174
hipArray *SrcArray = ImageSrc->getArray(hQueue->getDevice());
@@ -1166,13 +1206,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
11661206
UR_CHECK_ERROR(RetImplEvent->start());
11671207
}
11681208

1169-
Result = commonEnqueueMemImageNDCopy(
1209+
UR_CHECK_ERROR(commonEnqueueMemImageNDCopy(
11701210
HIPStream, ImgType, AdjustedRegion, SrcArray, hipMemoryTypeArray,
1171-
SrcOffset, DstArray, hipMemoryTypeArray, DstOffset);
1172-
1173-
if (Result != UR_RESULT_SUCCESS) {
1174-
return Result;
1175-
}
1211+
SrcOffset, DstArray, hipMemoryTypeArray, DstOffset));
11761212

11771213
if (phEvent) {
11781214
UR_CHECK_ERROR(RetImplEvent->record());
@@ -1220,15 +1256,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
12201256
if (!IsPinned &&
12211257
((mapFlags & UR_MAP_FLAG_READ) || (mapFlags & UR_MAP_FLAG_WRITE))) {
12221258
// Pinned host memory is already on host so it doesn't need to be read.
1223-
Result = urEnqueueMemBufferRead(hQueue, hBuffer, blockingMap, offset, size,
1224-
HostPtr, numEventsInWaitList,
1225-
phEventWaitList, phEvent);
1259+
UR_CHECK_ERROR(urEnqueueMemBufferRead(hQueue, hBuffer, blockingMap, offset,
1260+
size, HostPtr, numEventsInWaitList,
1261+
phEventWaitList, phEvent));
12261262
} else {
12271263
ScopedDevice Active(hQueue->getDevice());
12281264

12291265
if (IsPinned) {
1230-
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
1231-
nullptr);
1266+
UR_CHECK_ERROR(urEnqueueEventsWait(hQueue, numEventsInWaitList,
1267+
phEventWaitList, nullptr));
12321268
}
12331269

12341270
if (phEvent) {
@@ -1254,7 +1290,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
12541290
ur_queue_handle_t hQueue, ur_mem_handle_t hMem, void *pMappedPtr,
12551291
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
12561292
ur_event_handle_t *phEvent) {
1257-
ur_result_t Result = UR_RESULT_SUCCESS;
12581293
UR_ASSERT(hMem->isBuffer(), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
12591294
ur_buffer_ *Mem = ur_cast<ur_buffer_ *>(hMem);
12601295
UR_ASSERT(Mem->getMapPtr() != nullptr, UR_RESULT_ERROR_INVALID_MEM_OBJECT);
@@ -1267,15 +1302,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
12671302
((Mem->getMapFlags() & UR_MAP_FLAG_WRITE) ||
12681303
(Mem->getMapFlags() & UR_MAP_FLAG_WRITE_INVALIDATE_REGION))) {
12691304
// Pinned host memory is only on host so it doesn't need to be written to.
1270-
Result = urEnqueueMemBufferWrite(
1305+
UR_CHECK_ERROR(urEnqueueMemBufferWrite(
12711306
hQueue, hMem, true, Mem->getMapOffset(), Mem->getMapSize(), pMappedPtr,
1272-
numEventsInWaitList, phEventWaitList, phEvent);
1307+
numEventsInWaitList, phEventWaitList, phEvent));
12731308
} else {
12741309
ScopedDevice Active(hQueue->getDevice());
12751310

12761311
if (IsPinned) {
1277-
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
1278-
nullptr);
1312+
UR_CHECK_ERROR(urEnqueueEventsWait(hQueue, numEventsInWaitList,
1313+
phEventWaitList, nullptr));
12791314
}
12801315

12811316
if (phEvent) {
@@ -1285,13 +1320,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
12851320
UR_CHECK_ERROR((*phEvent)->start());
12861321
UR_CHECK_ERROR((*phEvent)->record());
12871322
} catch (ur_result_t Error) {
1288-
Result = Error;
1323+
return Error;
12891324
}
12901325
}
12911326
}
12921327

12931328
Mem->unmap(pMappedPtr);
1294-
return Result;
1329+
return UR_RESULT_SUCCESS;
12951330
}
12961331

12971332
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(

sycl/plugins/unified_runtime/ur/adapters/hip/memory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ ur_buffer_::allocateMemObjOnDeviceIfNeeded(ur_device_handle_t hDevice) {
376376
ScopedDevice Active(hDevice);
377377
ur_lock LockGuard(MemoryAllocationMutex);
378378

379-
hipDeviceptr_t &DevPtr = getNativePtr(hDevice);
379+
hipDeviceptr_t &DevPtr = getPtr(hDevice);
380380

381381
// Allocation has already been made
382382
if (DevPtr != ur_buffer_::native_type{0}) {

0 commit comments

Comments
 (0)