@@ -133,6 +133,25 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
133
133
level_zero_memory_provider_params_t * ze_params =
134
134
(level_zero_memory_provider_params_t * )params ;
135
135
136
+ if (!ze_params -> level_zero_context_handle ) {
137
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
138
+ }
139
+
140
+ if (ze_params -> memory_type == UMF_MEMORY_TYPE_HOST &&
141
+ ze_params -> level_zero_device_handle ) {
142
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
143
+ }
144
+
145
+ if (ze_params -> memory_type != UMF_MEMORY_TYPE_HOST &&
146
+ !ze_params -> level_zero_device_handle ) {
147
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
148
+ }
149
+
150
+ if ((bool )ze_params -> resident_device_count !=
151
+ (bool )ze_params -> resident_device_handles ) {
152
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
153
+ }
154
+
136
155
util_init_once (& ze_is_initialized , init_ze_global_state );
137
156
if (Init_ze_global_state_failed ) {
138
157
LOG_ERR ("Loading Level Zero symbols failed" );
@@ -149,12 +168,17 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
149
168
ze_provider -> device = ze_params -> level_zero_device_handle ;
150
169
ze_provider -> memory_type = (ze_memory_type_t )ze_params -> memory_type ;
151
170
152
- umf_result_t ret = ze2umf_result (g_ze_ops .zeDeviceGetProperties (
153
- ze_provider -> device , & ze_provider -> device_properties ));
171
+ if (ze_provider -> device ) {
172
+ umf_result_t ret = ze2umf_result (g_ze_ops .zeDeviceGetProperties (
173
+ ze_provider -> device , & ze_provider -> device_properties ));
154
174
155
- if (ret != UMF_RESULT_SUCCESS ) {
156
- umf_ba_global_free (ze_provider );
157
- return ret ;
175
+ if (ret != UMF_RESULT_SUCCESS ) {
176
+ umf_ba_global_free (ze_provider );
177
+ return ret ;
178
+ }
179
+ } else {
180
+ memset (& ze_provider -> device_properties , 0 ,
181
+ sizeof (ze_provider -> device_properties ));
158
182
}
159
183
160
184
* provider = ze_provider ;
@@ -173,6 +197,18 @@ void ze_memory_provider_finalize(void *provider) {
173
197
memcpy (& ze_is_initialized , & is_initialized , sizeof (ze_is_initialized ));
174
198
}
175
199
200
+ static bool use_relaxed_allocation (ze_memory_provider_t * ze_provider ,
201
+ size_t size ) {
202
+ assert (ze_provider -> device );
203
+ assert (ze_provider -> device_properties .maxMemAllocSize > 0 );
204
+ return size > ze_provider -> device_properties .maxMemAllocSize ;
205
+ }
206
+
207
+ static ze_relaxed_allocation_limits_exp_desc_t relaxed_device_allocation_desc =
208
+ {.stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC ,
209
+ .pNext = NULL ,
210
+ .flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE };
211
+
176
212
static umf_result_t ze_memory_provider_alloc (void * provider , size_t size ,
177
213
size_t alignment ,
178
214
void * * resultPtr ) {
@@ -181,16 +217,6 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
181
217
182
218
ze_memory_provider_t * ze_provider = (ze_memory_provider_t * )provider ;
183
219
184
- bool useRelaxedAllocationFlag =
185
- size > ze_provider -> device_properties .maxMemAllocSize ;
186
- ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = {
187
- .stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC ,
188
- .pNext = NULL ,
189
- .flags = 0 };
190
- if (useRelaxedAllocationFlag ) {
191
- relaxed_desc .flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE ;
192
- }
193
-
194
220
ze_result_t ze_result = ZE_RESULT_SUCCESS ;
195
221
switch (ze_provider -> memory_type ) {
196
222
case UMF_MEMORY_TYPE_HOST : {
@@ -205,7 +231,9 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
205
231
case UMF_MEMORY_TYPE_DEVICE : {
206
232
ze_device_mem_alloc_desc_t dev_desc = {
207
233
.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC ,
208
- .pNext = useRelaxedAllocationFlag ? & relaxed_desc : NULL ,
234
+ .pNext = use_relaxed_allocation (ze_provider , size )
235
+ ? & relaxed_device_allocation_desc
236
+ : NULL ,
209
237
.flags = 0 ,
210
238
.ordinal = 0 // TODO
211
239
};
@@ -220,8 +248,10 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
220
248
.pNext = NULL ,
221
249
.flags = 0 };
222
250
ze_device_mem_alloc_desc_t dev_desc = {
223
- .stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC ,
224
- .pNext = useRelaxedAllocationFlag ? & relaxed_desc : NULL ,
251
+ .stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC ,
252
+ .pNext = use_relaxed_allocation (ze_provider , size )
253
+ ? & relaxed_device_allocation_desc
254
+ : NULL ,
225
255
.flags = 0 ,
226
256
.ordinal = 0 // TODO
227
257
};
0 commit comments