Skip to content

[SYCL][UR][HIP] s/ScopedContext/ScopedDevice #10672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 1 addition & 46 deletions sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
/// UR API context are objects that are passed to functions, and not bound
/// to threads.
/// The ur_context_handle_t_ object doesn't implement this behavior. It only
/// holds the HIP context data. The RAII object \ref ScopedContext implements
/// holds the HIP context data. The RAII object \ref ScopedDevice implements
/// the active context behavior.
///
/// <b> Primary vs UserDefined context </b>
Expand Down Expand Up @@ -151,48 +151,3 @@ struct ur_context_handle_t_ {
std::vector<deleter_data> ExtendedDeleters;
std::unordered_map<const void *, size_t> USMMappings;
};

namespace {
/// RAII type to guarantee recovering original HIP context
/// Scoped context is used across all UR HIP plugin implementation
/// to activate the UR Context on the current thread, matching the
/// HIP driver semantics where the context used for the HIP Driver
/// API is the one active on the thread.
/// The implementation tries to avoid replacing the hipCtx_t if it cans
class ScopedContext {
hipCtx_t Original;
bool NeedToRecover;

public:
ScopedContext(ur_device_handle_t hDevice) : NeedToRecover{false} {

if (!hDevice) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}

// FIXME when multi device context are supported in HIP adapter
hipCtx_t Desired = hDevice->getNativeContext();
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
if (Original != Desired) {
// Sets the desired context as the active one for the thread
UR_CHECK_ERROR(hipCtxSetCurrent(Desired));
if (Original == nullptr) {
// No context is installed on the current thread
// This is the most common case. We can activate the context in the
// thread and leave it there until all the UR context referring to the
// same underlying HIP context are destroyed. This emulates
// the behaviour of the HIP runtime api, and avoids costly context
// switches. No action is required on this side of the if.
} else {
NeedToRecover = true;
}
}
}

~ScopedContext() {
if (NeedToRecover) {
UR_CHECK_ERROR(hipCtxSetCurrent(Original));
}
}
};
} // namespace
2 changes: 1 addition & 1 deletion sycl/plugins/unified_runtime/ur/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
return UR_RESULT_SUCCESS;

ur_event_handle_t_::native_type Event;
ScopedContext Active(hDevice);
ScopedDevice Active(hDevice);

if (pDeviceTimestamp) {
UR_CHECK_ERROR(hipEventCreateWithFlags(&Event, hipEventDefault));
Expand Down
45 changes: 45 additions & 0 deletions sycl/plugins/unified_runtime/ur/adapters/hip/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,48 @@ struct ur_device_handle_t_ {
};

int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);

namespace {
/// RAII type to guarantee recovering original HIP context
/// Scoped context is used across all UR HIP plugin implementation
/// to activate the UR Context on the current thread, matching the
/// HIP driver semantics where the context used for the HIP Driver
/// API is the one active on the thread.
/// The implementation tries to avoid replacing the hipCtx_t if it cans
class ScopedDevice {
hipCtx_t Original;
bool NeedToRecover;

public:
ScopedDevice(ur_device_handle_t hDevice) : NeedToRecover{false} {

if (!hDevice) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}

// FIXME when multi device context are supported in HIP adapter
hipCtx_t Desired = hDevice->getNativeContext();
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
if (Original != Desired) {
// Sets the desired context as the active one for the thread
UR_CHECK_ERROR(hipCtxSetCurrent(Desired));
if (Original == nullptr) {
// No context is installed on the current thread
// This is the most common case. We can activate the context in the
// thread and leave it there until all the UR context referring to the
// same underlying HIP context are destroyed. This emulates
// the behaviour of the HIP runtime api, and avoids costly context
// switches. No action is required on this side of the if.
} else {
NeedToRecover = true;
}
}
}

~ScopedDevice() {
if (NeedToRecover) {
UR_CHECK_ERROR(hipCtxSetCurrent(Original));
}
}
};
} // namespace
38 changes: 19 additions & 19 deletions sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t CommandQueue,
return UR_RESULT_SUCCESS;
}
try {
ScopedContext Active(CommandQueue->getDevice());
ScopedDevice Active(CommandQueue->getDevice());

auto Result = forLatestEvents(
EventWaitList, NumEventsInWaitList,
Expand Down Expand Up @@ -97,7 +97,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -143,7 +143,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -253,7 +253,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(

try {
ur_device_handle_t Dev = hQueue->getDevice();
ScopedContext Active(Dev);
ScopedDevice Active(Dev);
ur_context_handle_t Ctx = hQueue->getContext();

uint32_t StreamToken;
Expand Down Expand Up @@ -373,7 +373,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
ur_result_t Result;

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
uint32_t StreamToken;
ur_stream_quard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
Expand Down Expand Up @@ -523,7 +523,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
Expand Down Expand Up @@ -571,7 +571,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -619,7 +619,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
ur_result_t Result;
auto Stream = hQueue->getNextTransferStream();

Expand Down Expand Up @@ -666,7 +666,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -761,7 +761,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());

auto Stream = hQueue->getNextTransferStream();
ur_result_t Result;
Expand Down Expand Up @@ -902,7 +902,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

if (phEventWaitList) {
Expand Down Expand Up @@ -970,7 +970,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

if (phEventWaitList) {
Expand Down Expand Up @@ -1042,7 +1042,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
if (phEventWaitList) {
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
Expand Down Expand Up @@ -1144,7 +1144,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
HostPtr, numEventsInWaitList,
phEventWaitList, phEvent);
} else {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());

if (IsPinned) {
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
Expand Down Expand Up @@ -1195,7 +1195,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
hMem->Mem.BufferMem.getMapSize(), pMappedPtr, numEventsInWaitList,
phEventWaitList, phEvent);
} else {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());

if (IsPinned) {
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
Expand Down Expand Up @@ -1226,7 +1226,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
uint32_t StreamToken;
ur_stream_quard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
Expand Down Expand Up @@ -1284,7 +1284,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -1330,7 +1330,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -1403,7 +1403,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getDevice());
ScopedDevice Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down
4 changes: 2 additions & 2 deletions sycl/plugins/unified_runtime/ur/adapters/hip/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) {
try {

auto Context = phEventWaitList[0]->getContext();
ScopedContext Active(Context->getDevice());
ScopedDevice Active(Context->getDevice());

auto WaitFunc = [Context](ur_event_handle_t Event) -> ur_result_t {
UR_ASSERT(Event, UR_RESULT_ERROR_INVALID_EVENT);
Expand Down Expand Up @@ -289,7 +289,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
try {
ScopedContext Active(hEvent->getContext()->getDevice());
ScopedDevice Active(hEvent->getContext()->getDevice());
Result = hEvent->release();
} catch (...) {
Result = UR_RESULT_ERROR_OUT_OF_RESOURCES;
Expand Down
2 changes: 1 addition & 1 deletion sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
std::unique_ptr<ur_kernel_handle_t_> RetKernel{nullptr};

try {
ScopedContext Active(hProgram->getContext()->getDevice());
ScopedDevice Active(hProgram->getContext()->getDevice());

hipFunction_t HIPFunc;
Result = UR_CHECK_ERROR(
Expand Down
10 changes: 5 additions & 5 deletions sycl/plugins/unified_runtime/ur/adapters/hip/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
return UR_RESULT_SUCCESS;
}

ScopedContext Active(uniqueMemObj->getContext()->getDevice());
ScopedDevice Active(uniqueMemObj->getContext()->getDevice());

if (hMem->MemType == ur_mem_handle_t_::Type::Buffer) {
switch (uniqueMemObj->Mem.BufferMem.MemAllocMode) {
Expand Down Expand Up @@ -101,7 +101,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
ur_mem_handle_t RetMemObj = nullptr;

try {
ScopedContext Active(hContext->getDevice());
ScopedDevice Active(hContext->getDevice());
void *Ptr;
auto pHost = pProperties ? pProperties->pHost : nullptr;
ur_mem_handle_t_::MemImpl::BufferMem::AllocMode AllocMode =
Expand Down Expand Up @@ -218,7 +218,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(

std::unique_ptr<ur_mem_handle_t_> RetMemObj{nullptr};
try {
ScopedContext Active(Context->getDevice());
ScopedDevice Active(Context->getDevice());

RetMemObj = std::unique_ptr<ur_mem_handle_t_>{new ur_mem_handle_t_{
Context, hBuffer, flags, AllocMode, Ptr, HostPtr, pRegion->size}};
Expand Down Expand Up @@ -247,7 +247,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory,

UrReturnHelper ReturnValue(propSize, pMemInfo, pPropSizeRet);

ScopedContext Active(hMemory->getContext()->getDevice());
ScopedDevice Active(hMemory->getContext()->getDevice());

switch (MemInfoType) {
case UR_MEM_INFO_SIZE: {
Expand Down Expand Up @@ -425,7 +425,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate(
size_t ImageSizeBytes = PixelSizeBytes * pImageDesc->width *
pImageDesc->height * pImageDesc->depth;

ScopedContext Active(hContext->getDevice());
ScopedDevice Active(hContext->getDevice());
hipArray *ImageArray;
Result = UR_CHECK_ERROR(hipArray3DCreate(
reinterpret_cast<hipCUarray *>(&ImageArray), &ArrayDesc));
Expand Down
4 changes: 2 additions & 2 deletions sycl/plugins/unified_runtime/ur/adapters/hip/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t,
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hProgram->getContext()->getDevice());
ScopedDevice Active(hProgram->getContext()->getDevice());

hProgram->buildProgram(pOptions);

Expand Down Expand Up @@ -209,7 +209,7 @@ urProgramRelease(ur_program_handle_t hProgram) {
ur_result_t Result = UR_RESULT_ERROR_INVALID_PROGRAM;

try {
ScopedContext Active(hProgram->getContext()->getDevice());
ScopedDevice Active(hProgram->getContext()->getDevice());
auto HIPModule = hProgram->get();
if (HIPModule) {
Result = UR_CHECK_ERROR(hipModuleUnload(HIPModule));
Expand Down
Loading