Skip to content

Commit 355ffa0

Browse files
committed
L0 provider: do not accept device handle for HOST memory type
and do not call zeDeviceGetProperties for HOST provider. Also, add extra checks for parameters.
1 parent 89b660c commit 355ffa0

File tree

3 files changed

+145
-22
lines changed

3 files changed

+145
-22
lines changed

src/provider/provider_level_zero.c

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,25 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
133133
level_zero_memory_provider_params_t *ze_params =
134134
(level_zero_memory_provider_params_t *)params;
135135

136+
if (!ze_params->level_zero_context_handle) {
137+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
138+
}
139+
140+
if (ze_params->memory_type == UMF_MEMORY_TYPE_HOST &&
141+
ze_params->level_zero_device_handle) {
142+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
143+
}
144+
145+
if (ze_params->memory_type != UMF_MEMORY_TYPE_HOST &&
146+
!ze_params->level_zero_device_handle) {
147+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
148+
}
149+
150+
if ((bool)ze_params->resident_device_count !=
151+
(bool)ze_params->resident_device_handles) {
152+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
153+
}
154+
136155
util_init_once(&ze_is_initialized, init_ze_global_state);
137156
if (Init_ze_global_state_failed) {
138157
LOG_ERR("Loading Level Zero symbols failed");
@@ -149,12 +168,17 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
149168
ze_provider->device = ze_params->level_zero_device_handle;
150169
ze_provider->memory_type = (ze_memory_type_t)ze_params->memory_type;
151170

152-
umf_result_t ret = ze2umf_result(g_ze_ops.zeDeviceGetProperties(
153-
ze_provider->device, &ze_provider->device_properties));
171+
if (ze_provider->device) {
172+
umf_result_t ret = ze2umf_result(g_ze_ops.zeDeviceGetProperties(
173+
ze_provider->device, &ze_provider->device_properties));
154174

155-
if (ret != UMF_RESULT_SUCCESS) {
156-
umf_ba_global_free(ze_provider);
157-
return ret;
175+
if (ret != UMF_RESULT_SUCCESS) {
176+
umf_ba_global_free(ze_provider);
177+
return ret;
178+
}
179+
} else {
180+
memset(&ze_provider->device_properties, 0,
181+
sizeof(ze_provider->device_properties));
158182
}
159183

160184
*provider = ze_provider;
@@ -173,6 +197,18 @@ void ze_memory_provider_finalize(void *provider) {
173197
memcpy(&ze_is_initialized, &is_initialized, sizeof(ze_is_initialized));
174198
}
175199

200+
static bool use_relaxed_allocation(ze_memory_provider_t *ze_provider,
201+
size_t size) {
202+
assert(ze_provider->device);
203+
assert(ze_provider->device_properties.maxMemAllocSize > 0);
204+
return size > ze_provider->device_properties.maxMemAllocSize;
205+
}
206+
207+
static ze_relaxed_allocation_limits_exp_desc_t relaxed_device_allocation_desc =
208+
{.stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC,
209+
.pNext = NULL,
210+
.flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE};
211+
176212
static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
177213
size_t alignment,
178214
void **resultPtr) {
@@ -181,16 +217,6 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
181217

182218
ze_memory_provider_t *ze_provider = (ze_memory_provider_t *)provider;
183219

184-
bool useRelaxedAllocationFlag =
185-
size > ze_provider->device_properties.maxMemAllocSize;
186-
ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = {
187-
.stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC,
188-
.pNext = NULL,
189-
.flags = 0};
190-
if (useRelaxedAllocationFlag) {
191-
relaxed_desc.flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE;
192-
}
193-
194220
ze_result_t ze_result = ZE_RESULT_SUCCESS;
195221
switch (ze_provider->memory_type) {
196222
case UMF_MEMORY_TYPE_HOST: {
@@ -205,7 +231,9 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
205231
case UMF_MEMORY_TYPE_DEVICE: {
206232
ze_device_mem_alloc_desc_t dev_desc = {
207233
.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
208-
.pNext = useRelaxedAllocationFlag ? &relaxed_desc : NULL,
234+
.pNext = use_relaxed_allocation(ze_provider, size)
235+
? &relaxed_device_allocation_desc
236+
: NULL,
209237
.flags = 0,
210238
.ordinal = 0 // TODO
211239
};
@@ -220,8 +248,10 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
220248
.pNext = NULL,
221249
.flags = 0};
222250
ze_device_mem_alloc_desc_t dev_desc = {
223-
.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
224-
.pNext = useRelaxedAllocationFlag ? &relaxed_desc : NULL,
251+
.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
252+
.pNext = use_relaxed_allocation(ze_provider, size)
253+
? &relaxed_device_allocation_desc
254+
: NULL,
225255
.flags = 0,
226256
.ordinal = 0 // TODO
227257
};

test/providers/level_zero_helpers.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -668,8 +668,8 @@ int init_level_zero() {
668668

669669
level_zero_memory_provider_params_t
670670
create_level_zero_prov_params(umf_usm_memory_type_t memory_type) {
671-
level_zero_memory_provider_params_t params = {NULL, NULL,
672-
UMF_MEMORY_TYPE_UNKNOWN};
671+
level_zero_memory_provider_params_t params = {
672+
NULL, NULL, UMF_MEMORY_TYPE_UNKNOWN, NULL, 0};
673673
uint32_t driver_idx = 0;
674674
ze_driver_handle_t hDriver;
675675
ze_device_handle_t hDevice;
@@ -701,7 +701,13 @@ create_level_zero_prov_params(umf_usm_memory_type_t memory_type) {
701701
}
702702

703703
params.level_zero_context_handle = hContext;
704-
params.level_zero_device_handle = hDevice;
704+
705+
if (memory_type == UMF_MEMORY_TYPE_HOST) {
706+
params.level_zero_device_handle = NULL;
707+
} else {
708+
params.level_zero_device_handle = hDevice;
709+
}
710+
705711
params.memory_type = memory_type;
706712

707713
return params;

test/providers/provider_level_zero.cpp

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,89 @@
1919
using umf_test::test;
2020
using namespace umf_test;
2121

22+
struct LevelZeroProviderInit
23+
: public test,
24+
public ::testing::WithParamInterface<umf_usm_memory_type_t> {};
25+
26+
INSTANTIATE_TEST_SUITE_P(, LevelZeroProviderInit,
27+
::testing::Values(UMF_MEMORY_TYPE_HOST,
28+
UMF_MEMORY_TYPE_DEVICE,
29+
UMF_MEMORY_TYPE_SHARED));
30+
31+
TEST_P(LevelZeroProviderInit, FailNullContext) {
32+
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
33+
ASSERT_NE(ops, nullptr);
34+
35+
auto memory_type = GetParam();
36+
37+
level_zero_memory_provider_params_t params = {nullptr, nullptr, memory_type,
38+
nullptr, 0};
39+
40+
umf_memory_provider_handle_t provider = nullptr;
41+
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
42+
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
43+
}
44+
45+
TEST_P(LevelZeroProviderInit, FailNullDevice) {
46+
if (GetParam() == UMF_MEMORY_TYPE_HOST) {
47+
GTEST_SKIP() << "Host memory does not require device handle";
48+
}
49+
50+
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
51+
ASSERT_NE(ops, nullptr);
52+
53+
auto memory_type = GetParam();
54+
auto params = create_level_zero_prov_params(memory_type);
55+
params.level_zero_device_handle = nullptr;
56+
57+
umf_memory_provider_handle_t provider = nullptr;
58+
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
59+
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
60+
}
61+
62+
TEST_F(test, FailNonNullDevice) {
63+
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
64+
ASSERT_NE(ops, nullptr);
65+
66+
auto memory_type = UMF_MEMORY_TYPE_HOST;
67+
68+
// prepare params for device to get non-null device handle
69+
auto params = create_level_zero_prov_params(UMF_MEMORY_TYPE_DEVICE);
70+
params.memory_type = memory_type;
71+
72+
umf_memory_provider_handle_t provider = nullptr;
73+
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
74+
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
75+
}
76+
77+
TEST_F(test, FailMismatchedResidentHandlesCount) {
78+
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
79+
ASSERT_NE(ops, nullptr);
80+
81+
auto memory_type = UMF_MEMORY_TYPE_DEVICE;
82+
83+
auto params = create_level_zero_prov_params(memory_type);
84+
params.resident_device_count = 99;
85+
86+
umf_memory_provider_handle_t provider = nullptr;
87+
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
88+
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
89+
}
90+
91+
TEST_F(test, FailMismatchedResidentHandlesPtr) {
92+
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
93+
ASSERT_NE(ops, nullptr);
94+
95+
auto memory_type = UMF_MEMORY_TYPE_DEVICE;
96+
97+
auto params = create_level_zero_prov_params(memory_type);
98+
params.resident_device_handles = &params.level_zero_device_handle;
99+
100+
umf_memory_provider_handle_t provider = nullptr;
101+
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
102+
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
103+
}
104+
22105
class LevelZeroMemoryAccessor : public MemoryAccessor {
23106
public:
24107
LevelZeroMemoryAccessor(ze_context_handle_t hContext,
@@ -61,7 +144,11 @@ struct umfLevelZeroProviderTest
61144
hDevice = (ze_device_handle_t)params.level_zero_device_handle;
62145
hContext = (ze_context_handle_t)params.level_zero_context_handle;
63146

64-
ASSERT_NE(hDevice, nullptr);
147+
if (params.memory_type == UMF_MEMORY_TYPE_HOST) {
148+
ASSERT_EQ(hDevice, nullptr);
149+
} else {
150+
ASSERT_NE(hDevice, nullptr);
151+
}
65152
ASSERT_NE(hContext, nullptr);
66153

67154
switch (params.memory_type) {

0 commit comments

Comments
 (0)