Skip to content

L0 provider: do not accept device handle for HOST memory type #703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions src/provider/provider_level_zero.c
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
level_zero_memory_provider_params_t *ze_params =
(level_zero_memory_provider_params_t *)params;

if (!ze_params->level_zero_context_handle) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

if ((ze_params->memory_type == UMF_MEMORY_TYPE_HOST) ==
(bool)ze_params->level_zero_device_handle) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

if ((bool)ze_params->resident_device_count !=
(bool)ze_params->resident_device_handles) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

util_init_once(&ze_is_initialized, init_ze_global_state);
if (Init_ze_global_state_failed) {
LOG_ERR("Loading Level Zero symbols failed");
Expand All @@ -149,12 +163,17 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
ze_provider->device = ze_params->level_zero_device_handle;
ze_provider->memory_type = (ze_memory_type_t)ze_params->memory_type;

umf_result_t ret = ze2umf_result(g_ze_ops.zeDeviceGetProperties(
ze_provider->device, &ze_provider->device_properties));
if (ze_provider->device) {
umf_result_t ret = ze2umf_result(g_ze_ops.zeDeviceGetProperties(
ze_provider->device, &ze_provider->device_properties));

if (ret != UMF_RESULT_SUCCESS) {
umf_ba_global_free(ze_provider);
return ret;
if (ret != UMF_RESULT_SUCCESS) {
umf_ba_global_free(ze_provider);
return ret;
}
} else {
memset(&ze_provider->device_properties, 0,
sizeof(ze_provider->device_properties));
}

*provider = ze_provider;
Expand All @@ -173,6 +192,18 @@ void ze_memory_provider_finalize(void *provider) {
memcpy(&ze_is_initialized, &is_initialized, sizeof(ze_is_initialized));
}

static bool use_relaxed_allocation(ze_memory_provider_t *ze_provider,
size_t size) {
assert(ze_provider->device);
assert(ze_provider->device_properties.maxMemAllocSize > 0);
return size > ze_provider->device_properties.maxMemAllocSize;
}

static ze_relaxed_allocation_limits_exp_desc_t relaxed_device_allocation_desc =
{.stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC,
.pNext = NULL,
.flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE};

static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
size_t alignment,
void **resultPtr) {
Expand All @@ -181,16 +212,6 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,

ze_memory_provider_t *ze_provider = (ze_memory_provider_t *)provider;

bool useRelaxedAllocationFlag =
size > ze_provider->device_properties.maxMemAllocSize;
ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = {
.stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC,
.pNext = NULL,
.flags = 0};
if (useRelaxedAllocationFlag) {
relaxed_desc.flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE;
}

ze_result_t ze_result = ZE_RESULT_SUCCESS;
switch (ze_provider->memory_type) {
case UMF_MEMORY_TYPE_HOST: {
Expand All @@ -204,8 +225,10 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
}
case UMF_MEMORY_TYPE_DEVICE: {
ze_device_mem_alloc_desc_t dev_desc = {
.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
.pNext = useRelaxedAllocationFlag ? &relaxed_desc : NULL,
.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
.pNext = use_relaxed_allocation(ze_provider, size)
? &relaxed_device_allocation_desc
: NULL,
.flags = 0,
.ordinal = 0 // TODO
};
Expand All @@ -221,7 +244,9 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
.flags = 0};
ze_device_mem_alloc_desc_t dev_desc = {
.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
.pNext = useRelaxedAllocationFlag ? &relaxed_desc : NULL,
.pNext = use_relaxed_allocation(ze_provider, size)
? &relaxed_device_allocation_desc
: NULL,
.flags = 0,
.ordinal = 0 // TODO
};
Expand Down
12 changes: 9 additions & 3 deletions test/providers/level_zero_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,8 +668,8 @@ int init_level_zero() {

level_zero_memory_provider_params_t
create_level_zero_prov_params(umf_usm_memory_type_t memory_type) {
level_zero_memory_provider_params_t params = {NULL, NULL,
UMF_MEMORY_TYPE_UNKNOWN};
level_zero_memory_provider_params_t params = {
NULL, NULL, UMF_MEMORY_TYPE_UNKNOWN, NULL, 0};
uint32_t driver_idx = 0;
ze_driver_handle_t hDriver;
ze_device_handle_t hDevice;
Expand Down Expand Up @@ -701,7 +701,13 @@ create_level_zero_prov_params(umf_usm_memory_type_t memory_type) {
}

params.level_zero_context_handle = hContext;
params.level_zero_device_handle = hDevice;

if (memory_type == UMF_MEMORY_TYPE_HOST) {
params.level_zero_device_handle = NULL;
} else {
params.level_zero_device_handle = hDevice;
}

params.memory_type = memory_type;

return params;
Expand Down
89 changes: 88 additions & 1 deletion test/providers/provider_level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,89 @@
using umf_test::test;
using namespace umf_test;

struct LevelZeroProviderInit
: public test,
public ::testing::WithParamInterface<umf_usm_memory_type_t> {};

INSTANTIATE_TEST_SUITE_P(, LevelZeroProviderInit,
::testing::Values(UMF_MEMORY_TYPE_HOST,
UMF_MEMORY_TYPE_DEVICE,
UMF_MEMORY_TYPE_SHARED));

TEST_P(LevelZeroProviderInit, FailNullContext) {
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
ASSERT_NE(ops, nullptr);

auto memory_type = GetParam();

level_zero_memory_provider_params_t params = {nullptr, nullptr, memory_type,
nullptr, 0};

umf_memory_provider_handle_t provider = nullptr;
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
}

TEST_P(LevelZeroProviderInit, FailNullDevice) {
if (GetParam() == UMF_MEMORY_TYPE_HOST) {
GTEST_SKIP() << "Host memory does not require device handle";
}

umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
ASSERT_NE(ops, nullptr);

auto memory_type = GetParam();
auto params = create_level_zero_prov_params(memory_type);
params.level_zero_device_handle = nullptr;

umf_memory_provider_handle_t provider = nullptr;
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
}

TEST_F(test, FailNonNullDevice) {
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
ASSERT_NE(ops, nullptr);

auto memory_type = UMF_MEMORY_TYPE_HOST;

// prepare params for device to get non-null device handle
auto params = create_level_zero_prov_params(UMF_MEMORY_TYPE_DEVICE);
params.memory_type = memory_type;

umf_memory_provider_handle_t provider = nullptr;
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
}

TEST_F(test, FailMismatchedResidentHandlesCount) {
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
ASSERT_NE(ops, nullptr);

auto memory_type = UMF_MEMORY_TYPE_DEVICE;

auto params = create_level_zero_prov_params(memory_type);
params.resident_device_count = 99;

umf_memory_provider_handle_t provider = nullptr;
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
}

TEST_F(test, FailMismatchedResidentHandlesPtr) {
umf_memory_provider_ops_t *ops = umfLevelZeroMemoryProviderOps();
ASSERT_NE(ops, nullptr);

auto memory_type = UMF_MEMORY_TYPE_DEVICE;

auto params = create_level_zero_prov_params(memory_type);
params.resident_device_handles = &params.level_zero_device_handle;

umf_memory_provider_handle_t provider = nullptr;
umf_result_t result = umfMemoryProviderCreate(ops, &params, &provider);
ASSERT_EQ(result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
}

class LevelZeroMemoryAccessor : public MemoryAccessor {
public:
LevelZeroMemoryAccessor(ze_context_handle_t hContext,
Expand Down Expand Up @@ -61,7 +144,11 @@ struct umfLevelZeroProviderTest
hDevice = (ze_device_handle_t)params.level_zero_device_handle;
hContext = (ze_context_handle_t)params.level_zero_context_handle;

ASSERT_NE(hDevice, nullptr);
if (params.memory_type == UMF_MEMORY_TYPE_HOST) {
ASSERT_EQ(hDevice, nullptr);
} else {
ASSERT_NE(hDevice, nullptr);
}
ASSERT_NE(hContext, nullptr);

switch (params.memory_type) {
Expand Down
Loading