@@ -75,6 +75,14 @@ umf_result_t umfLevelZeroMemoryProviderParamsSetResidentDevices(
75
75
return UMF_RESULT_ERROR_NOT_SUPPORTED ;
76
76
}
77
77
78
+ umf_result_t umfLevelZeroMemoryProviderParamsSetFreePolicy (
79
+ umf_level_zero_memory_provider_params_handle_t hParams ,
80
+ umf_level_zero_memory_provider_free_policy_t policy ) {
81
+ (void )hParams ;
82
+ (void )policy ;
83
+ return UMF_RESULT_ERROR_NOT_SUPPORTED ;
84
+ }
85
+
78
86
umf_memory_provider_ops_t * umfLevelZeroMemoryProviderOps (void ) {
79
87
// not supported
80
88
LOG_ERR ("L0 memory provider is disabled! (UMF_BUILD_LEVEL_ZERO_PROVIDER is "
@@ -107,6 +115,9 @@ typedef struct umf_level_zero_memory_provider_params_t {
107
115
resident_device_handles ; ///< Array of devices for which the memory should be made resident
108
116
uint32_t
109
117
resident_device_count ; ///< Number of devices for which the memory should be made resident
118
+
119
+ umf_level_zero_memory_provider_free_policy_t
120
+ freePolicy ; ///< Memory free policy
110
121
} umf_level_zero_memory_provider_params_t ;
111
122
112
123
typedef struct ze_memory_provider_t {
@@ -118,6 +129,8 @@ typedef struct ze_memory_provider_t {
118
129
uint32_t resident_device_count ;
119
130
120
131
ze_device_properties_t device_properties ;
132
+
133
+ ze_driver_memory_free_policy_ext_flags_t freePolicyFlags ;
121
134
} ze_memory_provider_t ;
122
135
123
136
typedef struct ze_ops_t {
@@ -144,6 +157,8 @@ typedef struct ze_ops_t {
144
157
size_t );
145
158
ze_result_t (* zeDeviceGetProperties )(ze_device_handle_t ,
146
159
ze_device_properties_t * );
160
+ ze_result_t (* zeMemFreeExt )(ze_context_handle_t ,
161
+ ze_memory_free_ext_desc_t * , void * );
147
162
} ze_ops_t ;
148
163
149
164
static ze_ops_t g_ze_ops ;
@@ -197,6 +212,8 @@ static void init_ze_global_state(void) {
197
212
utils_get_symbol_addr (0 , "zeContextMakeMemoryResident" , lib_name );
198
213
* (void * * )& g_ze_ops .zeDeviceGetProperties =
199
214
utils_get_symbol_addr (0 , "zeDeviceGetProperties" , lib_name );
215
+ * (void * * )& g_ze_ops .zeMemFreeExt =
216
+ utils_get_symbol_addr (0 , "zeMemFreeExt" , lib_name );
200
217
201
218
if (!g_ze_ops .zeMemAllocHost || !g_ze_ops .zeMemAllocDevice ||
202
219
!g_ze_ops .zeMemAllocShared || !g_ze_ops .zeMemFree ||
@@ -232,6 +249,7 @@ umf_result_t umfLevelZeroMemoryProviderParamsCreate(
232
249
params -> memory_type = UMF_MEMORY_TYPE_UNKNOWN ;
233
250
params -> resident_device_handles = NULL ;
234
251
params -> resident_device_count = 0 ;
252
+ params -> freePolicy = UMF_LEVEL_ZERO_MEMORY_PROVIDER_FREE_POLICY_DEFAULT ;
235
253
236
254
* hParams = params ;
237
255
@@ -308,6 +326,32 @@ umf_result_t umfLevelZeroMemoryProviderParamsSetResidentDevices(
308
326
return UMF_RESULT_SUCCESS ;
309
327
}
310
328
329
+ umf_result_t umfLevelZeroMemoryProviderParamsSetFreePolicy (
330
+ umf_level_zero_memory_provider_params_handle_t hParams ,
331
+ umf_level_zero_memory_provider_free_policy_t policy ) {
332
+ if (!hParams ) {
333
+ LOG_ERR ("Level zero memory provider params handle is NULL" );
334
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
335
+ }
336
+
337
+ hParams -> freePolicy = policy ;
338
+ return UMF_RESULT_SUCCESS ;
339
+ }
340
+
341
+ static ze_driver_memory_free_policy_ext_flags_t
342
+ umfFreePolicyToZePolicy (umf_level_zero_memory_provider_free_policy_t policy ) {
343
+ switch (policy ) {
344
+ case UMF_LEVEL_ZERO_MEMORY_PROVIDER_FREE_POLICY_DEFAULT :
345
+ return 0 ;
346
+ case UMF_LEVEL_ZERO_MEMORY_PROVIDER_FREE_POLICY_BLOCKING_FREE :
347
+ return ZE_DRIVER_MEMORY_FREE_POLICY_EXT_FLAG_BLOCKING_FREE ;
348
+ case UMF_LEVEL_ZERO_MEMORY_PROVIDER_FREE_POLICY_DEFER_FREE :
349
+ return ZE_DRIVER_MEMORY_FREE_POLICY_EXT_FLAG_DEFER_FREE ;
350
+ default :
351
+ return 0 ;
352
+ }
353
+ }
354
+
311
355
static umf_result_t ze_memory_provider_initialize (void * params ,
312
356
void * * provider ) {
313
357
if (params == NULL ) {
@@ -351,6 +395,8 @@ static umf_result_t ze_memory_provider_initialize(void *params,
351
395
ze_provider -> context = ze_params -> level_zero_context_handle ;
352
396
ze_provider -> device = ze_params -> level_zero_device_handle ;
353
397
ze_provider -> memory_type = (ze_memory_type_t )ze_params -> memory_type ;
398
+ ze_provider -> freePolicyFlags =
399
+ umfFreePolicyToZePolicy (ze_params -> freePolicy );
354
400
355
401
memset (& ze_provider -> device_properties , 0 ,
356
402
sizeof (ze_provider -> device_properties ));
@@ -493,8 +539,18 @@ static umf_result_t ze_memory_provider_free(void *provider, void *ptr,
493
539
}
494
540
495
541
ze_memory_provider_t * ze_provider = (ze_memory_provider_t * )provider ;
496
- ze_result_t ze_result = g_ze_ops .zeMemFree (ze_provider -> context , ptr );
497
- return ze2umf_result (ze_result );
542
+
543
+ if (ze_provider -> freePolicyFlags == 0 ) {
544
+ return ze2umf_result (g_ze_ops .zeMemFree (ze_provider -> context , ptr ));
545
+ }
546
+
547
+ ze_memory_free_ext_desc_t desc = {
548
+ .stype = ZE_STRUCTURE_TYPE_MEMORY_FREE_EXT_DESC ,
549
+ .pNext = NULL ,
550
+ .freePolicy = ze_provider -> freePolicyFlags };
551
+
552
+ return ze2umf_result (
553
+ g_ze_ops .zeMemFreeExt (ze_provider -> context , & desc , ptr ));
498
554
}
499
555
500
556
static void ze_memory_provider_get_last_native_error (void * provider ,
0 commit comments