@@ -30,6 +30,8 @@ typedef struct ze_memory_provider_t {
30
30
31
31
ze_device_handle_t * resident_device_handles ;
32
32
uint32_t resident_device_count ;
33
+
34
+ ze_device_properties_t device_properties ;
33
35
} ze_memory_provider_t ;
34
36
35
37
typedef struct ze_ops_t {
@@ -54,6 +56,8 @@ typedef struct ze_ops_t {
54
56
ze_result_t (* zeContextMakeMemoryResident )(ze_context_handle_t ,
55
57
ze_device_handle_t , void * ,
56
58
size_t );
59
+ ze_result_t (* zeDeviceGetProperties )(ze_device_handle_t ,
60
+ ze_device_properties_t * );
57
61
} ze_ops_t ;
58
62
59
63
static ze_ops_t g_ze_ops ;
@@ -86,12 +90,15 @@ static void init_ze_global_state(void) {
86
90
util_get_symbol_addr (0 , "zeMemCloseIpcHandle" , lib_name );
87
91
* (void * * )& g_ze_ops .zeContextMakeMemoryResident =
88
92
util_get_symbol_addr (0 , "zeContextMakeMemoryResident" , lib_name );
93
+ * (void * * )& g_ze_ops .zeDeviceGetProperties =
94
+ util_get_symbol_addr (0 , "zeDeviceGetProperties" , lib_name );
89
95
90
96
if (!g_ze_ops .zeMemAllocHost || !g_ze_ops .zeMemAllocDevice ||
91
97
!g_ze_ops .zeMemAllocShared || !g_ze_ops .zeMemFree ||
92
98
!g_ze_ops .zeMemGetIpcHandle || !g_ze_ops .zeMemOpenIpcHandle ||
93
99
!g_ze_ops .zeMemCloseIpcHandle ||
94
- !g_ze_ops .zeContextMakeMemoryResident ) {
100
+ !g_ze_ops .zeContextMakeMemoryResident ||
101
+ !g_ze_ops .zeDeviceGetProperties ) {
95
102
// g_ze_ops.zeMemPutIpcHandle can be NULL because it was introduced
96
103
// starting from Level Zero 1.6
97
104
LOG_ERR ("Required Level Zero symbols not found." );
@@ -123,6 +130,13 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
123
130
ze_provider -> device = ze_params -> level_zero_device_handle ;
124
131
ze_provider -> memory_type = (ze_memory_type_t )ze_params -> memory_type ;
125
132
133
+ ze_result_t ret = g_ze_ops .zeDeviceGetProperties (
134
+ ze_provider -> device , & ze_provider -> device_properties );
135
+ if (ret != ZE_RESULT_SUCCESS ) {
136
+ umf_ba_global_free (ze_provider );
137
+ return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC ;
138
+ }
139
+
126
140
* provider = ze_provider ;
127
141
128
142
return UMF_RESULT_SUCCESS ;
@@ -147,6 +161,16 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
147
161
148
162
ze_memory_provider_t * ze_provider = (ze_memory_provider_t * )provider ;
149
163
164
+ bool useRelaxedAllocationFlag =
165
+ size > ze_provider -> device_properties .maxMemAllocSize ;
166
+ ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = {
167
+ .stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC ,
168
+ .pNext = NULL ,
169
+ .flags = 0 };
170
+ if (useRelaxedAllocationFlag ) {
171
+ relaxed_desc .flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE ;
172
+ }
173
+
150
174
ze_result_t ze_result = ZE_RESULT_SUCCESS ;
151
175
switch (ze_provider -> memory_type ) {
152
176
case UMF_MEMORY_TYPE_HOST : {
@@ -161,7 +185,7 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
161
185
case UMF_MEMORY_TYPE_DEVICE : {
162
186
ze_device_mem_alloc_desc_t dev_desc = {
163
187
.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC ,
164
- .pNext = NULL ,
188
+ .pNext = useRelaxedAllocationFlag ? & relaxed_desc : NULL ,
165
189
.flags = 0 ,
166
190
.ordinal = 0 // TODO
167
191
};
@@ -177,8 +201,8 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
177
201
.flags = 0 };
178
202
ze_device_mem_alloc_desc_t dev_desc = {
179
203
.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC ,
180
- .pNext = NULL ,
181
- .flags = 0 ,
204
+ .flags = NULL ,
205
+ .pNext = useRelaxedAllocationFlag ? & relaxed_desc : NULL ,
182
206
.ordinal = 0 // TODO
183
207
};
184
208
ze_result = g_ze_ops .zeMemAllocShared (ze_provider -> context , & dev_desc ,
0 commit comments