Skip to content

Commit dac0a2a

Browse files
committed
Fix indirect access feature after pooling changes
1 parent 1499c4b commit dac0a2a

File tree

1 file changed

+105
-30
lines changed

1 file changed

+105
-30
lines changed

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 105 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3032,6 +3032,74 @@ pi_result piextQueueCreateWithNativeHandle(pi_native_handle NativeHandle,
30323032
return PI_SUCCESS;
30333033
}
30343034

3035+
// If indirect access tracking is enabled then performs reference counting,
3036+
// otherwise just calls zeMemAllocDevice.
3037+
static pi_result ZeDeviceMemAllocHelper(void **ResultPtr, pi_context Context,
3038+
pi_device Device, size_t Size) {
3039+
pi_platform Plt = Device->Platform;
3040+
std::unique_lock<std::mutex> ContextsLock(Plt->ContextsMutex,
3041+
std::defer_lock);
3042+
if (IndirectAccessTrackingEnabled) {
3043+
// Lock the mutex which is guarding contexts container in the platform.
3044+
// This prevents new kernels from being submitted in any context while
3045+
// we are in the process of allocating a memory, this is needed to
3046+
// properly capture allocations by kernels with indirect access.
3047+
ContextsLock.lock();
3048+
// We are going to defer memory release if there are kernels with
3049+
// indirect access, that is why explicitly retain context to be sure
3050+
// that it is released after all memory allocations in this context are
3051+
// released.
3052+
PI_CALL(piContextRetain(Context));
3053+
}
3054+
3055+
ze_device_mem_alloc_desc_t ZeDesc = {};
3056+
ZeDesc.flags = 0;
3057+
ZeDesc.ordinal = 0;
3058+
ZE_CALL(zeMemAllocDevice,
3059+
(Context->ZeContext, &ZeDesc, Size, 1, Device->ZeDevice, ResultPtr));
3060+
3061+
if (IndirectAccessTrackingEnabled) {
3062+
// Keep track of all memory allocations in the context
3063+
Context->MemAllocs.emplace(std::piecewise_construct,
3064+
std::forward_as_tuple(*ResultPtr),
3065+
std::forward_as_tuple(Context));
3066+
}
3067+
return PI_SUCCESS;
3068+
}
3069+
3070+
// If indirect access tracking is enabled then performs reference counting,
3071+
// otherwise just calls zeMemAllocHost.
3072+
static pi_result ZeHostMemAllocHelper(void **ResultPtr, pi_context Context,
3073+
size_t Size) {
3074+
pi_platform Plt = Context->Devices[0]->Platform;
3075+
std::unique_lock<std::mutex> ContextsLock(Plt->ContextsMutex,
3076+
std::defer_lock);
3077+
if (IndirectAccessTrackingEnabled) {
3078+
// Lock the mutex which is guarding contexts container in the platform.
3079+
// This prevents new kernels from being submitted in any context while
3080+
// we are in the process of allocating a memory, this is needed to
3081+
// properly capture allocations by kernels with indirect access.
3082+
ContextsLock.lock();
3083+
// We are going to defer memory release if there are kernels with
3084+
// indirect access, that is why explicitly retain context to be sure
3085+
// that it is released after all memory allocations in this context are
3086+
// released.
3087+
PI_CALL(piContextRetain(Context));
3088+
}
3089+
3090+
ze_host_mem_alloc_desc_t ZeDesc = {};
3091+
ZeDesc.flags = 0;
3092+
ZE_CALL(zeMemAllocHost, (Context->ZeContext, &ZeDesc, Size, 1, ResultPtr));
3093+
3094+
if (IndirectAccessTrackingEnabled) {
3095+
// Keep track of all memory allocations in the context
3096+
Context->MemAllocs.emplace(std::piecewise_construct,
3097+
std::forward_as_tuple(*ResultPtr),
3098+
std::forward_as_tuple(Context));
3099+
}
3100+
return PI_SUCCESS;
3101+
}
3102+
30353103
pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
30363104
void *HostPtr, pi_mem *RetMem,
30373105
const pi_mem_properties *properties) {
@@ -3094,9 +3162,7 @@ pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
30943162
if (enableBufferPooling())
30953163
Result = piextUSMHostAlloc(&Ptr, Context, nullptr, Size, Alignment);
30963164
else {
3097-
ze_host_mem_alloc_desc_t ZeDesc = {};
3098-
ZeDesc.flags = 0;
3099-
ZE_CALL(zeMemAllocHost, (Context->ZeContext, &ZeDesc, Size, 1, &Ptr));
3165+
ZeHostMemAllocHelper(&Ptr, Context, Size);
31003166
}
31013167
} else if (Context->SingleRootDevice) {
31023168
// If we have a single discrete device or all devices in the context are
@@ -3105,11 +3171,7 @@ pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
31053171
Result = piextUSMDeviceAlloc(&Ptr, Context, Context->SingleRootDevice,
31063172
nullptr, Size, Alignment);
31073173
else {
3108-
ze_device_mem_alloc_desc_t ZeDesc = {};
3109-
ZeDesc.flags = 0;
3110-
ZeDesc.ordinal = 0;
3111-
ZE_CALL(zeMemAllocDevice, (Context->ZeContext, &ZeDesc, Size, 1,
3112-
Context->SingleRootDevice->ZeDevice, &Ptr));
3174+
ZeDeviceMemAllocHelper(&Ptr, Context, Context->SingleRootDevice, Size);
31133175
}
31143176
} else {
31153177
// Context with several gpu cards. Temporarily use host allocation because
@@ -3121,9 +3183,7 @@ pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
31213183
if (enableBufferPooling())
31223184
Result = piextUSMHostAlloc(&Ptr, Context, nullptr, Size, Alignment);
31233185
else {
3124-
ze_host_mem_alloc_desc_t ZeDesc = {};
3125-
ZeDesc.flags = 0;
3126-
ZE_CALL(zeMemAllocHost, (Context->ZeContext, &ZeDesc, Size, 1, &Ptr));
3186+
ZeHostMemAllocHelper(&Ptr, Context, Size);
31273187
}
31283188
}
31293189

@@ -3190,6 +3250,37 @@ pi_result piMemRetain(pi_mem Mem) {
31903250
return PI_SUCCESS;
31913251
}
31923252

3253+
// If indirect access tracking is not enabled then this functions just performs
3254+
// zeMemFree. If indirect access tracking is enabled then reference counting is
3255+
// performed.
3256+
static pi_result ZeMemFreeHelper(pi_context Context, void *Ptr) {
3257+
pi_platform Plt = Context->Devices[0]->Platform;
3258+
std::unique_lock<std::mutex> ContextsLock(Plt->ContextsMutex,
3259+
std::defer_lock);
3260+
if (IndirectAccessTrackingEnabled) {
3261+
ContextsLock.lock();
3262+
auto It = Context->MemAllocs.find(Ptr);
3263+
if (It == std::end(Context->MemAllocs)) {
3264+
die("All memory allocations must be tracked!");
3265+
}
3266+
if (--(It->second.RefCount) != 0) {
3267+
// Memory can't be deallocated yet.
3268+
return PI_SUCCESS;
3269+
}
3270+
3271+
// Reference count is zero, it is ok to free memory.
3272+
// We don't need to track this allocation anymore.
3273+
Context->MemAllocs.erase(It);
3274+
}
3275+
3276+
ZE_CALL(zeMemFree, (Context->ZeContext, Ptr));
3277+
3278+
if (IndirectAccessTrackingEnabled)
3279+
PI_CALL(ContextReleaseHelper(Context));
3280+
3281+
return PI_SUCCESS;
3282+
}
3283+
31933284
pi_result piMemRelease(pi_mem Mem) {
31943285
PI_ASSERT(Mem, PI_INVALID_MEM_OBJECT);
31953286

@@ -3202,7 +3293,7 @@ pi_result piMemRelease(pi_mem Mem) {
32023293
if (enableBufferPooling()) {
32033294
PI_CALL(piextUSMFree(Mem->Context, Mem->getZeHandle()));
32043295
} else {
3205-
ZE_CALL(zeMemFree, (Mem->Context->ZeContext, Mem->getZeHandle()));
3296+
ZeMemFreeHelper(Mem->Context, Mem->getZeHandle());
32063297
}
32073298
}
32083299
}
@@ -5016,13 +5107,7 @@ static pi_result EventRelease(pi_event Event, pi_queue LockedQueue) {
50165107
if (Event->CommandType == PI_COMMAND_TYPE_MEM_BUFFER_UNMAP &&
50175108
Event->CommandData) {
50185109
// Free the memory allocated in the piEnqueueMemBufferMap.
5019-
// TODO: always use piextUSMFree
5020-
if (IndirectAccessTrackingEnabled) {
5021-
// Use the version with reference counting
5022-
PI_CALL(piextUSMFree(Event->Context, Event->CommandData));
5023-
} else {
5024-
ZE_CALL(zeMemFree, (Event->Context->ZeContext, Event->CommandData));
5025-
}
5110+
ZeMemFreeHelper(Event->Context, Event->CommandData);
50265111
Event->CommandData = nullptr;
50275112
}
50285113
if (Event->OwnZeEvent) {
@@ -5813,17 +5898,7 @@ pi_result piEnqueueMemBufferMap(pi_queue Queue, pi_mem Buffer,
58135898
if (Buffer->MapHostPtr) {
58145899
*RetMap = Buffer->MapHostPtr + Offset;
58155900
} else {
5816-
// TODO: always use piextUSMHostAlloc
5817-
if (IndirectAccessTrackingEnabled) {
5818-
// Use the version with reference counting
5819-
PI_CALL(piextUSMHostAlloc(RetMap, Queue->Context, nullptr, Size, 1));
5820-
} else {
5821-
ZeStruct<ze_host_mem_alloc_desc_t> ZeDesc;
5822-
ZeDesc.flags = 0;
5823-
5824-
ZE_CALL(zeMemAllocHost,
5825-
(Queue->Context->ZeContext, &ZeDesc, Size, 1, RetMap));
5826-
}
5901+
ZeHostMemAllocHelper(RetMap, Queue->Context, Size);
58275902
}
58285903
const auto &ZeCommandList = CommandList->first;
58295904
const auto &WaitList = (*Event)->WaitList;

0 commit comments

Comments
 (0)