Skip to content

Commit d1c92cb

Browse files
authored
[HIP][UR] Use primary context in HIP adapter (#10514)
The primary context has been default for a while in CUDA PI/Adapter. See #8197. This PR brings the HIP adapter up to speed. It also changes the scoped context to only take a `ur_device_handle_t` since this is coupled with a native primary context in HIP
1 parent 49cf82e commit d1c92cb

File tree

12 files changed

+75
-112
lines changed

12 files changed

+75
-112
lines changed

sycl/plugins/unified_runtime/ur/adapters/hip/context.cpp

Lines changed: 7 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
2222

2323
std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
2424
try {
25-
hipCtx_t Current = nullptr;
26-
2725
// Create a scoped context.
28-
hipCtx_t NewContext;
29-
UR_CHECK_ERROR(hipCtxGetCurrent(&Current));
30-
RetErr = UR_CHECK_ERROR(
31-
hipCtxCreate(&NewContext, hipDeviceMapHost, phDevices[0]->get()));
32-
ContextPtr = std::unique_ptr<ur_context_handle_t_>(new ur_context_handle_t_{
33-
ur_context_handle_t_::kind::UserDefined, NewContext, *phDevices});
26+
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
27+
new ur_context_handle_t_{*phDevices});
3428

3529
static std::once_flag InitFlag;
3630
std::call_once(
@@ -43,14 +37,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
4337
},
4438
RetErr);
4539

46-
// For non-primary scoped contexts keep the last active on top of the stack
47-
// as `hipCtxCreate` replaces it implicitly otherwise.
48-
// Primary contexts are kept on top of the stack, so the previous context
49-
// is not queried and therefore not recovered.
50-
if (Current != nullptr) {
51-
UR_CHECK_ERROR(hipCtxSetCurrent(Current));
52-
}
53-
5440
*phContext = ContextPtr.release();
5541
} catch (ur_result_t Err) {
5642
RetErr = Err;
@@ -97,40 +83,10 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
9783

9884
UR_APIEXPORT ur_result_t UR_APICALL
9985
urContextRelease(ur_context_handle_t hContext) {
100-
if (hContext->decrementReferenceCount() > 0) {
101-
return UR_RESULT_SUCCESS;
102-
}
103-
hContext->invokeExtendedDeleters();
104-
105-
std::unique_ptr<ur_context_handle_t_> context{hContext};
106-
107-
if (!hContext->isPrimary()) {
108-
hipCtx_t HIPCtxt = hContext->get();
109-
// hipCtxSynchronize is not supported for AMD platform so we can just
110-
// destroy the context, for NVIDIA make sure it's synchronized.
111-
#if defined(__HIP_PLATFORM_NVIDIA__)
112-
hipCtx_t Current = nullptr;
113-
UR_CHECK_ERROR(hipCtxGetCurrent(&Current));
114-
if (HIPCtxt != Current) {
115-
UR_CHECK_ERROR(hipCtxPushCurrent(HIPCtxt));
116-
}
117-
UR_CHECK_ERROR(hipCtxSynchronize());
118-
UR_CHECK_ERROR(hipCtxGetCurrent(&Current));
119-
if (HIPCtxt == Current) {
120-
UR_CHECK_ERROR(hipCtxPopCurrent(&Current));
121-
}
122-
#endif
123-
return UR_CHECK_ERROR(hipCtxDestroy(HIPCtxt));
124-
} else {
125-
// Primary context is not destroyed, but released
126-
hipDevice_t HIPDev = hContext->getDevice()->get();
127-
hipCtx_t Current;
128-
UR_CHECK_ERROR(hipCtxPopCurrent(&Current));
129-
return UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDev));
86+
if (hContext->decrementReferenceCount() == 0) {
87+
delete hContext;
13088
}
131-
132-
hipCtx_t HIPCtxt = hContext->get();
133-
return UR_CHECK_ERROR(hipCtxDestroy(HIPCtxt));
89+
return UR_RESULT_SUCCESS;
13490
}
13591

13692
UR_APIEXPORT ur_result_t UR_APICALL
@@ -143,7 +99,8 @@ urContextRetain(ur_context_handle_t hContext) {
14399

144100
UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
145101
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
146-
*phNativeContext = reinterpret_cast<ur_native_handle_t>(hContext->get());
102+
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
103+
hContext->getDevice()->getNativeContext());
147104
return UR_RESULT_SUCCESS;
148105
}
149106

sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,11 @@ struct ur_context_handle_t_ {
6262

6363
using native_type = hipCtx_t;
6464

65-
enum class kind { Primary, UserDefined } Kind;
66-
native_type HIPContext;
6765
ur_device_handle_t DeviceId;
6866
std::atomic_uint32_t RefCount;
6967

70-
ur_context_handle_t_(kind K, hipCtx_t Ctxt, ur_device_handle_t DevId)
71-
: Kind{K}, HIPContext{Ctxt}, DeviceId{DevId}, RefCount{1} {
72-
DeviceId->setContext(this);
68+
ur_context_handle_t_(ur_device_handle_t DevId)
69+
: DeviceId{DevId}, RefCount{1} {
7370
urDeviceRetain(DeviceId);
7471
};
7572

@@ -90,10 +87,6 @@ struct ur_context_handle_t_ {
9087

9188
ur_device_handle_t getDevice() const noexcept { return DeviceId; }
9289

93-
native_type get() const noexcept { return HIPContext; }
94-
95-
bool isPrimary() const noexcept { return Kind == kind::Primary; }
96-
9790
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
9891

9992
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
@@ -113,19 +106,18 @@ namespace {
113106
/// API is the one active on the thread.
114107
/// The implementation tries to avoid replacing the hipCtx_t if it cans
115108
class ScopedContext {
116-
ur_context_handle_t PlacedContext;
117109
hipCtx_t Original;
118110
bool NeedToRecover;
119111

120112
public:
121-
ScopedContext(ur_context_handle_t Ctxt)
122-
: PlacedContext{Ctxt}, NeedToRecover{false} {
113+
ScopedContext(ur_device_handle_t hDevice) : NeedToRecover{false} {
123114

124-
if (!PlacedContext) {
125-
throw UR_RESULT_ERROR_INVALID_CONTEXT;
115+
if (!hDevice) {
116+
throw UR_RESULT_ERROR_INVALID_DEVICE;
126117
}
127118

128-
hipCtx_t Desired = PlacedContext->get();
119+
// FIXME when multi device context are supported in HIP adapter
120+
hipCtx_t Desired = hDevice->getNativeContext();
129121
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
130122
if (Original != Desired) {
131123
// Sets the desired context as the active one for the thread

sycl/plugins/unified_runtime/ur/adapters/hip/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
968968
return UR_RESULT_SUCCESS;
969969

970970
ur_event_handle_t_::native_type Event;
971-
ScopedContext Active(hDevice->getContext());
971+
ScopedContext Active(hDevice);
972972

973973
if (pDeviceTimestamp) {
974974
UR_CHECK_ERROR(hipEventCreateWithFlags(&Event, hipEventDefault));

sycl/plugins/unified_runtime/ur/adapters/hip/device.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,25 @@ struct ur_device_handle_t_ {
2222
native_type HIPDevice;
2323
std::atomic_uint32_t RefCount;
2424
ur_platform_handle_t Platform;
25-
ur_context_handle_t Context;
25+
hipCtx_t HIPContext;
2626

2727
public:
28-
ur_device_handle_t_(native_type HipDevice, ur_platform_handle_t Platform)
29-
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform) {}
28+
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
29+
ur_platform_handle_t Platform)
30+
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
31+
HIPContext(Context) {}
32+
33+
~ur_device_handle_t_() {
34+
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
35+
}
3036

3137
native_type get() const noexcept { return HIPDevice; };
3238

3339
uint32_t getReferenceCount() const noexcept { return RefCount; }
3440

3541
ur_platform_handle_t getPlatform() const noexcept { return Platform; };
3642

37-
void setContext(ur_context_handle_t Ctxt) { Context = Ctxt; };
38-
39-
ur_context_handle_t getContext() { return Context; };
43+
hipCtx_t getNativeContext() { return HIPContext; };
4044
};
4145

4246
int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);

sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t CommandQueue,
4141
return UR_RESULT_SUCCESS;
4242
}
4343
try {
44-
ScopedContext Active(CommandQueue->getContext());
44+
ScopedContext Active(CommandQueue->getDevice());
4545

4646
auto Result = forLatestEvents(
4747
EventWaitList, NumEventsInWaitList,
@@ -97,7 +97,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
9797
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
9898

9999
try {
100-
ScopedContext Active(hQueue->getContext());
100+
ScopedContext Active(hQueue->getDevice());
101101
hipStream_t HIPStream = hQueue->getNextTransferStream();
102102
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
103103
phEventWaitList);
@@ -143,7 +143,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
143143
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
144144

145145
try {
146-
ScopedContext Active(hQueue->getContext());
146+
ScopedContext Active(hQueue->getDevice());
147147
hipStream_t HIPStream = hQueue->getNextTransferStream();
148148
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
149149
phEventWaitList);
@@ -252,7 +252,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
252252
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
253253

254254
try {
255-
ScopedContext Active(hQueue->getContext());
255+
ScopedContext Active(hQueue->getDevice());
256256

257257
uint32_t StreamToken;
258258
ur_stream_quard Guard;
@@ -363,7 +363,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
363363
ur_result_t Result;
364364

365365
try {
366-
ScopedContext Active(hQueue->getContext());
366+
ScopedContext Active(hQueue->getDevice());
367367
uint32_t StreamToken;
368368
ur_stream_quard Guard;
369369
hipStream_t HIPStream = hQueue->getNextComputeStream(
@@ -513,7 +513,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
513513
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
514514

515515
try {
516-
ScopedContext Active(hQueue->getContext());
516+
ScopedContext Active(hQueue->getDevice());
517517
hipStream_t HIPStream = hQueue->getNextTransferStream();
518518

519519
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
@@ -561,7 +561,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
561561
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
562562

563563
try {
564-
ScopedContext Active(hQueue->getContext());
564+
ScopedContext Active(hQueue->getDevice());
565565
hipStream_t HIPStream = hQueue->getNextTransferStream();
566566
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
567567
phEventWaitList);
@@ -609,7 +609,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
609609
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
610610

611611
try {
612-
ScopedContext Active(hQueue->getContext());
612+
ScopedContext Active(hQueue->getDevice());
613613
ur_result_t Result;
614614
auto Stream = hQueue->getNextTransferStream();
615615

@@ -656,7 +656,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
656656
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
657657

658658
try {
659-
ScopedContext Active(hQueue->getContext());
659+
ScopedContext Active(hQueue->getDevice());
660660
hipStream_t HIPStream = hQueue->getNextTransferStream();
661661
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
662662
phEventWaitList);
@@ -751,7 +751,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
751751
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
752752

753753
try {
754-
ScopedContext Active(hQueue->getContext());
754+
ScopedContext Active(hQueue->getDevice());
755755

756756
auto Stream = hQueue->getNextTransferStream();
757757
ur_result_t Result;
@@ -892,7 +892,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
892892
ur_result_t Result = UR_RESULT_SUCCESS;
893893

894894
try {
895-
ScopedContext Active(hQueue->getContext());
895+
ScopedContext Active(hQueue->getDevice());
896896
hipStream_t HIPStream = hQueue->getNextTransferStream();
897897

898898
if (phEventWaitList) {
@@ -954,7 +954,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
954954
ur_result_t Result = UR_RESULT_SUCCESS;
955955

956956
try {
957-
ScopedContext Active(hQueue->getContext());
957+
ScopedContext Active(hQueue->getDevice());
958958
hipStream_t HIPStream = hQueue->getNextTransferStream();
959959

960960
if (phEventWaitList) {
@@ -1020,7 +1020,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
10201020
ur_result_t Result = UR_RESULT_SUCCESS;
10211021

10221022
try {
1023-
ScopedContext Active(hQueue->getContext());
1023+
ScopedContext Active(hQueue->getDevice());
10241024
hipStream_t HIPStream = hQueue->getNextTransferStream();
10251025
if (phEventWaitList) {
10261026
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
@@ -1116,7 +1116,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
11161116
HostPtr, numEventsInWaitList,
11171117
phEventWaitList, phEvent);
11181118
} else {
1119-
ScopedContext Active(hQueue->getContext());
1119+
ScopedContext Active(hQueue->getDevice());
11201120

11211121
if (IsPinned) {
11221122
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
@@ -1167,7 +1167,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
11671167
hMem->Mem.BufferMem.getMapSize(), pMappedPtr, numEventsInWaitList,
11681168
phEventWaitList, phEvent);
11691169
} else {
1170-
ScopedContext Active(hQueue->getContext());
1170+
ScopedContext Active(hQueue->getDevice());
11711171

11721172
if (IsPinned) {
11731173
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
@@ -1198,7 +1198,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
11981198
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
11991199

12001200
try {
1201-
ScopedContext Active(hQueue->getContext());
1201+
ScopedContext Active(hQueue->getDevice());
12021202
uint32_t StreamToken;
12031203
ur_stream_quard Guard;
12041204
hipStream_t HIPStream = hQueue->getNextComputeStream(
@@ -1256,7 +1256,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
12561256
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
12571257

12581258
try {
1259-
ScopedContext Active(hQueue->getContext());
1259+
ScopedContext Active(hQueue->getDevice());
12601260
hipStream_t HIPStream = hQueue->getNextTransferStream();
12611261
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
12621262
phEventWaitList);
@@ -1287,6 +1287,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
12871287
ur_queue_handle_t hQueue, const void *pMem, size_t size,
12881288
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
12891289
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1290+
#if HIP_VERSION_MAJOR >= 5
12901291
void *HIPDevicePtr = const_cast<void *>(pMem);
12911292
unsigned int PointerRangeSize = 0;
12921293
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
@@ -1301,7 +1302,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13011302
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
13021303

13031304
try {
1304-
ScopedContext Active(hQueue->getContext());
1305+
ScopedContext Active(hQueue->getDevice());
13051306
hipStream_t HIPStream = hQueue->getNextTransferStream();
13061307
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
13071308
phEventWaitList);
@@ -1311,8 +1312,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13111312
UR_COMMAND_USM_PREFETCH, hQueue, HIPStream));
13121313
EventPtr->start();
13131314
}
1314-
Result = UR_CHECK_ERROR(hipMemPrefetchAsync(
1315-
pMem, size, hQueue->getContext()->getDevice()->get(), HIPStream));
1315+
Result = UR_CHECK_ERROR(
1316+
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
13161317
if (phEvent) {
13171318
Result = EventPtr->record();
13181319
*phEvent = EventPtr.release();
@@ -1322,11 +1323,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13221323
}
13231324

13241325
return Result;
1326+
#else
1327+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1328+
#endif
13251329
}
13261330

13271331
UR_APIEXPORT ur_result_t UR_APICALL
13281332
urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
13291333
ur_usm_advice_flags_t, ur_event_handle_t *phEvent) {
1334+
#if HIP_VERSION_MAJOR >= 5
13301335
void *HIPDevicePtr = const_cast<void *>(pMem);
13311336
unsigned int PointerRangeSize = 0;
13321337
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
@@ -1337,6 +1342,9 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
13371342
// TODO implement a mapping to hipMemAdvise once the expected behaviour
13381343
// of urEnqueueUSMAdvise is detailed in the USM extension
13391344
return urEnqueueEventsWait(hQueue, 0, nullptr, phEvent);
1345+
#else
1346+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1347+
#endif
13401348
}
13411349

13421350
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(
@@ -1367,7 +1375,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
13671375
ur_result_t Result = UR_RESULT_SUCCESS;
13681376

13691377
try {
1370-
ScopedContext Active(hQueue->getContext());
1378+
ScopedContext Active(hQueue->getDevice());
13711379
hipStream_t HIPStream = hQueue->getNextTransferStream();
13721380
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
13731381
phEventWaitList);

0 commit comments

Comments
 (0)