@@ -21,10 +21,8 @@ using namespace umf_test;
21
21
22
22
class CUDAMemoryAccessor : public MemoryAccessor {
23
23
public:
24
- void init (CUcontext hContext, CUdevice hDevice) {
25
- hDevice_ = hDevice;
26
- hContext_ = hContext;
27
- }
24
+ CUDAMemoryAccessor (CUcontext hContext, CUdevice hDevice)
25
+ : hDevice_(hDevice), hContext_(hContext) {}
28
26
29
27
void fill (void *ptr, size_t size, const void *pattern,
30
28
size_t pattern_size) {
@@ -53,7 +51,7 @@ class CUDAMemoryAccessor : public MemoryAccessor {
53
51
};
54
52
55
53
using CUDAProviderTestParams =
56
- std::tuple<umf_usm_memory_type_t , MemoryAccessor *>;
54
+ std::tuple<cuda_memory_provider_params_t , MemoryAccessor *>;
57
55
58
56
struct umfCUDAProviderTest
59
57
: umf_test::test,
@@ -62,23 +60,12 @@ struct umfCUDAProviderTest
62
60
void SetUp () override {
63
61
test::SetUp ();
64
62
65
- auto [memory_type , accessor] = this ->GetParam ();
66
- params = create_cuda_prov_params (memory_type) ;
63
+ auto [cuda_params , accessor] = this ->GetParam ();
64
+ params = cuda_params ;
67
65
memAccessor = accessor;
68
- if (memory_type == UMF_MEMORY_TYPE_DEVICE) {
69
- ((CUDAMemoryAccessor *)memAccessor)
70
- ->init ((CUcontext)params.cuda_context_handle ,
71
- params.cuda_device_handle );
72
- }
73
66
}
74
67
75
- void TearDown () override {
76
- if (params.cuda_context_handle ) {
77
- int ret = destroy_context ((CUcontext)params.cuda_context_handle );
78
- ASSERT_EQ (ret, 0 );
79
- }
80
- test::TearDown ();
81
- }
68
+ void TearDown () override { test::TearDown (); }
82
69
83
70
cuda_memory_provider_params_t params;
84
71
MemoryAccessor *memAccessor = nullptr ;
@@ -87,6 +74,7 @@ struct umfCUDAProviderTest
87
74
TEST_P (umfCUDAProviderTest, basic) {
88
75
const size_t size = 1024 * 8 ;
89
76
const uint32_t pattern = 0xAB ;
77
+ CUcontext expected_current_context = get_current_context ();
90
78
91
79
// create CUDA provider
92
80
umf_memory_provider_handle_t provider = nullptr ;
@@ -113,6 +101,12 @@ TEST_P(umfCUDAProviderTest, basic) {
113
101
// use the allocated memory - fill it with a 0xAB pattern
114
102
memAccessor->fill (ptr, size, &pattern, sizeof (pattern));
115
103
104
+ CUcontext actual_mem_context = get_mem_context (ptr);
105
+ ASSERT_EQ (actual_mem_context, (CUcontext)params.cuda_context_handle );
106
+
107
+ CUcontext actual_current_context = get_current_context ();
108
+ ASSERT_EQ (actual_current_context, expected_current_context);
109
+
116
110
umf_usm_memory_type_t memoryTypeActual =
117
111
get_mem_type ((CUcontext)params.cuda_context_handle , ptr);
118
112
ASSERT_EQ (memoryTypeActual, params.memory_type );
@@ -132,6 +126,7 @@ TEST_P(umfCUDAProviderTest, basic) {
132
126
}
133
127
134
128
TEST_P (umfCUDAProviderTest, allocInvalidSize) {
129
+ CUcontext expected_current_context = get_current_context ();
135
130
// create CUDA provider
136
131
umf_memory_provider_handle_t provider = nullptr ;
137
132
umf_result_t umf_result =
@@ -151,32 +146,32 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) {
151
146
ASSERT_EQ (umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
152
147
}
153
148
154
- // destroy context and try to alloc some memory
155
- destroy_context ((CUcontext)params.cuda_context_handle );
156
- params.cuda_context_handle = 0 ;
157
- umf_result = umfMemoryProviderAlloc (provider, 128 , 0 , &ptr);
158
- ASSERT_EQ (umf_result, UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC);
159
-
160
- const char *message;
161
- int32_t error;
162
- umfMemoryProviderGetLastNativeError (provider, &message, &error);
163
- ASSERT_EQ (error, CUDA_ERROR_INVALID_CONTEXT);
164
- const char *expected_message =
165
- " CUDA_ERROR_INVALID_CONTEXT - invalid device context" ;
166
- ASSERT_EQ (strncmp (message, expected_message, strlen (expected_message)), 0 );
149
+ CUcontext actual_current_context = get_current_context ();
150
+ ASSERT_EQ (actual_current_context, expected_current_context);
151
+
152
+ umfMemoryProviderDestroy (provider);
167
153
}
168
154
169
155
// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
170
156
171
- CUDAMemoryAccessor cuAccessor;
157
+ cuda_memory_provider_params_t cuParams_device_memory =
158
+ create_cuda_prov_params (UMF_MEMORY_TYPE_DEVICE);
159
+ cuda_memory_provider_params_t cuParams_shared_memory =
160
+ create_cuda_prov_params (UMF_MEMORY_TYPE_SHARED);
161
+ cuda_memory_provider_params_t cuParams_host_memory =
162
+ create_cuda_prov_params (UMF_MEMORY_TYPE_HOST);
163
+
164
+ CUDAMemoryAccessor
165
+ cuAccessor ((CUcontext)cuParams_device_memory.cuda_context_handle,
166
+ (CUdevice)cuParams_device_memory.cuda_device_handle);
172
167
HostMemoryAccessor hostAccessor;
173
168
174
169
INSTANTIATE_TEST_SUITE_P (
175
170
umfCUDAProviderTestSuite, umfCUDAProviderTest,
176
171
::testing::Values (
177
- CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE , &cuAccessor},
178
- CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED , &hostAccessor},
179
- CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST , &hostAccessor}));
172
+ CUDAProviderTestParams{cuParams_device_memory , &cuAccessor},
173
+ CUDAProviderTestParams{cuParams_shared_memory , &hostAccessor},
174
+ CUDAProviderTestParams{cuParams_host_memory , &hostAccessor}));
180
175
181
176
// TODO: add IPC API
182
177
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST (umfIpcTest);
@@ -185,5 +180,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest,
185
180
::testing::Values(ipcTestParams{
186
181
umfProxyPoolOps(), nullptr,
187
182
umfCUDAMemoryProviderOps(),
188
- &cuParams_device_memory, &l0Accessor }));
183
+ &cuParams_device_memory, &cuAccessor }));
189
184
*/
0 commit comments