Skip to content

Extend L0 provider to match UR capabilities #692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions include/umf/providers/provider_level_zero.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
extern "C" {
#endif

typedef struct _ze_context_handle_t *ze_context_handle_t;
typedef struct _ze_device_handle_t *ze_device_handle_t;

/// @brief USM memory allocation type
typedef enum umf_usm_memory_type_t {
UMF_MEMORY_TYPE_UNKNOWN = 0, ///< The memory pointed to is of unknown type
Expand All @@ -24,9 +27,17 @@ typedef enum umf_usm_memory_type_t {

/// @brief Level Zero Memory Provider settings struct
typedef struct level_zero_memory_provider_params_t {
void *level_zero_context_handle; ///< Handle to the Level Zero context
void *level_zero_device_handle; ///< Handle to the Level Zero device
ze_context_handle_t
level_zero_context_handle; ///< Handle to the Level Zero context
ze_device_handle_t
level_zero_device_handle; ///< Handle to the Level Zero device

umf_usm_memory_type_t memory_type; ///< Allocation memory type

ze_device_handle_t *
resident_device_handles; ///< Array of devices for which the memory should be made resident
uint32_t
resident_device_count; ///< Number of devices for which the memory should be made resident
} level_zero_memory_provider_params_t;

umf_memory_provider_ops_t *umfLevelZeroMemoryProviderOps(void);
Expand Down
97 changes: 78 additions & 19 deletions src/provider/provider_level_zero.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ typedef struct ze_memory_provider_t {
ze_context_handle_t context;
ze_device_handle_t device;
ze_memory_type_t memory_type;

ze_device_handle_t *resident_device_handles;
uint32_t resident_device_count;

ze_device_properties_t device_properties;
} ze_memory_provider_t;

typedef struct ze_ops_t {
Expand All @@ -48,11 +53,35 @@ typedef struct ze_ops_t {
ze_ipc_mem_handle_t,
ze_ipc_memory_flags_t, void **);
ze_result_t (*zeMemCloseIpcHandle)(ze_context_handle_t, void *);
ze_result_t (*zeContextMakeMemoryResident)(ze_context_handle_t,
ze_device_handle_t, void *,
size_t);
ze_result_t (*zeDeviceGetProperties)(ze_device_handle_t,
ze_device_properties_t *);
} ze_ops_t;

static ze_ops_t g_ze_ops;
static UTIL_ONCE_FLAG ze_is_initialized = UTIL_ONCE_FLAG_INIT;
static bool Init_ze_global_state_failed;
static __TLS ze_result_t TLS_last_native_error;

static void store_last_native_error(int32_t native_error) {
TLS_last_native_error = native_error;
}

umf_result_t ze2umf_result(ze_result_t result) {
switch (result) {
case ZE_RESULT_SUCCESS:
return UMF_RESULT_SUCCESS;
case ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY:
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
case ZE_RESULT_ERROR_INVALID_ARGUMENT:
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
default:
store_last_native_error(result);
return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
}
}

static void init_ze_global_state(void) {
#ifdef _WIN32
Expand All @@ -78,11 +107,17 @@ static void init_ze_global_state(void) {
util_get_symbol_addr(0, "zeMemOpenIpcHandle", lib_name);
*(void **)&g_ze_ops.zeMemCloseIpcHandle =
util_get_symbol_addr(0, "zeMemCloseIpcHandle", lib_name);
*(void **)&g_ze_ops.zeContextMakeMemoryResident =
util_get_symbol_addr(0, "zeContextMakeMemoryResident", lib_name);
*(void **)&g_ze_ops.zeDeviceGetProperties =
util_get_symbol_addr(0, "zeDeviceGetProperties", lib_name);

if (!g_ze_ops.zeMemAllocHost || !g_ze_ops.zeMemAllocDevice ||
!g_ze_ops.zeMemAllocShared || !g_ze_ops.zeMemFree ||
!g_ze_ops.zeMemGetIpcHandle || !g_ze_ops.zeMemOpenIpcHandle ||
!g_ze_ops.zeMemCloseIpcHandle) {
!g_ze_ops.zeMemCloseIpcHandle ||
!g_ze_ops.zeContextMakeMemoryResident ||
!g_ze_ops.zeDeviceGetProperties) {
// g_ze_ops.zeMemPutIpcHandle can be NULL because it was introduced
// starting from Level Zero 1.6
LOG_ERR("Required Level Zero symbols not found.");
Expand Down Expand Up @@ -114,6 +149,14 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
ze_provider->device = ze_params->level_zero_device_handle;
ze_provider->memory_type = (ze_memory_type_t)ze_params->memory_type;

umf_result_t ret = ze2umf_result(g_ze_ops.zeDeviceGetProperties(
ze_provider->device, &ze_provider->device_properties));

if (ret != UMF_RESULT_SUCCESS) {
umf_ba_global_free(ze_provider);
return ret;
}

*provider = ze_provider;

return UMF_RESULT_SUCCESS;
Expand All @@ -138,6 +181,16 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,

ze_memory_provider_t *ze_provider = (ze_memory_provider_t *)provider;

bool useRelaxedAllocationFlag =
size > ze_provider->device_properties.maxMemAllocSize;
ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = {
.stype = ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC,
.pNext = NULL,
.flags = 0};
if (useRelaxedAllocationFlag) {
relaxed_desc.flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE;
}

ze_result_t ze_result = ZE_RESULT_SUCCESS;
switch (ze_provider->memory_type) {
case UMF_MEMORY_TYPE_HOST: {
Expand All @@ -152,7 +205,7 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
case UMF_MEMORY_TYPE_DEVICE: {
ze_device_mem_alloc_desc_t dev_desc = {
.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC,
.pNext = NULL,
.pNext = useRelaxedAllocationFlag ? &relaxed_desc : NULL,
.flags = 0,
.ordinal = 0 // TODO
};
Expand All @@ -168,7 +221,7 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
.flags = 0};
ze_device_mem_alloc_desc_t dev_desc = {
.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
.pNext = NULL,
.pNext = useRelaxedAllocationFlag ? &relaxed_desc : NULL,
.flags = 0,
.ordinal = 0 // TODO
};
Expand All @@ -178,13 +231,23 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
break;
}
default:
return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

// TODO add error reporting
return (ze_result == ZE_RESULT_SUCCESS)
? UMF_RESULT_SUCCESS
: UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
if (ze_result != ZE_RESULT_SUCCESS) {
return ze2umf_result(ze_result);
}

for (uint32_t i = 0; i < ze_provider->resident_device_count; i++) {
ze_result = g_ze_ops.zeContextMakeMemoryResident(
ze_provider->context, ze_provider->resident_device_handles[i],
*resultPtr, size);
if (ze_result != ZE_RESULT_SUCCESS) {
return ze2umf_result(ze_result);
}
}

return ze2umf_result(ze_result);
}

static umf_result_t ze_memory_provider_free(void *provider, void *ptr,
Expand All @@ -194,11 +257,7 @@ static umf_result_t ze_memory_provider_free(void *provider, void *ptr,
assert(provider);
ze_memory_provider_t *ze_provider = (ze_memory_provider_t *)provider;
ze_result_t ze_result = g_ze_ops.zeMemFree(ze_provider->context, ptr);

// TODO add error reporting
return (ze_result == ZE_RESULT_SUCCESS)
? UMF_RESULT_SUCCESS
: UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
return ze2umf_result(ze_result);
}

void ze_memory_provider_get_last_native_error(void *provider,
Expand All @@ -207,9 +266,9 @@ void ze_memory_provider_get_last_native_error(void *provider,
(void)provider;
(void)ppMessage;

// TODO
assert(pError);
*pError = 0;

*pError = TLS_last_native_error;
}

static umf_result_t ze_memory_provider_get_min_page_size(void *provider,
Expand Down Expand Up @@ -314,7 +373,7 @@ static umf_result_t ze_memory_provider_get_ipc_handle(void *provider,
&ze_ipc_data->ze_handle);
if (ze_result != ZE_RESULT_SUCCESS) {
LOG_ERR("zeMemGetIpcHandle() failed.");
return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
return ze2umf_result(ze_result);
}

ze_ipc_data->pid = utils_getpid();
Expand Down Expand Up @@ -342,7 +401,7 @@ static umf_result_t ze_memory_provider_put_ipc_handle(void *provider,
ze_ipc_data->ze_handle);
if (ze_result != ZE_RESULT_SUCCESS) {
LOG_ERR("zeMemPutIpcHandle() failed.");
return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
return ze2umf_result(ze_result);
}
return UMF_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -379,7 +438,7 @@ static umf_result_t ze_memory_provider_open_ipc_handle(void *provider,
}
if (ze_result != ZE_RESULT_SUCCESS) {
LOG_ERR("zeMemOpenIpcHandle() failed.");
return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
return ze2umf_result(ze_result);
}

return UMF_RESULT_SUCCESS;
Expand All @@ -397,7 +456,7 @@ ze_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
ze_result = g_ze_ops.zeMemCloseIpcHandle(ze_provider->context, ptr);
if (ze_result != ZE_RESULT_SUCCESS) {
LOG_ERR("zeMemCloseIpcHandle() failed.");
return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
return ze2umf_result(ze_result);
}

return UMF_RESULT_SUCCESS;
Expand Down
Loading