Skip to content

Commit 7b57951

Browse files
committed
[SYCL][UR][HIP] s/ScopedContext/ScopedDevice
d1c92cb changed ScopedContext to take a device instead of a context thus sematically changing its meaning. This rename makes it clear what the intented usages are.
1 parent 7897517 commit 7b57951

File tree

10 files changed

+84
-84
lines changed

10 files changed

+84
-84
lines changed

sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
2929
/// UR API context are objects that are passed to functions, and not bound
3030
/// to threads.
3131
/// The ur_context_handle_t_ object doesn't implement this behavior. It only
32-
/// holds the HIP context data. The RAII object \ref ScopedContext implements
32+
/// holds the HIP context data. The RAII object \ref ScopedDevice implements
3333
/// the active context behavior.
3434
///
3535
/// <b> Primary vs UserDefined context </b>
@@ -151,48 +151,3 @@ struct ur_context_handle_t_ {
151151
std::vector<deleter_data> ExtendedDeleters;
152152
std::unordered_map<const void *, size_t> USMMappings;
153153
};
154-
155-
namespace {
156-
/// RAII type to guarantee recovering original HIP context
157-
/// Scoped context is used across all UR HIP plugin implementation
158-
/// to activate the UR Context on the current thread, matching the
159-
/// HIP driver semantics where the context used for the HIP Driver
160-
/// API is the one active on the thread.
161-
/// The implementation tries to avoid replacing the hipCtx_t if it cans
162-
class ScopedContext {
163-
hipCtx_t Original;
164-
bool NeedToRecover;
165-
166-
public:
167-
ScopedContext(ur_device_handle_t hDevice) : NeedToRecover{false} {
168-
169-
if (!hDevice) {
170-
throw UR_RESULT_ERROR_INVALID_DEVICE;
171-
}
172-
173-
// FIXME when multi device context are supported in HIP adapter
174-
hipCtx_t Desired = hDevice->getNativeContext();
175-
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
176-
if (Original != Desired) {
177-
// Sets the desired context as the active one for the thread
178-
UR_CHECK_ERROR(hipCtxSetCurrent(Desired));
179-
if (Original == nullptr) {
180-
// No context is installed on the current thread
181-
// This is the most common case. We can activate the context in the
182-
// thread and leave it there until all the UR context referring to the
183-
// same underlying HIP context are destroyed. This emulates
184-
// the behaviour of the HIP runtime api, and avoids costly context
185-
// switches. No action is required on this side of the if.
186-
} else {
187-
NeedToRecover = true;
188-
}
189-
}
190-
}
191-
192-
~ScopedContext() {
193-
if (NeedToRecover) {
194-
UR_CHECK_ERROR(hipCtxSetCurrent(Original));
195-
}
196-
}
197-
};
198-
} // namespace

sycl/plugins/unified_runtime/ur/adapters/hip/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
968968
return UR_RESULT_SUCCESS;
969969

970970
ur_event_handle_t_::native_type Event;
971-
ScopedContext Active(hDevice);
971+
ScopedDevice Active(hDevice);
972972

973973
if (pDeviceTimestamp) {
974974
UR_CHECK_ERROR(hipEventCreateWithFlags(&Event, hipEventDefault));

sycl/plugins/unified_runtime/ur/adapters/hip/device.hpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,48 @@ struct ur_device_handle_t_ {
4444
};
4545

4646
int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);
47+
48+
namespace {
49+
/// RAII type to guarantee recovering original HIP context
50+
/// Scoped context is used across all UR HIP plugin implementation
51+
/// to activate the UR Context on the current thread, matching the
52+
/// HIP driver semantics where the context used for the HIP Driver
53+
/// API is the one active on the thread.
54+
/// The implementation tries to avoid replacing the hipCtx_t if it cans
55+
class ScopedDevice {
56+
hipCtx_t Original;
57+
bool NeedToRecover;
58+
59+
public:
60+
ScopedDevice(ur_device_handle_t hDevice) : NeedToRecover{false} {
61+
62+
if (!hDevice) {
63+
throw UR_RESULT_ERROR_INVALID_DEVICE;
64+
}
65+
66+
// FIXME when multi device context are supported in HIP adapter
67+
hipCtx_t Desired = hDevice->getNativeContext();
68+
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
69+
if (Original != Desired) {
70+
// Sets the desired context as the active one for the thread
71+
UR_CHECK_ERROR(hipCtxSetCurrent(Desired));
72+
if (Original == nullptr) {
73+
// No context is installed on the current thread
74+
// This is the most common case. We can activate the context in the
75+
// thread and leave it there until all the UR context referring to the
76+
// same underlying HIP context are destroyed. This emulates
77+
// the behaviour of the HIP runtime api, and avoids costly context
78+
// switches. No action is required on this side of the if.
79+
} else {
80+
NeedToRecover = true;
81+
}
82+
}
83+
}
84+
85+
~ScopedDevice() {
86+
if (NeedToRecover) {
87+
UR_CHECK_ERROR(hipCtxSetCurrent(Original));
88+
}
89+
}
90+
};
91+
} // namespace

sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t CommandQueue,
4141
return UR_RESULT_SUCCESS;
4242
}
4343
try {
44-
ScopedContext Active(CommandQueue->getDevice());
44+
ScopedDevice Active(CommandQueue->getDevice());
4545

4646
auto Result = forLatestEvents(
4747
EventWaitList, NumEventsInWaitList,
@@ -97,7 +97,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
9797
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
9898

9999
try {
100-
ScopedContext Active(hQueue->getDevice());
100+
ScopedDevice Active(hQueue->getDevice());
101101
hipStream_t HIPStream = hQueue->getNextTransferStream();
102102
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
103103
phEventWaitList);
@@ -143,7 +143,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
143143
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
144144

145145
try {
146-
ScopedContext Active(hQueue->getDevice());
146+
ScopedDevice Active(hQueue->getDevice());
147147
hipStream_t HIPStream = hQueue->getNextTransferStream();
148148
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
149149
phEventWaitList);
@@ -253,7 +253,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
253253

254254
try {
255255
ur_device_handle_t Dev = hQueue->getDevice();
256-
ScopedContext Active(Dev);
256+
ScopedDevice Active(Dev);
257257
ur_context_handle_t Ctx = hQueue->getContext();
258258

259259
uint32_t StreamToken;
@@ -373,7 +373,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
373373
ur_result_t Result;
374374

375375
try {
376-
ScopedContext Active(hQueue->getDevice());
376+
ScopedDevice Active(hQueue->getDevice());
377377
uint32_t StreamToken;
378378
ur_stream_quard Guard;
379379
hipStream_t HIPStream = hQueue->getNextComputeStream(
@@ -523,7 +523,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
523523
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
524524

525525
try {
526-
ScopedContext Active(hQueue->getDevice());
526+
ScopedDevice Active(hQueue->getDevice());
527527
hipStream_t HIPStream = hQueue->getNextTransferStream();
528528

529529
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
@@ -571,7 +571,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
571571
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
572572

573573
try {
574-
ScopedContext Active(hQueue->getDevice());
574+
ScopedDevice Active(hQueue->getDevice());
575575
hipStream_t HIPStream = hQueue->getNextTransferStream();
576576
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
577577
phEventWaitList);
@@ -619,7 +619,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
619619
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
620620

621621
try {
622-
ScopedContext Active(hQueue->getDevice());
622+
ScopedDevice Active(hQueue->getDevice());
623623
ur_result_t Result;
624624
auto Stream = hQueue->getNextTransferStream();
625625

@@ -666,7 +666,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
666666
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
667667

668668
try {
669-
ScopedContext Active(hQueue->getDevice());
669+
ScopedDevice Active(hQueue->getDevice());
670670
hipStream_t HIPStream = hQueue->getNextTransferStream();
671671
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
672672
phEventWaitList);
@@ -761,7 +761,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
761761
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
762762

763763
try {
764-
ScopedContext Active(hQueue->getDevice());
764+
ScopedDevice Active(hQueue->getDevice());
765765

766766
auto Stream = hQueue->getNextTransferStream();
767767
ur_result_t Result;
@@ -902,7 +902,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
902902
ur_result_t Result = UR_RESULT_SUCCESS;
903903

904904
try {
905-
ScopedContext Active(hQueue->getDevice());
905+
ScopedDevice Active(hQueue->getDevice());
906906
hipStream_t HIPStream = hQueue->getNextTransferStream();
907907

908908
if (phEventWaitList) {
@@ -970,7 +970,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
970970
ur_result_t Result = UR_RESULT_SUCCESS;
971971

972972
try {
973-
ScopedContext Active(hQueue->getDevice());
973+
ScopedDevice Active(hQueue->getDevice());
974974
hipStream_t HIPStream = hQueue->getNextTransferStream();
975975

976976
if (phEventWaitList) {
@@ -1042,7 +1042,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
10421042
ur_result_t Result = UR_RESULT_SUCCESS;
10431043

10441044
try {
1045-
ScopedContext Active(hQueue->getDevice());
1045+
ScopedDevice Active(hQueue->getDevice());
10461046
hipStream_t HIPStream = hQueue->getNextTransferStream();
10471047
if (phEventWaitList) {
10481048
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
@@ -1144,7 +1144,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
11441144
HostPtr, numEventsInWaitList,
11451145
phEventWaitList, phEvent);
11461146
} else {
1147-
ScopedContext Active(hQueue->getDevice());
1147+
ScopedDevice Active(hQueue->getDevice());
11481148

11491149
if (IsPinned) {
11501150
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
@@ -1195,7 +1195,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
11951195
hMem->Mem.BufferMem.getMapSize(), pMappedPtr, numEventsInWaitList,
11961196
phEventWaitList, phEvent);
11971197
} else {
1198-
ScopedContext Active(hQueue->getDevice());
1198+
ScopedDevice Active(hQueue->getDevice());
11991199

12001200
if (IsPinned) {
12011201
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
@@ -1226,7 +1226,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12261226
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
12271227

12281228
try {
1229-
ScopedContext Active(hQueue->getDevice());
1229+
ScopedDevice Active(hQueue->getDevice());
12301230
uint32_t StreamToken;
12311231
ur_stream_quard Guard;
12321232
hipStream_t HIPStream = hQueue->getNextComputeStream(
@@ -1284,7 +1284,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
12841284
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
12851285

12861286
try {
1287-
ScopedContext Active(hQueue->getDevice());
1287+
ScopedDevice Active(hQueue->getDevice());
12881288
hipStream_t HIPStream = hQueue->getNextTransferStream();
12891289
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
12901290
phEventWaitList);
@@ -1330,7 +1330,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13301330
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
13311331

13321332
try {
1333-
ScopedContext Active(hQueue->getDevice());
1333+
ScopedDevice Active(hQueue->getDevice());
13341334
hipStream_t HIPStream = hQueue->getNextTransferStream();
13351335
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
13361336
phEventWaitList);
@@ -1403,7 +1403,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
14031403
ur_result_t Result = UR_RESULT_SUCCESS;
14041404

14051405
try {
1406-
ScopedContext Active(hQueue->getDevice());
1406+
ScopedDevice Active(hQueue->getDevice());
14071407
hipStream_t HIPStream = hQueue->getNextTransferStream();
14081408
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
14091409
phEventWaitList);

sycl/plugins/unified_runtime/ur/adapters/hip/event.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) {
190190
try {
191191

192192
auto Context = phEventWaitList[0]->getContext();
193-
ScopedContext Active(Context->getDevice());
193+
ScopedDevice Active(Context->getDevice());
194194

195195
auto WaitFunc = [Context](ur_event_handle_t Event) -> ur_result_t {
196196
UR_ASSERT(Event, UR_RESULT_ERROR_INVALID_EVENT);
@@ -289,7 +289,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
289289
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
290290
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
291291
try {
292-
ScopedContext Active(hEvent->getContext()->getDevice());
292+
ScopedDevice Active(hEvent->getContext()->getDevice());
293293
Result = hEvent->release();
294294
} catch (...) {
295295
Result = UR_RESULT_ERROR_OUT_OF_RESOURCES;

sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
1717
std::unique_ptr<ur_kernel_handle_t_> RetKernel{nullptr};
1818

1919
try {
20-
ScopedContext Active(hProgram->getContext()->getDevice());
20+
ScopedDevice Active(hProgram->getContext()->getDevice());
2121

2222
hipFunction_t HIPFunc;
2323
Result = UR_CHECK_ERROR(

sycl/plugins/unified_runtime/ur/adapters/hip/memory.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
3030
return UR_RESULT_SUCCESS;
3131
}
3232

33-
ScopedContext Active(uniqueMemObj->getContext()->getDevice());
33+
ScopedDevice Active(uniqueMemObj->getContext()->getDevice());
3434

3535
if (hMem->MemType == ur_mem_handle_t_::Type::Buffer) {
3636
switch (uniqueMemObj->Mem.BufferMem.MemAllocMode) {
@@ -101,7 +101,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
101101
ur_mem_handle_t RetMemObj = nullptr;
102102

103103
try {
104-
ScopedContext Active(hContext->getDevice());
104+
ScopedDevice Active(hContext->getDevice());
105105
void *Ptr;
106106
auto pHost = pProperties ? pProperties->pHost : nullptr;
107107
ur_mem_handle_t_::MemImpl::BufferMem::AllocMode AllocMode =
@@ -218,7 +218,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
218218

219219
std::unique_ptr<ur_mem_handle_t_> RetMemObj{nullptr};
220220
try {
221-
ScopedContext Active(Context->getDevice());
221+
ScopedDevice Active(Context->getDevice());
222222

223223
RetMemObj = std::unique_ptr<ur_mem_handle_t_>{new ur_mem_handle_t_{
224224
Context, hBuffer, flags, AllocMode, Ptr, HostPtr, pRegion->size}};
@@ -247,7 +247,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory,
247247

248248
UrReturnHelper ReturnValue(propSize, pMemInfo, pPropSizeRet);
249249

250-
ScopedContext Active(hMemory->getContext()->getDevice());
250+
ScopedDevice Active(hMemory->getContext()->getDevice());
251251

252252
switch (MemInfoType) {
253253
case UR_MEM_INFO_SIZE: {
@@ -425,7 +425,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate(
425425
size_t ImageSizeBytes = PixelSizeBytes * pImageDesc->width *
426426
pImageDesc->height * pImageDesc->depth;
427427

428-
ScopedContext Active(hContext->getDevice());
428+
ScopedDevice Active(hContext->getDevice());
429429
hipArray *ImageArray;
430430
Result = UR_CHECK_ERROR(hipArray3DCreate(
431431
reinterpret_cast<hipCUarray *>(&ImageArray), &ArrayDesc));

sycl/plugins/unified_runtime/ur/adapters/hip/program.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t,
103103
ur_result_t Result = UR_RESULT_SUCCESS;
104104

105105
try {
106-
ScopedContext Active(hProgram->getContext()->getDevice());
106+
ScopedDevice Active(hProgram->getContext()->getDevice());
107107

108108
hProgram->buildProgram(pOptions);
109109

@@ -209,7 +209,7 @@ urProgramRelease(ur_program_handle_t hProgram) {
209209
ur_result_t Result = UR_RESULT_ERROR_INVALID_PROGRAM;
210210

211211
try {
212-
ScopedContext Active(hProgram->getContext()->getDevice());
212+
ScopedDevice Active(hProgram->getContext()->getDevice());
213213
auto HIPModule = hProgram->get();
214214
if (HIPModule) {
215215
Result = UR_CHECK_ERROR(hipModuleUnload(HIPModule));

0 commit comments

Comments
 (0)