@@ -76,38 +76,6 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
76
76
// /
77
77
// /
78
78
79
- static ur_result_t
80
- CreateHostMemoryProviderPool (ur_device_handle_t_ *DeviceHandle,
81
- umf_memory_provider_handle_t *MemoryProviderHost,
82
- umf_memory_pool_handle_t *MemoryPoolHost) {
83
-
84
- *MemoryProviderHost = nullptr ;
85
- CUcontext context = DeviceHandle->getNativeContext ();
86
-
87
- umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams = nullptr ;
88
- umf_result_t UmfResult =
89
- umfCUDAMemoryProviderParamsCreate (&CUMemoryProviderParams);
90
- UMF_RETURN_UR_ERROR (UmfResult);
91
- OnScopeExit Cleanup (
92
- [=]() { umfCUDAMemoryProviderParamsDestroy (CUMemoryProviderParams); });
93
-
94
- UmfResult = umf::setCUMemoryProviderParams (
95
- CUMemoryProviderParams, 0 /* cuDevice */ , context, UMF_MEMORY_TYPE_HOST);
96
- UMF_RETURN_UR_ERROR (UmfResult);
97
-
98
- // create UMF CUDA memory provider and pool for the host memory
99
- // (UMF_MEMORY_TYPE_HOST)
100
- UmfResult = umfMemoryProviderCreate (
101
- umfCUDAMemoryProviderOps (), CUMemoryProviderParams, MemoryProviderHost);
102
- UMF_RETURN_UR_ERROR (UmfResult);
103
-
104
- UmfResult = umfPoolCreate (umfProxyPoolOps (), *MemoryProviderHost, nullptr , 0 ,
105
- MemoryPoolHost);
106
- UMF_RETURN_UR_ERROR (UmfResult);
107
-
108
- return UR_RESULT_SUCCESS;
109
- }
110
-
111
79
struct ur_context_handle_t_ {
112
80
113
81
struct deleter_data {
@@ -120,30 +88,42 @@ struct ur_context_handle_t_ {
120
88
std::vector<ur_device_handle_t > Devices;
121
89
std::atomic_uint32_t RefCount;
122
90
123
- // UMF CUDA memory provider and pool for the host memory
91
+ // UMF CUDA memory pool for the host memory
124
92
// (UMF_MEMORY_TYPE_HOST)
125
- umf_memory_provider_handle_t MemoryProviderHost = nullptr ;
126
93
umf_memory_pool_handle_t MemoryPoolHost = nullptr ;
127
94
95
+ // UMF CUDA memory pools for the device memory
96
+ // (UMF_MEMORY_TYPE_DEVICE)
97
+ std::vector<umf_memory_pool_handle_t > MemoryDevicePools;
98
+
99
+ // UMF CUDA memory pools for the shared memory
100
+ // (UMF_MEMORY_TYPE_SHARED)
101
+ std::vector<umf_memory_pool_handle_t > MemorySharedPools;
102
+
128
103
ur_context_handle_t_ (const ur_device_handle_t *Devs, uint32_t NumDevices)
129
104
: Devices{Devs, Devs + NumDevices}, RefCount{1 } {
130
105
for (auto &Dev : Devices) {
131
106
urDeviceRetain (Dev);
132
107
}
133
108
134
- // Create UMF CUDA memory provider for the host memory
135
- // (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because
136
- // it is guaranteed to exist).
137
- UR_CHECK_ERROR (CreateHostMemoryProviderPool (Devices[0 ], &MemoryProviderHost,
138
- &MemoryPoolHost));
109
+ // Create UMF CUDA memory provider and pool for the host memory
110
+ // (UMF_MEMORY_TYPE_HOST)
111
+ UR_CHECK_ERROR (createHostMemoryPool ());
112
+
113
+ // Create UMF CUDA memory providers and pools for the device memory
114
+ // (UMF_MEMORY_TYPE_HOST) and the shared memory (UMF_MEMORY_TYPE_SHARED).
115
+ UR_CHECK_ERROR (createDeviceMemoryPools ());
139
116
};
140
117
141
118
~ur_context_handle_t_ () {
142
119
if (MemoryPoolHost) {
143
120
umfPoolDestroy (MemoryPoolHost);
144
121
}
145
- if (MemoryProviderHost) {
146
- umfMemoryProviderDestroy (MemoryProviderHost);
122
+ for (auto &Pool : MemoryDevicePools) {
123
+ umfPoolDestroy (Pool);
124
+ }
125
+ for (auto &Pool : MemorySharedPools) {
126
+ umfPoolDestroy (Pool);
147
127
}
148
128
for (auto &Dev : Devices) {
149
129
urDeviceRelease (Dev);
@@ -190,6 +170,59 @@ struct ur_context_handle_t_ {
190
170
std::mutex Mutex;
191
171
std::vector<deleter_data> ExtendedDeleters;
192
172
std::set<ur_usm_pool_handle_t > PoolHandles;
173
+
174
+ // Create UMF CUDA memory pool for the host memory (UMF_MEMORY_TYPE_HOST)
175
+ ur_result_t createHostMemoryPool () {
176
+ umf_memory_provider_handle_t memoryProviderHost = nullptr ;
177
+ ur_result_t URResult = umf::createHostMemoryProvider (
178
+ Devices[0 ]->getNativeContext (), &memoryProviderHost);
179
+ if (URResult != UR_RESULT_SUCCESS) {
180
+ return URResult;
181
+ }
182
+
183
+ umf_result_t UmfResult =
184
+ umfPoolCreate (umfProxyPoolOps (), memoryProviderHost, nullptr ,
185
+ UMF_POOL_CREATE_FLAG_OWN_PROVIDER, &MemoryPoolHost);
186
+ UMF_RETURN_UR_ERROR (UmfResult);
187
+
188
+ return UR_RESULT_SUCCESS;
189
+ }
190
+
191
+ // Create UMF CUDA memory pools for the device memory (UMF_MEMORY_TYPE_HOST)
192
+ // and the shared memory (UMF_MEMORY_TYPE_SHARED)
193
+ ur_result_t createDeviceMemoryPools () {
194
+ for (auto &Device : Devices) {
195
+ umf_memory_provider_handle_t memoryDeviceProvider = nullptr ;
196
+ umf_memory_provider_handle_t memorySharedProvider = nullptr ;
197
+ ur_result_t URResult = umf::createDeviceMemoryProviders (
198
+ Device, &memoryDeviceProvider, &memorySharedProvider);
199
+ if (URResult != UR_RESULT_SUCCESS) {
200
+ return URResult;
201
+ }
202
+
203
+ // create UMF CUDA memory pool for the device memory
204
+ // (UMF_MEMORY_TYPE_DEVICE)
205
+ umf_result_t UmfResult;
206
+ umf_memory_pool_handle_t memoryDevicePool = nullptr ;
207
+ UmfResult =
208
+ umfPoolCreate (umfProxyPoolOps (), memoryDeviceProvider, nullptr ,
209
+ UMF_POOL_CREATE_FLAG_OWN_PROVIDER, &memoryDevicePool);
210
+ UMF_RETURN_UR_ERROR (UmfResult);
211
+
212
+ // create UMF CUDA memory pool for the shared memory
213
+ // (UMF_MEMORY_TYPE_SHARED)
214
+ umf_memory_pool_handle_t memorySharedPool = nullptr ;
215
+ UmfResult =
216
+ umfPoolCreate (umfProxyPoolOps (), memorySharedProvider, nullptr ,
217
+ UMF_POOL_CREATE_FLAG_OWN_PROVIDER, &memorySharedPool);
218
+ UMF_RETURN_UR_ERROR (UmfResult);
219
+
220
+ MemoryDevicePools.push_back (memoryDevicePool);
221
+ MemorySharedPools.push_back (memorySharedPool);
222
+ }
223
+
224
+ return UR_RESULT_SUCCESS;
225
+ }
193
226
};
194
227
195
228
namespace {
0 commit comments