Skip to content

Init TBB symbols only once #1064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/libumf.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "base_alloc_global.h"
#include "ipc_cache.h"
#include "memspace_internal.h"
#include "pool/pool_scalable_internal.h"
#include "provider_cuda_internal.h"
#include "provider_level_zero_internal.h"
#include "provider_tracking.h"
Expand Down Expand Up @@ -83,6 +84,7 @@ void umfTearDown(void) {
fini_umfTearDown:
fini_ze_global_state();
fini_cu_global_state();
fini_tbb_global_state();
LOG_DEBUG("UMF library finalized");
}
}
Expand Down
118 changes: 71 additions & 47 deletions src/pool/pool_scalable.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "base_alloc_global.h"
#include "libumf.h"
#include "pool_scalable_internal.h"
#include "utils_common.h"
#include "utils_concurrency.h"
#include "utils_load_library.h"
Expand All @@ -33,6 +34,7 @@ static __TLS umf_result_t TLS_last_allocation_error;
static __TLS umf_result_t TLS_last_free_error;

static const size_t DEFAULT_GRANULARITY = 2 * 1024 * 1024; // 2MB

typedef struct tbb_mem_pool_policy_t {
raw_alloc_tbb_type pAlloc;
raw_free_tbb_type pFree;
Expand Down Expand Up @@ -66,7 +68,6 @@ typedef struct tbb_callbacks_t {
typedef struct tbb_memory_pool_t {
umf_memory_provider_handle_t mem_provider;
void *tbb_pool;
tbb_callbacks_t tbb_callbacks;
} tbb_memory_pool_t;

typedef enum tbb_enums_t {
Expand All @@ -82,6 +83,10 @@ typedef enum tbb_enums_t {
TBB_POOL_SYMBOLS_MAX // it has to be the last one
} tbb_enums_t;

static UTIL_ONCE_FLAG tbb_initialized = UTIL_ONCE_FLAG_INIT;
static int tbb_init_result = 0;
static tbb_callbacks_t tbb_callbacks = {0};

static const char *tbb_symbol[TBB_POOL_SYMBOLS_MAX] = {
#ifdef _WIN32
// symbols copied from oneTBB/src/tbbmalloc/def/win64-tbbmalloc.def
Expand Down Expand Up @@ -109,46 +114,60 @@ static const char *tbb_symbol[TBB_POOL_SYMBOLS_MAX] = {
#endif
};

static int init_tbb_callbacks(tbb_callbacks_t *tbb_callbacks) {
assert(tbb_callbacks);

static void init_tbb_callbacks_once(void) {
const char *lib_name = tbb_symbol[TBB_LIB_NAME];
tbb_callbacks->lib_handle = utils_open_library(lib_name, 0);
if (!tbb_callbacks->lib_handle) {
tbb_callbacks.lib_handle = utils_open_library(lib_name, 0);
if (!tbb_callbacks.lib_handle) {
LOG_ERR("%s required by Scalable Pool not found - install TBB malloc "
"or make sure it is in the default search paths.",
lib_name);
return -1;
tbb_init_result = -1;
return;
}

*(void **)&tbb_callbacks->pool_malloc = utils_get_symbol_addr(
tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_MALLOC], lib_name);
*(void **)&tbb_callbacks->pool_realloc = utils_get_symbol_addr(
tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_REALLOC], lib_name);
*(void **)&tbb_callbacks->pool_aligned_malloc =
utils_get_symbol_addr(tbb_callbacks->lib_handle,
*(void **)&tbb_callbacks.pool_malloc = utils_get_symbol_addr(
tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_MALLOC], lib_name);
*(void **)&tbb_callbacks.pool_realloc = utils_get_symbol_addr(
tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_REALLOC], lib_name);
*(void **)&tbb_callbacks.pool_aligned_malloc =
utils_get_symbol_addr(tbb_callbacks.lib_handle,
tbb_symbol[TBB_POOL_ALIGNED_MALLOC], lib_name);
*(void **)&tbb_callbacks->pool_free = utils_get_symbol_addr(
tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_FREE], lib_name);
*(void **)&tbb_callbacks->pool_create_v1 = utils_get_symbol_addr(
tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_CREATE_V1], lib_name);
*(void **)&tbb_callbacks->pool_destroy = utils_get_symbol_addr(
tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_DESTROY], lib_name);
*(void **)&tbb_callbacks->pool_identify = utils_get_symbol_addr(
tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_IDENTIFY], lib_name);
*(void **)&tbb_callbacks->pool_msize = utils_get_symbol_addr(
tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_MSIZE], lib_name);

if (!tbb_callbacks->pool_malloc || !tbb_callbacks->pool_realloc ||
!tbb_callbacks->pool_aligned_malloc || !tbb_callbacks->pool_free ||
!tbb_callbacks->pool_create_v1 || !tbb_callbacks->pool_destroy ||
!tbb_callbacks->pool_identify) {
*(void **)&tbb_callbacks.pool_free = utils_get_symbol_addr(
tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_FREE], lib_name);
*(void **)&tbb_callbacks.pool_create_v1 = utils_get_symbol_addr(
tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_CREATE_V1], lib_name);
*(void **)&tbb_callbacks.pool_destroy = utils_get_symbol_addr(
tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_DESTROY], lib_name);
*(void **)&tbb_callbacks.pool_identify = utils_get_symbol_addr(
tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_IDENTIFY], lib_name);
*(void **)&tbb_callbacks.pool_msize = utils_get_symbol_addr(
tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_MSIZE], lib_name);

if (!tbb_callbacks.pool_malloc || !tbb_callbacks.pool_realloc ||
!tbb_callbacks.pool_aligned_malloc || !tbb_callbacks.pool_free ||
!tbb_callbacks.pool_create_v1 || !tbb_callbacks.pool_destroy ||
!tbb_callbacks.pool_identify) {
LOG_FATAL("Could not find all TBB symbols in %s", lib_name);
utils_close_library(tbb_callbacks->lib_handle);
return -1;
if (utils_close_library(tbb_callbacks.lib_handle)) {
LOG_ERR("Could not close %s library", lib_name);
}
tbb_init_result = -1;
}
}

return 0;
static int init_tbb_callbacks(void) {
utils_init_once(&tbb_initialized, init_tbb_callbacks_once);
return tbb_init_result;
}

void fini_tbb_global_state(void) {
if (tbb_callbacks.lib_handle) {
if (!utils_close_library(tbb_callbacks.lib_handle)) {
tbb_callbacks.lib_handle = NULL;
LOG_DEBUG("TBB library closed");
} else {
LOG_ERR("TBB library cannot be unloaded");
}
}
}

static void *tbb_raw_alloc_wrapper(intptr_t pool_id, size_t *raw_bytes) {
Expand Down Expand Up @@ -264,35 +283,41 @@ static umf_result_t tbb_pool_initialize(umf_memory_provider_handle_t provider,
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

int ret = init_tbb_callbacks(&pool_data->tbb_callbacks);
umf_result_t res = UMF_RESULT_SUCCESS;
int ret = init_tbb_callbacks();
if (ret != 0) {
LOG_FATAL("loading TBB symbols failed");
return UMF_RESULT_ERROR_UNKNOWN;
res = UMF_RESULT_ERROR_UNKNOWN;
goto err_tbb_init;
}

pool_data->mem_provider = provider;
ret = pool_data->tbb_callbacks.pool_create_v1((intptr_t)pool_data, &policy,
&(pool_data->tbb_pool));
ret = tbb_callbacks.pool_create_v1((intptr_t)pool_data, &policy,
&(pool_data->tbb_pool));
if (ret != 0 /* TBBMALLOC_OK */) {
return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
res = UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC;
goto err_tbb_init;
}

*pool = (void *)pool_data;

return UMF_RESULT_SUCCESS;
return res;

err_tbb_init:
umf_ba_global_free(pool_data);
return res;
}

static void tbb_pool_finalize(void *pool) {
tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool;
pool_data->tbb_callbacks.pool_destroy(pool_data->tbb_pool);
utils_close_library(pool_data->tbb_callbacks.lib_handle);
tbb_callbacks.pool_destroy(pool_data->tbb_pool);
umf_ba_global_free(pool_data);
}

static void *tbb_malloc(void *pool, size_t size) {
tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool;
TLS_last_allocation_error = UMF_RESULT_SUCCESS;
void *ptr = pool_data->tbb_callbacks.pool_malloc(pool_data->tbb_pool, size);
void *ptr = tbb_callbacks.pool_malloc(pool_data->tbb_pool, size);
if (ptr == NULL) {
if (TLS_last_allocation_error == UMF_RESULT_SUCCESS) {
TLS_last_allocation_error = UMF_RESULT_ERROR_UNKNOWN;
Expand All @@ -319,8 +344,7 @@ static void *tbb_calloc(void *pool, size_t num, size_t size) {
static void *tbb_realloc(void *pool, void *ptr, size_t size) {
tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool;
TLS_last_allocation_error = UMF_RESULT_SUCCESS;
void *new_ptr =
pool_data->tbb_callbacks.pool_realloc(pool_data->tbb_pool, ptr, size);
void *new_ptr = tbb_callbacks.pool_realloc(pool_data->tbb_pool, ptr, size);
if (new_ptr == NULL) {
if (TLS_last_allocation_error == UMF_RESULT_SUCCESS) {
TLS_last_allocation_error = UMF_RESULT_ERROR_UNKNOWN;
Expand All @@ -334,8 +358,8 @@ static void *tbb_realloc(void *pool, void *ptr, size_t size) {
static void *tbb_aligned_malloc(void *pool, size_t size, size_t alignment) {
tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool;
TLS_last_allocation_error = UMF_RESULT_SUCCESS;
void *ptr = pool_data->tbb_callbacks.pool_aligned_malloc(
pool_data->tbb_pool, size, alignment);
void *ptr =
tbb_callbacks.pool_aligned_malloc(pool_data->tbb_pool, size, alignment);
if (ptr == NULL) {
if (TLS_last_allocation_error == UMF_RESULT_SUCCESS) {
TLS_last_allocation_error = UMF_RESULT_ERROR_UNKNOWN;
Expand All @@ -360,7 +384,7 @@ static umf_result_t tbb_free(void *pool, void *ptr) {
utils_annotate_release(pool);

tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool;
if (pool_data->tbb_callbacks.pool_free(pool_data->tbb_pool, ptr)) {
if (tbb_callbacks.pool_free(pool_data->tbb_pool, ptr)) {
return UMF_RESULT_SUCCESS;
}

Expand All @@ -373,7 +397,7 @@ static umf_result_t tbb_free(void *pool, void *ptr) {

static size_t tbb_malloc_usable_size(void *pool, void *ptr) {
tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool;
return pool_data->tbb_callbacks.pool_msize(pool_data->tbb_pool, ptr);
return tbb_callbacks.pool_msize(pool_data->tbb_pool, ptr);
}

static umf_result_t tbb_get_last_allocation_error(void *pool) {
Expand Down
10 changes: 10 additions & 0 deletions src/pool/pool_scalable_internal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
*
* Copyright (C) 2025 Intel Corporation
*
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*
*/

void fini_tbb_global_state(void);
Loading