@@ -4189,8 +4189,7 @@ pi_result piMemRetain(pi_mem Mem) {
4189
4189
// If indirect access tracking is not enabled then this functions just performs
4190
4190
// zeMemFree. If indirect access tracking is enabled then reference counting is
4191
4191
// performed.
4192
- static pi_result ZeMemFreeHelper (pi_context Context, void *Ptr,
4193
- bool OwnZeMemHandle = true ) {
4192
+ static pi_result ZeMemFreeHelper (pi_context Context, void *Ptr) {
4194
4193
pi_platform Plt = Context->getPlatform ();
4195
4194
std::unique_lock<pi_shared_mutex> ContextsLock (Plt->ContextsMutex ,
4196
4195
std::defer_lock);
@@ -4210,8 +4209,7 @@ static pi_result ZeMemFreeHelper(pi_context Context, void *Ptr,
4210
4209
Context->MemAllocs .erase (It);
4211
4210
}
4212
4211
4213
- if (OwnZeMemHandle)
4214
- ZE_CALL (zeMemFree, (Context->ZeContext , Ptr));
4212
+ ZE_CALL (zeMemFree, (Context->ZeContext , Ptr));
4215
4213
4216
4214
if (IndirectAccessTrackingEnabled)
4217
4215
PI_CALL (ContextReleaseHelper (Context));
@@ -4220,7 +4218,7 @@ static pi_result ZeMemFreeHelper(pi_context Context, void *Ptr,
4220
4218
}
4221
4219
4222
4220
static pi_result USMFreeHelper (pi_context Context, void *Ptr,
4223
- bool OwnZeMemHandle);
4221
+ bool OwnZeMemHandle = true );
4224
4222
4225
4223
pi_result piMemRelease (pi_mem Mem) {
4226
4224
PI_ASSERT (Mem, PI_ERROR_INVALID_MEM_OBJECT);
@@ -8085,10 +8083,8 @@ static pi_result USMHostAllocImpl(void **ResultPtr, pi_context Context,
8085
8083
return PI_SUCCESS;
8086
8084
}
8087
8085
8088
- static pi_result USMFreeImpl (pi_context Context, void *Ptr,
8089
- bool OwnZeMemHandle) {
8090
- if (OwnZeMemHandle)
8091
- ZE_CALL (zeMemFree, (Context->ZeContext , Ptr));
8086
+ static pi_result USMFreeImpl (pi_context Context, void *Ptr) {
8087
+ ZE_CALL (zeMemFree, (Context->ZeContext , Ptr));
8092
8088
return PI_SUCCESS;
8093
8089
}
8094
8090
@@ -8147,8 +8143,8 @@ void *USMMemoryAllocBase::allocate(size_t Size, size_t Alignment) {
8147
8143
return Ptr;
8148
8144
}
8149
8145
8150
- void USMMemoryAllocBase::deallocate (void *Ptr, bool OwnZeMemHandle ) {
8151
- auto Res = USMFreeImpl (Context, Ptr, OwnZeMemHandle );
8146
+ void USMMemoryAllocBase::deallocate (void *Ptr) {
8147
+ auto Res = USMFreeImpl (Context, Ptr);
8152
8148
if (Res != PI_SUCCESS) {
8153
8149
throw UsmAllocationException (Res);
8154
8150
}
@@ -8396,8 +8392,13 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
8396
8392
Context->MemAllocs .erase (It);
8397
8393
}
8398
8394
8395
+ if (!OwnZeMemHandle) {
8396
+ // Memory should not be freed
8397
+ return PI_SUCCESS;
8398
+ }
8399
+
8399
8400
if (!UseUSMAllocator) {
8400
- pi_result Res = USMFreeImpl (Context, Ptr, OwnZeMemHandle );
8401
+ pi_result Res = USMFreeImpl (Context, Ptr);
8401
8402
if (IndirectAccessTrackingEnabled)
8402
8403
PI_CALL (ContextReleaseHelper (Context));
8403
8404
return Res;
@@ -8416,7 +8417,7 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
8416
8417
// If memory type is host release from host pool
8417
8418
if (ZeMemoryAllocationProperties.type == ZE_MEMORY_TYPE_HOST) {
8418
8419
try {
8419
- Context->HostMemAllocContext ->deallocate (Ptr, OwnZeMemHandle );
8420
+ Context->HostMemAllocContext ->deallocate (Ptr);
8420
8421
} catch (const UsmAllocationException &Ex) {
8421
8422
return Ex.getError ();
8422
8423
} catch (...) {
@@ -8444,16 +8445,16 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
8444
8445
PI_ASSERT (Device, PI_ERROR_INVALID_DEVICE);
8445
8446
8446
8447
auto DeallocationHelper =
8447
- [Context, Device, Ptr,
8448
- OwnZeMemHandle ](std::unordered_map<ze_device_handle_t , USMAllocContext>
8449
- &AllocContextMap) {
8448
+ [Context, Device,
8449
+ Ptr ](std::unordered_map<ze_device_handle_t , USMAllocContext>
8450
+ &AllocContextMap) {
8450
8451
try {
8451
8452
auto It = AllocContextMap.find (Device->ZeDevice );
8452
8453
if (It == AllocContextMap.end ())
8453
8454
return PI_ERROR_INVALID_VALUE;
8454
8455
8455
8456
// The right context is found, deallocate the pointer
8456
- It->second .deallocate (Ptr, OwnZeMemHandle );
8457
+ It->second .deallocate (Ptr);
8457
8458
} catch (const UsmAllocationException &Ex) {
8458
8459
return Ex.getError ();
8459
8460
}
@@ -8479,7 +8480,7 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
8479
8480
}
8480
8481
}
8481
8482
8482
- pi_result Res = USMFreeImpl (Context, Ptr, OwnZeMemHandle );
8483
+ pi_result Res = USMFreeImpl (Context, Ptr);
8483
8484
if (SharedReadOnlyAllocsIterator != Context->SharedReadOnlyAllocs .end ()) {
8484
8485
Context->SharedReadOnlyAllocs .erase (SharedReadOnlyAllocsIterator);
8485
8486
}
@@ -8494,7 +8495,7 @@ pi_result piextUSMFree(pi_context Context, void *Ptr) {
8494
8495
std::scoped_lock<pi_shared_mutex> Lock (
8495
8496
IndirectAccessTrackingEnabled ? Plt->ContextsMutex : Context->Mutex );
8496
8497
8497
- return USMFreeHelper (Context, Ptr, true /* OwnZeMemHandle */ );
8498
+ return USMFreeHelper (Context, Ptr);
8498
8499
}
8499
8500
8500
8501
pi_result piextKernelSetArgPointer (pi_kernel Kernel, pi_uint32 ArgIndex,
@@ -9410,11 +9411,11 @@ pi_result _pi_buffer::free() {
9410
9411
std::scoped_lock<pi_shared_mutex> Lock (
9411
9412
IndirectAccessTrackingEnabled ? Plt->ContextsMutex : Context->Mutex );
9412
9413
9413
- PI_CALL (USMFreeHelper (Context, ZeHandle, true ));
9414
+ PI_CALL (USMFreeHelper (Context, ZeHandle));
9414
9415
break ;
9415
9416
}
9416
9417
case allocation_t ::free_native:
9417
- PI_CALL (ZeMemFreeHelper (Context, ZeHandle, true ));
9418
+ PI_CALL (ZeMemFreeHelper (Context, ZeHandle));
9418
9419
break ;
9419
9420
case allocation_t ::unimport:
9420
9421
ZeUSMImport.doZeUSMRelease (Context->getPlatform ()->ZeDriver , ZeHandle);
0 commit comments