@@ -27,6 +27,11 @@ typedef struct ze_memory_provider_t {
27
27
ze_context_handle_t context ;
28
28
ze_device_handle_t device ;
29
29
ze_memory_type_t memory_type ;
30
+
31
+ ze_device_handle_t * resident_device_handles ;
32
+ uint32_t resident_device_count ;
33
+
34
+ ze_device_properties_t device_properties ;
30
35
} ze_memory_provider_t ;
31
36
32
37
typedef struct ze_ops_t {
@@ -48,11 +53,35 @@ typedef struct ze_ops_t {
48
53
ze_ipc_mem_handle_t ,
49
54
ze_ipc_memory_flags_t , void * * );
50
55
ze_result_t (* zeMemCloseIpcHandle )(ze_context_handle_t , void * );
56
+ ze_result_t (* zeContextMakeMemoryResident )(ze_context_handle_t ,
57
+ ze_device_handle_t , void * ,
58
+ size_t );
59
+ ze_result_t (* zeDeviceGetProperties )(ze_device_handle_t ,
60
+ ze_device_properties_t * );
51
61
} ze_ops_t ;
52
62
53
63
static ze_ops_t g_ze_ops ;
54
64
static UTIL_ONCE_FLAG ze_is_initialized = UTIL_ONCE_FLAG_INIT ;
55
65
static bool Init_ze_global_state_failed ;
66
+ static __TLS ze_result_t TLS_last_native_error ;
67
+
68
+ static void store_last_native_error (int32_t native_error ) {
69
+ TLS_last_native_error = native_error ;
70
+ }
71
+
72
+ umf_result_t ze2umf_result (ze_result_t result ) {
73
+ switch (result ) {
74
+ case ZE_RESULT_SUCCESS :
75
+ return UMF_RESULT_SUCCESS ;
76
+ case ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY :
77
+ return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY ;
78
+ case ZE_RESULT_ERROR_INVALID_ARGUMENT :
79
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
80
+ default :
81
+ store_last_native_error (result );
82
+ return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC ;
83
+ }
84
+ }
56
85
57
86
static void init_ze_global_state (void ) {
58
87
#ifdef _WIN32
@@ -78,11 +107,17 @@ static void init_ze_global_state(void) {
78
107
util_get_symbol_addr (0 , "zeMemOpenIpcHandle" , lib_name );
79
108
* (void * * )& g_ze_ops .zeMemCloseIpcHandle =
80
109
util_get_symbol_addr (0 , "zeMemCloseIpcHandle" , lib_name );
110
+ * (void * * )& g_ze_ops .zeContextMakeMemoryResident =
111
+ util_get_symbol_addr (0 , "zeContextMakeMemoryResident" , lib_name );
112
+ * (void * * )& g_ze_ops .zeDeviceGetProperties =
113
+ util_get_symbol_addr (0 , "zeDeviceGetProperties" , lib_name );
81
114
82
115
if (!g_ze_ops .zeMemAllocHost || !g_ze_ops .zeMemAllocDevice ||
83
116
!g_ze_ops .zeMemAllocShared || !g_ze_ops .zeMemFree ||
84
117
!g_ze_ops .zeMemGetIpcHandle || !g_ze_ops .zeMemOpenIpcHandle ||
85
- !g_ze_ops .zeMemCloseIpcHandle ) {
118
+ !g_ze_ops .zeMemCloseIpcHandle ||
119
+ !g_ze_ops .zeContextMakeMemoryResident ||
120
+ !g_ze_ops .zeDeviceGetProperties ) {
86
121
// g_ze_ops.zeMemPutIpcHandle can be NULL because it was introduced
87
122
// starting from Level Zero 1.6
88
123
LOG_ERR ("Required Level Zero symbols not found." );
@@ -114,6 +149,14 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
114
149
ze_provider -> device = ze_params -> level_zero_device_handle ;
115
150
ze_provider -> memory_type = (ze_memory_type_t )ze_params -> memory_type ;
116
151
152
+ umf_result_t ret = ze2umf_result (g_ze_ops .zeDeviceGetProperties (
153
+ ze_provider -> device , & ze_provider -> device_properties ));
154
+
155
+ if (ret != UMF_RESULT_SUCCESS ) {
156
+ umf_ba_global_free (ze_provider );
157
+ return ret ;
158
+ }
159
+
117
160
* provider = ze_provider ;
118
161
119
162
return UMF_RESULT_SUCCESS ;
@@ -138,6 +181,16 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
138
181
139
182
ze_memory_provider_t * ze_provider = (ze_memory_provider_t * )provider ;
140
183
184
+ bool useRelaxedAllocationFlag =
185
+ size > ze_provider -> device_properties .maxMemAllocSize ;
186
+ ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = {
187
+ .stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC ,
188
+ .pNext = NULL ,
189
+ .flags = 0 };
190
+ if (useRelaxedAllocationFlag ) {
191
+ relaxed_desc .flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE ;
192
+ }
193
+
141
194
ze_result_t ze_result = ZE_RESULT_SUCCESS ;
142
195
switch (ze_provider -> memory_type ) {
143
196
case UMF_MEMORY_TYPE_HOST : {
@@ -152,7 +205,7 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
152
205
case UMF_MEMORY_TYPE_DEVICE : {
153
206
ze_device_mem_alloc_desc_t dev_desc = {
154
207
.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC ,
155
- .pNext = NULL ,
208
+ .pNext = useRelaxedAllocationFlag ? & relaxed_desc : NULL ,
156
209
.flags = 0 ,
157
210
.ordinal = 0 // TODO
158
211
};
@@ -168,7 +221,7 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
168
221
.flags = 0 };
169
222
ze_device_mem_alloc_desc_t dev_desc = {
170
223
.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC ,
171
- .pNext = NULL ,
224
+ .pNext = useRelaxedAllocationFlag ? & relaxed_desc : NULL ,
172
225
.flags = 0 ,
173
226
.ordinal = 0 // TODO
174
227
};
@@ -178,13 +231,23 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
178
231
break ;
179
232
}
180
233
default :
181
- return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC ;
234
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
182
235
}
183
236
184
- // TODO add error reporting
185
- return (ze_result == ZE_RESULT_SUCCESS )
186
- ? UMF_RESULT_SUCCESS
187
- : UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC ;
237
+ if (ze_result != ZE_RESULT_SUCCESS ) {
238
+ return ze2umf_result (ze_result );
239
+ }
240
+
241
+ for (uint32_t i = 0 ; i < ze_provider -> resident_device_count ; i ++ ) {
242
+ ze_result = g_ze_ops .zeContextMakeMemoryResident (
243
+ ze_provider -> context , ze_provider -> resident_device_handles [i ],
244
+ * resultPtr , size );
245
+ if (ze_result != ZE_RESULT_SUCCESS ) {
246
+ return ze2umf_result (ze_result );
247
+ }
248
+ }
249
+
250
+ return ze2umf_result (ze_result );
188
251
}
189
252
190
253
static umf_result_t ze_memory_provider_free (void * provider , void * ptr ,
@@ -194,11 +257,7 @@ static umf_result_t ze_memory_provider_free(void *provider, void *ptr,
194
257
assert (provider );
195
258
ze_memory_provider_t * ze_provider = (ze_memory_provider_t * )provider ;
196
259
ze_result_t ze_result = g_ze_ops .zeMemFree (ze_provider -> context , ptr );
197
-
198
- // TODO add error reporting
199
- return (ze_result == ZE_RESULT_SUCCESS )
200
- ? UMF_RESULT_SUCCESS
201
- : UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC ;
260
+ return ze2umf_result (ze_result );
202
261
}
203
262
204
263
void ze_memory_provider_get_last_native_error (void * provider ,
@@ -207,9 +266,9 @@ void ze_memory_provider_get_last_native_error(void *provider,
207
266
(void )provider ;
208
267
(void )ppMessage ;
209
268
210
- // TODO
211
269
assert (pError );
212
- * pError = 0 ;
270
+
271
+ * pError = TLS_last_native_error ;
213
272
}
214
273
215
274
static umf_result_t ze_memory_provider_get_min_page_size (void * provider ,
@@ -314,7 +373,7 @@ static umf_result_t ze_memory_provider_get_ipc_handle(void *provider,
314
373
& ze_ipc_data -> ze_handle );
315
374
if (ze_result != ZE_RESULT_SUCCESS ) {
316
375
LOG_ERR ("zeMemGetIpcHandle() failed." );
317
- return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC ;
376
+ return ze2umf_result ( ze_result ) ;
318
377
}
319
378
320
379
ze_ipc_data -> pid = utils_getpid ();
@@ -342,7 +401,7 @@ static umf_result_t ze_memory_provider_put_ipc_handle(void *provider,
342
401
ze_ipc_data -> ze_handle );
343
402
if (ze_result != ZE_RESULT_SUCCESS ) {
344
403
LOG_ERR ("zeMemPutIpcHandle() failed." );
345
- return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC ;
404
+ return ze2umf_result ( ze_result ) ;
346
405
}
347
406
return UMF_RESULT_SUCCESS ;
348
407
}
@@ -379,7 +438,7 @@ static umf_result_t ze_memory_provider_open_ipc_handle(void *provider,
379
438
}
380
439
if (ze_result != ZE_RESULT_SUCCESS ) {
381
440
LOG_ERR ("zeMemOpenIpcHandle() failed." );
382
- return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC ;
441
+ return ze2umf_result ( ze_result ) ;
383
442
}
384
443
385
444
return UMF_RESULT_SUCCESS ;
@@ -397,7 +456,7 @@ ze_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
397
456
ze_result = g_ze_ops .zeMemCloseIpcHandle (ze_provider -> context , ptr );
398
457
if (ze_result != ZE_RESULT_SUCCESS ) {
399
458
LOG_ERR ("zeMemCloseIpcHandle() failed." );
400
- return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC ;
459
+ return ze2umf_result ( ze_result ) ;
401
460
}
402
461
403
462
return UMF_RESULT_SUCCESS ;
0 commit comments