@@ -61,6 +61,14 @@ umf_result_t umfLevelZeroMemoryProviderParamsSetResidentDevices(
61
61
return UMF_RESULT_ERROR_NOT_SUPPORTED ;
62
62
}
63
63
64
+ umf_result_t umfLevelZeroMemoryProviderParamsSetFreePolicy (
65
+ umf_level_zero_memory_provider_params_handle_t hParams ,
66
+ umf_level_zero_memory_provider_free_policy_t policy ) {
67
+ (void )hParams ;
68
+ (void )policy ;
69
+ return UMF_RESULT_ERROR_NOT_SUPPORTED ;
70
+ }
71
+
64
72
umf_memory_provider_ops_t * umfLevelZeroMemoryProviderOps (void ) {
65
73
// not supported
66
74
return NULL ;
@@ -91,6 +99,9 @@ typedef struct umf_level_zero_memory_provider_params_t {
91
99
resident_device_handles ; ///< Array of devices for which the memory should be made resident
92
100
uint32_t
93
101
resident_device_count ; ///< Number of devices for which the memory should be made resident
102
+
103
+ umf_level_zero_memory_provider_free_policy_t
104
+ freePolicy ; ///< Memory free policy
94
105
} umf_level_zero_memory_provider_params_t ;
95
106
96
107
typedef struct ze_memory_provider_t {
@@ -102,6 +113,8 @@ typedef struct ze_memory_provider_t {
102
113
uint32_t resident_device_count ;
103
114
104
115
ze_device_properties_t device_properties ;
116
+
117
+ ze_driver_memory_free_policy_ext_flags_t freePolicyFlags ;
105
118
} ze_memory_provider_t ;
106
119
107
120
typedef struct ze_ops_t {
@@ -128,6 +141,8 @@ typedef struct ze_ops_t {
128
141
size_t );
129
142
ze_result_t (* zeDeviceGetProperties )(ze_device_handle_t ,
130
143
ze_device_properties_t * );
144
+ ze_result_t (* zeMemFreeExt )(ze_context_handle_t ,
145
+ ze_memory_free_ext_desc_t * , void * );
131
146
} ze_ops_t ;
132
147
133
148
static ze_ops_t g_ze_ops ;
@@ -181,6 +196,8 @@ static void init_ze_global_state(void) {
181
196
utils_get_symbol_addr (0 , "zeContextMakeMemoryResident" , lib_name );
182
197
* (void * * )& g_ze_ops .zeDeviceGetProperties =
183
198
utils_get_symbol_addr (0 , "zeDeviceGetProperties" , lib_name );
199
+ * (void * * )& g_ze_ops .zeMemFreeExt =
200
+ utils_get_symbol_addr (0 , "zeMemFreeExt" , lib_name );
184
201
185
202
if (!g_ze_ops .zeMemAllocHost || !g_ze_ops .zeMemAllocDevice ||
186
203
!g_ze_ops .zeMemAllocShared || !g_ze_ops .zeMemFree ||
@@ -216,6 +233,7 @@ umf_result_t umfLevelZeroMemoryProviderParamsCreate(
216
233
params -> memory_type = UMF_MEMORY_TYPE_UNKNOWN ;
217
234
params -> resident_device_handles = NULL ;
218
235
params -> resident_device_count = 0 ;
236
+ params -> freePolicy = UMF_LEVEL_ZERO_MEMORY_PROVIDER_FREE_POLICY_DEFAULT ;
219
237
220
238
* hParams = params ;
221
239
@@ -292,6 +310,32 @@ umf_result_t umfLevelZeroMemoryProviderParamsSetResidentDevices(
292
310
return UMF_RESULT_SUCCESS ;
293
311
}
294
312
313
+ umf_result_t umfLevelZeroMemoryProviderParamsSetFreePolicy (
314
+ umf_level_zero_memory_provider_params_handle_t hParams ,
315
+ umf_level_zero_memory_provider_free_policy_t policy ) {
316
+ if (!hParams ) {
317
+ LOG_ERR ("Level zero memory provider params handle is NULL" );
318
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
319
+ }
320
+
321
+ hParams -> freePolicy = policy ;
322
+ return UMF_RESULT_SUCCESS ;
323
+ }
324
+
325
+ static ze_driver_memory_free_policy_ext_flags_t
326
+ umfFreePolicyToZePolicy (umf_level_zero_memory_provider_free_policy_t policy ) {
327
+ switch (policy ) {
328
+ case UMF_LEVEL_ZERO_MEMORY_PROVIDER_FREE_POLICY_DEFAULT :
329
+ return 0 ;
330
+ case UMF_LEVEL_ZERO_MEMORY_PROVIDER_FREE_POLICY_BLOCKING_FREE :
331
+ return ZE_DRIVER_MEMORY_FREE_POLICY_EXT_FLAG_BLOCKING_FREE ;
332
+ case UMF_LEVEL_ZERO_MEMORY_PROVIDER_FREE_POLICY_DEFER_FREE :
333
+ return ZE_DRIVER_MEMORY_FREE_POLICY_EXT_FLAG_DEFER_FREE ;
334
+ default :
335
+ return 0 ;
336
+ }
337
+ }
338
+
295
339
static umf_result_t ze_memory_provider_initialize (void * params ,
296
340
void * * provider ) {
297
341
if (params == NULL ) {
@@ -335,6 +379,8 @@ static umf_result_t ze_memory_provider_initialize(void *params,
335
379
ze_provider -> context = ze_params -> level_zero_context_handle ;
336
380
ze_provider -> device = ze_params -> level_zero_device_handle ;
337
381
ze_provider -> memory_type = (ze_memory_type_t )ze_params -> memory_type ;
382
+ ze_provider -> freePolicyFlags =
383
+ umfFreePolicyToZePolicy (ze_params -> freePolicy );
338
384
339
385
if (ze_provider -> device ) {
340
386
umf_result_t ret = ze2umf_result (g_ze_ops .zeDeviceGetProperties (
@@ -476,8 +522,18 @@ static umf_result_t ze_memory_provider_free(void *provider, void *ptr,
476
522
}
477
523
478
524
ze_memory_provider_t * ze_provider = (ze_memory_provider_t * )provider ;
479
- ze_result_t ze_result = g_ze_ops .zeMemFree (ze_provider -> context , ptr );
480
- return ze2umf_result (ze_result );
525
+
526
+ if (ze_provider -> freePolicyFlags == 0 ) {
527
+ return ze2umf_result (g_ze_ops .zeMemFree (ze_provider -> context , ptr ));
528
+ }
529
+
530
+ ze_memory_free_ext_desc_t desc = {
531
+ .stype = ZE_STRUCTURE_TYPE_MEMORY_FREE_EXT_DESC ,
532
+ .pNext = NULL ,
533
+ .freePolicy = ze_provider -> freePolicyFlags };
534
+
535
+ return ze2umf_result (
536
+ g_ze_ops .zeMemFreeExt (ze_provider -> context , & desc , ptr ));
481
537
}
482
538
483
539
static void ze_memory_provider_get_last_native_error (void * provider ,
0 commit comments