@@ -3032,6 +3032,74 @@ pi_result piextQueueCreateWithNativeHandle(pi_native_handle NativeHandle,
3032
3032
return PI_SUCCESS;
3033
3033
}
3034
3034
3035
+ // If indirect access tracking is enabled then performs reference counting,
3036
+ // otherwise just calls zeMemAllocDevice.
3037
+ static pi_result ZeDeviceMemAllocHelper (void **ResultPtr, pi_context Context,
3038
+ pi_device Device, size_t Size) {
3039
+ pi_platform Plt = Device->Platform ;
3040
+ std::unique_lock<std::mutex> ContextsLock (Plt->ContextsMutex ,
3041
+ std::defer_lock);
3042
+ if (IndirectAccessTrackingEnabled) {
3043
+ // Lock the mutex which is guarding contexts container in the platform.
3044
+ // This prevents new kernels from being submitted in any context while
3045
+ // we are in the process of allocating a memory, this is needed to
3046
+ // properly capture allocations by kernels with indirect access.
3047
+ ContextsLock.lock ();
3048
+ // We are going to defer memory release if there are kernels with
3049
+ // indirect access, that is why explicitly retain context to be sure
3050
+ // that it is released after all memory allocations in this context are
3051
+ // released.
3052
+ PI_CALL (piContextRetain (Context));
3053
+ }
3054
+
3055
+ ze_device_mem_alloc_desc_t ZeDesc = {};
3056
+ ZeDesc.flags = 0 ;
3057
+ ZeDesc.ordinal = 0 ;
3058
+ ZE_CALL (zeMemAllocDevice,
3059
+ (Context->ZeContext , &ZeDesc, Size, 1 , Device->ZeDevice , ResultPtr));
3060
+
3061
+ if (IndirectAccessTrackingEnabled) {
3062
+ // Keep track of all memory allocations in the context
3063
+ Context->MemAllocs .emplace (std::piecewise_construct,
3064
+ std::forward_as_tuple (*ResultPtr),
3065
+ std::forward_as_tuple (Context));
3066
+ }
3067
+ return PI_SUCCESS;
3068
+ }
3069
+
3070
+ // If indirect access tracking is enabled then performs reference counting,
3071
+ // otherwise just calls zeMemAllocHost.
3072
+ static pi_result ZeHostMemAllocHelper (void **ResultPtr, pi_context Context,
3073
+ size_t Size) {
3074
+ pi_platform Plt = Context->Devices [0 ]->Platform ;
3075
+ std::unique_lock<std::mutex> ContextsLock (Plt->ContextsMutex ,
3076
+ std::defer_lock);
3077
+ if (IndirectAccessTrackingEnabled) {
3078
+ // Lock the mutex which is guarding contexts container in the platform.
3079
+ // This prevents new kernels from being submitted in any context while
3080
+ // we are in the process of allocating a memory, this is needed to
3081
+ // properly capture allocations by kernels with indirect access.
3082
+ ContextsLock.lock ();
3083
+ // We are going to defer memory release if there are kernels with
3084
+ // indirect access, that is why explicitly retain context to be sure
3085
+ // that it is released after all memory allocations in this context are
3086
+ // released.
3087
+ PI_CALL (piContextRetain (Context));
3088
+ }
3089
+
3090
+ ze_host_mem_alloc_desc_t ZeDesc = {};
3091
+ ZeDesc.flags = 0 ;
3092
+ ZE_CALL (zeMemAllocHost, (Context->ZeContext , &ZeDesc, Size, 1 , ResultPtr));
3093
+
3094
+ if (IndirectAccessTrackingEnabled) {
3095
+ // Keep track of all memory allocations in the context
3096
+ Context->MemAllocs .emplace (std::piecewise_construct,
3097
+ std::forward_as_tuple (*ResultPtr),
3098
+ std::forward_as_tuple (Context));
3099
+ }
3100
+ return PI_SUCCESS;
3101
+ }
3102
+
3035
3103
pi_result piMemBufferCreate (pi_context Context, pi_mem_flags Flags, size_t Size,
3036
3104
void *HostPtr, pi_mem *RetMem,
3037
3105
const pi_mem_properties *properties) {
@@ -3094,9 +3162,7 @@ pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
3094
3162
if (enableBufferPooling ())
3095
3163
Result = piextUSMHostAlloc (&Ptr, Context, nullptr , Size, Alignment);
3096
3164
else {
3097
- ze_host_mem_alloc_desc_t ZeDesc = {};
3098
- ZeDesc.flags = 0 ;
3099
- ZE_CALL (zeMemAllocHost, (Context->ZeContext , &ZeDesc, Size, 1 , &Ptr));
3165
+ ZeHostMemAllocHelper (&Ptr, Context, Size);
3100
3166
}
3101
3167
} else if (Context->SingleRootDevice ) {
3102
3168
// If we have a single discrete device or all devices in the context are
@@ -3105,11 +3171,7 @@ pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
3105
3171
Result = piextUSMDeviceAlloc (&Ptr, Context, Context->SingleRootDevice ,
3106
3172
nullptr , Size, Alignment);
3107
3173
else {
3108
- ze_device_mem_alloc_desc_t ZeDesc = {};
3109
- ZeDesc.flags = 0 ;
3110
- ZeDesc.ordinal = 0 ;
3111
- ZE_CALL (zeMemAllocDevice, (Context->ZeContext , &ZeDesc, Size, 1 ,
3112
- Context->SingleRootDevice ->ZeDevice , &Ptr));
3174
+ ZeDeviceMemAllocHelper (&Ptr, Context, Context->SingleRootDevice , Size);
3113
3175
}
3114
3176
} else {
3115
3177
// Context with several gpu cards. Temporarily use host allocation because
@@ -3121,9 +3183,7 @@ pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
3121
3183
if (enableBufferPooling ())
3122
3184
Result = piextUSMHostAlloc (&Ptr, Context, nullptr , Size, Alignment);
3123
3185
else {
3124
- ze_host_mem_alloc_desc_t ZeDesc = {};
3125
- ZeDesc.flags = 0 ;
3126
- ZE_CALL (zeMemAllocHost, (Context->ZeContext , &ZeDesc, Size, 1 , &Ptr));
3186
+ ZeHostMemAllocHelper (&Ptr, Context, Size);
3127
3187
}
3128
3188
}
3129
3189
@@ -3190,6 +3250,37 @@ pi_result piMemRetain(pi_mem Mem) {
3190
3250
return PI_SUCCESS;
3191
3251
}
3192
3252
3253
+ // If indirect access tracking is not enabled then this functions just performs
3254
+ // zeMemFree. If indirect access tracking is enabled then reference counting is
3255
+ // performed.
3256
+ static pi_result ZeMemFreeHelper (pi_context Context, void *Ptr) {
3257
+ pi_platform Plt = Context->Devices [0 ]->Platform ;
3258
+ std::unique_lock<std::mutex> ContextsLock (Plt->ContextsMutex ,
3259
+ std::defer_lock);
3260
+ if (IndirectAccessTrackingEnabled) {
3261
+ ContextsLock.lock ();
3262
+ auto It = Context->MemAllocs .find (Ptr);
3263
+ if (It == std::end (Context->MemAllocs )) {
3264
+ die (" All memory allocations must be tracked!" );
3265
+ }
3266
+ if (--(It->second .RefCount ) != 0 ) {
3267
+ // Memory can't be deallocated yet.
3268
+ return PI_SUCCESS;
3269
+ }
3270
+
3271
+ // Reference count is zero, it is ok to free memory.
3272
+ // We don't need to track this allocation anymore.
3273
+ Context->MemAllocs .erase (It);
3274
+ }
3275
+
3276
+ ZE_CALL (zeMemFree, (Context->ZeContext , Ptr));
3277
+
3278
+ if (IndirectAccessTrackingEnabled)
3279
+ PI_CALL (ContextReleaseHelper (Context));
3280
+
3281
+ return PI_SUCCESS;
3282
+ }
3283
+
3193
3284
pi_result piMemRelease (pi_mem Mem) {
3194
3285
PI_ASSERT (Mem, PI_INVALID_MEM_OBJECT);
3195
3286
@@ -3202,7 +3293,7 @@ pi_result piMemRelease(pi_mem Mem) {
3202
3293
if (enableBufferPooling ()) {
3203
3294
PI_CALL (piextUSMFree (Mem->Context , Mem->getZeHandle ()));
3204
3295
} else {
3205
- ZE_CALL (zeMemFree, ( Mem->Context -> ZeContext , Mem->getZeHandle () ));
3296
+ ZeMemFreeHelper ( Mem->Context , Mem->getZeHandle ());
3206
3297
}
3207
3298
}
3208
3299
}
@@ -5016,13 +5107,7 @@ static pi_result EventRelease(pi_event Event, pi_queue LockedQueue) {
5016
5107
if (Event->CommandType == PI_COMMAND_TYPE_MEM_BUFFER_UNMAP &&
5017
5108
Event->CommandData ) {
5018
5109
// Free the memory allocated in the piEnqueueMemBufferMap.
5019
- // TODO: always use piextUSMFree
5020
- if (IndirectAccessTrackingEnabled) {
5021
- // Use the version with reference counting
5022
- PI_CALL (piextUSMFree (Event->Context , Event->CommandData ));
5023
- } else {
5024
- ZE_CALL (zeMemFree, (Event->Context ->ZeContext , Event->CommandData ));
5025
- }
5110
+ ZeMemFreeHelper (Event->Context , Event->CommandData );
5026
5111
Event->CommandData = nullptr ;
5027
5112
}
5028
5113
if (Event->OwnZeEvent ) {
@@ -5813,17 +5898,7 @@ pi_result piEnqueueMemBufferMap(pi_queue Queue, pi_mem Buffer,
5813
5898
if (Buffer->MapHostPtr ) {
5814
5899
*RetMap = Buffer->MapHostPtr + Offset;
5815
5900
} else {
5816
- // TODO: always use piextUSMHostAlloc
5817
- if (IndirectAccessTrackingEnabled) {
5818
- // Use the version with reference counting
5819
- PI_CALL (piextUSMHostAlloc (RetMap, Queue->Context , nullptr , Size, 1 ));
5820
- } else {
5821
- ZeStruct<ze_host_mem_alloc_desc_t > ZeDesc;
5822
- ZeDesc.flags = 0 ;
5823
-
5824
- ZE_CALL (zeMemAllocHost,
5825
- (Queue->Context ->ZeContext , &ZeDesc, Size, 1 , RetMap));
5826
- }
5901
+ ZeHostMemAllocHelper (RetMap, Queue->Context , Size);
5827
5902
}
5828
5903
const auto &ZeCommandList = CommandList->first ;
5829
5904
const auto &WaitList = (*Event)->WaitList ;
0 commit comments