Skip to content

Commit 56a1afc

Browse files
committed
L0 provider: do not accept device handle for HOST memory type
and do not call zeDeviceGetProperties for HOST provider. Also, add extra checks for parameters.
1 parent 89b660c commit 56a1afc

File tree

3 files changed

+140
-22
lines changed

3 files changed

+140
-22
lines changed

src/provider/provider_level_zero.c

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,20 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
133133
level_zero_memory_provider_params_t *ze_params =
134134
(level_zero_memory_provider_params_t *)params;
135135

136+
if (!ze_params->level_zero_context_handle) {
137+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
138+
}
139+
140+
if ((ze_params->memory_type == UMF_MEMORY_TYPE_HOST) ==
141+
(bool)ze_params->level_zero_device_handle) {
142+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
143+
}
144+
145+
if ((bool)ze_params->resident_device_count !=
146+
(bool)ze_params->resident_device_handles) {
147+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
148+
}
149+
136150
util_init_once(&ze_is_initialized, init_ze_global_state);
137151
if (Init_ze_global_state_failed) {
138152
LOG_ERR("Loading Level Zero symbols failed");
@@ -149,12 +163,17 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
149163
ze_provider->device = ze_params->level_zero_device_handle;
150164
ze_provider->memory_type = (ze_memory_type_t)ze_params->memory_type;
151165

152-
umf_result_t ret = ze2umf_result(g_ze_ops.zeDeviceGetProperties(
153-
ze_provider->device, &ze_provider->device_properties));
166+
if (ze_provider->device) {
167+
umf_result_t ret = ze2umf_result(g_ze_ops.zeDeviceGetProperties(
168+
ze_provider->device, &ze_provider->device_properties));
154169

155-
if (ret != UMF_RESULT_SUCCESS) {
156-
umf_ba_global_free(ze_provider);
157-
return ret;
170+
if (ret != UMF_RESULT_SUCCESS) {
171+
umf_ba_global_free(ze_provider);
172+
return ret;
173+
}
174+
} else {
175+
memset(&ze_provider->device_properties, 0,
176+
sizeof(ze_provider->device_properties));
158177
}
159178

160179
*provider = ze_provider;
@@ -173,6 +192,18 @@ void ze_memory_provider_finalize(void *provider) {
173192
memcpy(&ze_is_initialized, &is_initialized, sizeof(ze_is_initialized));
174193
}
175194

195+
static bool use_relaxed_allocation(ze_memory_provider_t *ze_provider,
196+
size_t size) {
197+
assert(ze_provider->device);
198+
assert(ze_provider->device_properties.maxMemAllocSize > 0);
199+
return size > ze_provider->device_properties.maxMemAllocSize;
200+
}
201+
202+
static ze_relaxed_allocation_limits_exp_desc_t relaxed_device_allocation_desc =
203+
{.stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC,
204+
.pNext = NULL,
205+
.flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE};
206+
176207
static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
177208
size_t alignment,
178209
void **resultPtr) {
@@ -181,16 +212,6 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
181212

182213
ze_memory_provider_t *ze_provider = (ze_memory_provider_t *)provider;
183214

184-
bool useRelaxedAllocationFlag =
185-
size > ze_provider->device_properties.maxMemAllocSize;
186-
ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = {
187-
.stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC,
188-
.pNext = NULL,
189-
.flags = 0};
190-
if (useRelaxedAllocationFlag) {
191-
relaxed_desc.flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE;
192-
}
193-
194215
ze_result_t ze_result = ZE_RESULT_SUCCESS;
195216
switch (ze_provider->memory_type) {
196217
case UMF_MEMORY_TYPE_HOST: {
@@ -205,7 +226,9 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
205226
case UMF_MEMORY_TYPE_DEVICE: {
206227
ze_device_mem_alloc_desc_t dev_desc = {
207228
.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
208-
.pNext = useRelaxedAllocationFlag ? &relaxed_desc : NULL,
229+
.pNext = use_relaxed_allocation(ze_provider, size)
230+
? &relaxed_device_allocation_desc
231+
: NULL,
209232
.flags = 0,
210233
.ordinal = 0 // TODO
211234
};
@@ -220,8 +243,10 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
220243
.pNext = NULL,
221244
.flags = 0};
222245
ze_device_mem_alloc_desc_t dev_desc = {
223-
.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
224-
.pNext = useRelaxedAllocationFlag ? &relaxed_desc : NULL,
246+
.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
247+
.pNext = use_relaxed_allocation(ze_provider, size)
248+
? &relaxed_device_allocation_desc
249+
: NULL,
225250
.flags = 0,
226251
.ordinal = 0 // TODO
227252
};

test/providers/level_zero_helpers.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -668,8 +668,8 @@ int init_level_zero() {
668668

669669
level_zero_memory_provider_params_t
670670
create_level_zero_prov_params(umf_usm_memory_type_t memory_type) {
671-
level_zero_memory_provider_params_t params = {NULL, NULL,
672-
UMF_MEMORY_TYPE_UNKNOWN};
671+
level_zero_memory_provider_params_t params = {
672+
NULL, NULL, UMF_MEMORY_TYPE_UNKNOWN, NULL, 0};
673673
uint32_t driver_idx = 0;
674674
ze_driver_handle_t hDriver;
675675
ze_device_handle_t hDevice;
@@ -701,7 +701,13 @@ create_level_zero_prov_params(umf_usm_memory_type_t memory_type) {
701701
}
702702

703703
params.level_zero_context_handle = hContext;
704-
params.level_zero_device_handle = hDevice;
704+
705+
if (memory_type == UMF_MEMORY_TYPE_HOST) {
706+
params.level_zero_device_handle = NULL;
707+
} else {
708+
params.level_zero_device_handle = hDevice;
709+
}
710+
705711
params.memory_type = memory_type;
706712

707713
return params;

test/providers/provider_level_zero.cpp

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,89 @@
1919
using umf_test::test;
2020
using namespace umf_test;
2121

22+
struct LevelZeroProviderInit
23+
: public test,
24+
public ::testing::WithParamInterface<umf_usm_memory_type_t> {};
25+
26+
INSTANTIATE_TEST_SUITE_P(, LevelZeroProviderInit,
27+
::testing::Values(UMF_MEMORY_TYPE_HOST,
28+
UMF_MEMORY_TYPE_DEVICE,
29+
UMF_MEMORY_TYPE_SHARED));
30+
31+
TEST_P(LevelZeroProviderInit, FailNullContext) {
32+
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
33+
ASSERT_NE(ops, nullptr);
34+
35+
auto memory_type = GetParam();
36+
37+
level_zero_memory_provider_params_t params = {nullptr, nullptr, memory_type,
38+
nullptr, 0};
39+
40+
umf_memory_provider_handle_t provider = nullptr;
41+
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
42+
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
43+
}
44+
45+
TEST_P(LevelZeroProviderInit, FailNullDevice) {
46+
if (GetParam() == UMF_MEMORY_TYPE_HOST) {
47+
GTEST_SKIP() << "Host memory does not require device handle";
48+
}
49+
50+
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
51+
ASSERT_NE(ops, nullptr);
52+
53+
auto memory_type = GetParam();
54+
auto params = create_level_zero_prov_params(memory_type);
55+
params.level_zero_device_handle = nullptr;
56+
57+
umf_memory_provider_handle_t provider = nullptr;
58+
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
59+
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
60+
}
61+
62+
TEST_F(test, FailNonNullDevice) {
63+
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
64+
ASSERT_NE(ops, nullptr);
65+
66+
auto memory_type = UMF_MEMORY_TYPE_HOST;
67+
68+
// prepare params for device to get non-null device handle
69+
auto params = create_level_zero_prov_params(UMF_MEMORY_TYPE_DEVICE);
70+
params.memory_type = memory_type;
71+
72+
umf_memory_provider_handle_t provider = nullptr;
73+
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
74+
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
75+
}
76+
77+
TEST_F(test, FailMismatchedResidentHandlesCount) {
78+
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
79+
ASSERT_NE(ops, nullptr);
80+
81+
auto memory_type = UMF_MEMORY_TYPE_DEVICE;
82+
83+
auto params = create_level_zero_prov_params(memory_type);
84+
params.resident_device_count = 99;
85+
86+
umf_memory_provider_handle_t provider = nullptr;
87+
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
88+
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
89+
}
90+
91+
TEST_F(test, FailMismatchedResidentHandlesPtr) {
92+
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
93+
ASSERT_NE(ops, nullptr);
94+
95+
auto memory_type = UMF_MEMORY_TYPE_DEVICE;
96+
97+
auto params = create_level_zero_prov_params(memory_type);
98+
params.resident_device_handles = &params.level_zero_device_handle;
99+
100+
umf_memory_provider_handle_t provider = nullptr;
101+
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
102+
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
103+
}
104+
22105
class LevelZeroMemoryAccessor : public MemoryAccessor {
23106
public:
24107
LevelZeroMemoryAccessor(ze_context_handle_t hContext,
@@ -61,7 +144,11 @@ struct umfLevelZeroProviderTest
61144
hDevice = (ze_device_handle_t)params.level_zero_device_handle;
62145
hContext = (ze_context_handle_t)params.level_zero_context_handle;
63146

64-
ASSERT_NE(hDevice, nullptr);
147+
if (params.memory_type == UMF_MEMORY_TYPE_HOST) {
148+
ASSERT_EQ(hDevice, nullptr);
149+
} else {
150+
ASSERT_NE(hDevice, nullptr);
151+
}
65152
ASSERT_NE(hContext, nullptr);
66153

67154
switch (params.memory_type) {

0 commit comments

Comments
 (0)