Skip to content

Commit 7df503e

Browse files
committed
Add NULL checks for pool and provider handles
1 parent 9c7c3a5 commit 7df503e

File tree

5 files changed

+105
-0
lines changed

5 files changed

+105
-0
lines changed

src/memory_pool.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*/
99

1010
#include "memory_pool_internal.h"
11+
#include "utils_common.h"
1112

1213
#include <umf/memory_pool.h>
1314
#include <umf/memory_pool_ops.h>
@@ -47,30 +48,37 @@ umf_result_t umfPoolCreateEx(const umf_memory_pool_ops_t *pool_ops,
4748
}
4849

4950
void *umfPoolMalloc(umf_memory_pool_handle_t hPool, size_t size) {
51+
UMF_CHECK((hPool != NULL), NULL);
5052
return hPool->ops.malloc(hPool->pool_priv, size);
5153
}
5254

5355
void *umfPoolAlignedMalloc(umf_memory_pool_handle_t hPool, size_t size,
5456
size_t alignment) {
57+
UMF_CHECK((hPool != NULL), NULL);
5558
return hPool->ops.aligned_malloc(hPool->pool_priv, size, alignment);
5659
}
5760

5861
void *umfPoolCalloc(umf_memory_pool_handle_t hPool, size_t num, size_t size) {
62+
UMF_CHECK((hPool != NULL), NULL);
5963
return hPool->ops.calloc(hPool->pool_priv, num, size);
6064
}
6165

6266
void *umfPoolRealloc(umf_memory_pool_handle_t hPool, void *ptr, size_t size) {
67+
UMF_CHECK((hPool != NULL), NULL);
6368
return hPool->ops.realloc(hPool->pool_priv, ptr, size);
6469
}
6570

6671
size_t umfPoolMallocUsableSize(umf_memory_pool_handle_t hPool, void *ptr) {
72+
UMF_CHECK((hPool != NULL), 0);
6773
return hPool->ops.malloc_usable_size(hPool->pool_priv, ptr);
6874
}
6975

7076
umf_result_t umfPoolFree(umf_memory_pool_handle_t hPool, void *ptr) {
77+
UMF_CHECK((hPool != NULL), UMF_RESULT_ERROR_INVALID_ARGUMENT);
7178
return hPool->ops.free(hPool->pool_priv, ptr);
7279
}
7380

7481
umf_result_t umfPoolGetLastAllocationError(umf_memory_pool_handle_t hPool) {
82+
UMF_CHECK((hPool != NULL), UMF_RESULT_ERROR_INVALID_ARGUMENT);
7583
return hPool->ops.get_last_allocation_error(hPool->pool_priv);
7684
}

src/memory_provider.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
*/
99

1010
#include "memory_provider_internal.h"
11+
#include "utils_common.h"
12+
1113
#include <umf/memory_provider.h>
1214

1315
#include <assert.h>
@@ -64,6 +66,7 @@ checkErrorAndSetLastProvider(umf_result_t result,
6466

6567
umf_result_t umfMemoryProviderAlloc(umf_memory_provider_handle_t hProvider,
6668
size_t size, size_t alignment, void **ptr) {
69+
UMF_CHECK((hProvider != NULL), UMF_RESULT_ERROR_INVALID_ARGUMENT);
6770
umf_result_t res =
6871
hProvider->ops.alloc(hProvider->provider_priv, size, alignment, ptr);
6972
checkErrorAndSetLastProvider(res, hProvider);
@@ -72,6 +75,7 @@ umf_result_t umfMemoryProviderAlloc(umf_memory_provider_handle_t hProvider,
7275

7376
umf_result_t umfMemoryProviderFree(umf_memory_provider_handle_t hProvider,
7477
void *ptr, size_t size) {
78+
UMF_CHECK((hProvider != NULL), UMF_RESULT_ERROR_INVALID_ARGUMENT);
7579
umf_result_t res = hProvider->ops.free(hProvider->provider_priv, ptr, size);
7680
checkErrorAndSetLastProvider(res, hProvider);
7781
return res;
@@ -80,17 +84,20 @@ umf_result_t umfMemoryProviderFree(umf_memory_provider_handle_t hProvider,
8084
void umfMemoryProviderGetLastNativeError(umf_memory_provider_handle_t hProvider,
8185
const char **ppMessage,
8286
int32_t *pError) {
87+
ASSERT(hProvider != NULL);
8388
hProvider->ops.get_last_native_error(hProvider->provider_priv, ppMessage,
8489
pError);
8590
}
8691

8792
void *umfMemoryProviderGetPriv(umf_memory_provider_handle_t hProvider) {
93+
UMF_CHECK((hProvider != NULL), NULL);
8894
return hProvider->provider_priv;
8995
}
9096

9197
umf_result_t
9298
umfMemoryProviderGetRecommendedPageSize(umf_memory_provider_handle_t hProvider,
9399
size_t size, size_t *pageSize) {
100+
UMF_CHECK((hProvider != NULL), UMF_RESULT_ERROR_INVALID_ARGUMENT);
94101
umf_result_t res = hProvider->ops.get_recommended_page_size(
95102
hProvider->provider_priv, size, pageSize);
96103
checkErrorAndSetLastProvider(res, hProvider);
@@ -100,6 +107,7 @@ umfMemoryProviderGetRecommendedPageSize(umf_memory_provider_handle_t hProvider,
100107
umf_result_t
101108
umfMemoryProviderGetMinPageSize(umf_memory_provider_handle_t hProvider,
102109
void *ptr, size_t *pageSize) {
110+
UMF_CHECK((hProvider != NULL), UMF_RESULT_ERROR_INVALID_ARGUMENT);
103111
umf_result_t res = hProvider->ops.get_min_page_size(
104112
hProvider->provider_priv, ptr, pageSize);
105113
checkErrorAndSetLastProvider(res, hProvider);
@@ -108,6 +116,7 @@ umfMemoryProviderGetMinPageSize(umf_memory_provider_handle_t hProvider,
108116

109117
umf_result_t umfMemoryProviderPurgeLazy(umf_memory_provider_handle_t hProvider,
110118
void *ptr, size_t size) {
119+
UMF_CHECK((hProvider != NULL), UMF_RESULT_ERROR_INVALID_ARGUMENT);
111120
umf_result_t res =
112121
hProvider->ops.purge_lazy(hProvider->provider_priv, ptr, size);
113122
checkErrorAndSetLastProvider(res, hProvider);
@@ -116,13 +125,15 @@ umf_result_t umfMemoryProviderPurgeLazy(umf_memory_provider_handle_t hProvider,
116125

117126
umf_result_t umfMemoryProviderPurgeForce(umf_memory_provider_handle_t hProvider,
118127
void *ptr, size_t size) {
128+
UMF_CHECK((hProvider != NULL), UMF_RESULT_ERROR_INVALID_ARGUMENT);
119129
umf_result_t res =
120130
hProvider->ops.purge_force(hProvider->provider_priv, ptr, size);
121131
checkErrorAndSetLastProvider(res, hProvider);
122132
return res;
123133
}
124134

125135
const char *umfMemoryProviderGetName(umf_memory_provider_handle_t hProvider) {
136+
UMF_CHECK((hProvider != NULL), NULL);
126137
return hProvider->ops.get_name(hProvider->provider_priv);
127138
}
128139

test/common/base.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,22 @@ struct test : ::testing::Test {
3131
void SetUp() override { ::testing::Test::SetUp(); }
3232
void TearDown() override { ::testing::Test::TearDown(); }
3333
};
34+
35+
template <typename T> T generateArg() { return T{}; }
36+
37+
// returns Ret (*f)(void) that calls the original function
38+
// with all arguments created by calling generateArg()
39+
template <typename Ret, typename... Args>
40+
std::function<Ret(void)> withGeneratedArgs(Ret (*f)(Args...)) {
41+
std::tuple<Args...> tuple = {};
42+
auto args = std::apply(
43+
[](auto... x) {
44+
return std::make_tuple(generateArg<decltype(x)>()...);
45+
},
46+
tuple);
47+
return [=]() { return std::apply(f, args); };
48+
}
49+
3450
} // namespace umf_test
3551

3652
#endif /* UMF_TEST_BASE_HPP */

test/memoryPoolAPI.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44
// This file contains tests for UMF pool API
55

6+
#include "base.hpp"
67
#include "pool.hpp"
78
#include "poolFixtures.hpp"
89
#include "provider.hpp"
@@ -14,7 +15,9 @@
1415
#include <array>
1516
#include <string>
1617
#include <thread>
18+
#include <type_traits>
1719
#include <unordered_map>
20+
#include <variant>
1821

1922
using umf_test::test;
2023
using namespace umf_test;
@@ -282,3 +285,37 @@ TEST_F(test, getLastFailedMemoryProvider) {
282285
umfMemoryProviderGetName(umfGetLastFailedMemoryProvider())),
283286
"provider2");
284287
}
288+
289+
struct poolHandleCheck
290+
: umf_test::test,
291+
::testing::WithParamInterface<
292+
std::function<std::variant<void *, umf_result_t, size_t>(void)>> {};
293+
294+
TEST_P(poolHandleCheck, poolHandleCheckAll) {
295+
auto f = GetParam();
296+
auto ret = f();
297+
298+
std::visit(
299+
[&](auto arg) {
300+
using T = decltype(arg);
301+
if constexpr (std::is_same_v<T, umf_result_t>) {
302+
ASSERT_EQ(arg, UMF_RESULT_ERROR_INVALID_ARGUMENT);
303+
} else if constexpr (std::is_same_v<T, size_t>) {
304+
ASSERT_EQ(arg, 0U);
305+
} else {
306+
ASSERT_EQ(arg, nullptr);
307+
}
308+
},
309+
ret);
310+
}
311+
312+
INSTANTIATE_TEST_SUITE_P(
313+
poolHandleCheck, poolHandleCheck,
314+
::testing::Values(
315+
umf_test::withGeneratedArgs(umfPoolMalloc),
316+
umf_test::withGeneratedArgs(umfPoolAlignedMalloc),
317+
umf_test::withGeneratedArgs(umfPoolFree),
318+
umf_test::withGeneratedArgs(umfPoolCalloc),
319+
umf_test::withGeneratedArgs(umfPoolRealloc),
320+
umf_test::withGeneratedArgs(umfPoolMallocUsableSize),
321+
umf_test::withGeneratedArgs(umfPoolGetLastAllocationError)));

test/memoryProviderAPI.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <string>
1111
#include <unordered_map>
12+
#include <variant>
1213

1314
using umf_test::test;
1415

@@ -104,3 +105,35 @@ TEST_P(providerInitializeTest, errorPropagation) {
104105
&hProvider);
105106
ASSERT_EQ(ret, this->GetParam());
106107
}
108+
109+
struct providerHandleCheck
110+
: umf_test::test,
111+
::testing::WithParamInterface<
112+
std::function<std::variant<const char *, umf_result_t>(void)>> {};
113+
114+
TEST_P(providerHandleCheck, providerHandleCheckAll) {
115+
auto f = GetParam();
116+
auto ret = f();
117+
118+
std::visit(
119+
[&](auto arg) {
120+
using T = decltype(arg);
121+
if constexpr (std::is_same_v<T, umf_result_t>) {
122+
ASSERT_EQ(arg, UMF_RESULT_ERROR_INVALID_ARGUMENT);
123+
} else {
124+
ASSERT_EQ(arg, nullptr);
125+
}
126+
},
127+
ret);
128+
}
129+
130+
INSTANTIATE_TEST_SUITE_P(
131+
providerHandleCheck, providerHandleCheck,
132+
::testing::Values(
133+
umf_test::withGeneratedArgs(umfMemoryProviderAlloc),
134+
umf_test::withGeneratedArgs(umfMemoryProviderFree),
135+
umf_test::withGeneratedArgs(umfMemoryProviderGetRecommendedPageSize),
136+
umf_test::withGeneratedArgs(umfMemoryProviderGetMinPageSize),
137+
umf_test::withGeneratedArgs(umfMemoryProviderPurgeLazy),
138+
umf_test::withGeneratedArgs(umfMemoryProviderPurgeForce),
139+
umf_test::withGeneratedArgs(umfMemoryProviderGetName)));

0 commit comments

Comments
 (0)