@@ -34,7 +34,17 @@ ur_result_t getProviderNativeError(const char *providerName,
34
34
}
35
35
} // namespace umf
36
36
37
- static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig () {
37
+ static std::optional<usm::DisjointPoolAllConfigs>
38
+ initializeDisjointPoolConfig () {
39
+ const char *UrRetDisable = std::getenv (" UR_L0_DISABLE_USM_ALLOCATOR" );
40
+ const char *PiRetDisable =
41
+ std::getenv (" SYCL_PI_LEVEL_ZERO_DISABLE_USM_ALLOCATOR" );
42
+ const char *Disable =
43
+ UrRetDisable ? UrRetDisable : (PiRetDisable ? PiRetDisable : nullptr );
44
+ if (Disable != nullptr && Disable != std::string (" " )) {
45
+ return std::nullopt;
46
+ }
47
+
38
48
const char *PoolUrTraceVal = std::getenv (" UR_L0_USM_ALLOCATOR_TRACE" );
39
49
40
50
int PoolTrace = 0 ;
@@ -47,7 +57,14 @@ static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig() {
47
57
return usm::DisjointPoolAllConfigs (PoolTrace);
48
58
}
49
59
50
- return usm::parseDisjointPoolConfig (PoolUrConfigVal, PoolTrace);
60
+ // TODO: rework parseDisjointPoolConfig to return optional,
61
+ // once EnableBuffers is no longer used (by legacy L0)
62
+ auto configs = usm::parseDisjointPoolConfig (PoolUrConfigVal, PoolTrace);
63
+ if (configs.EnableBuffers ) {
64
+ return configs;
65
+ }
66
+
67
+ return std::nullopt;
51
68
}
52
69
53
70
inline umf_usm_memory_type_t urToUmfMemoryType (ur_usm_type_t type) {
@@ -81,32 +98,35 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
81
98
}
82
99
}
83
100
84
- static umf::pool_unique_handle_t
85
- makePool (usm::umf_disjoint_pool_config_t *poolParams,
86
- usm::pool_descriptor poolDescriptor) {
87
- umf_level_zero_memory_provider_params_handle_t params = NULL ;
88
- umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (¶ms);
101
+ static umf::provider_unique_handle_t
102
+ makeProvider (usm::pool_descriptor poolDescriptor) {
103
+ umf_level_zero_memory_provider_params_handle_t hParams;
104
+ umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (&hParams);
89
105
if (umf_ret != UMF_RESULT_SUCCESS) {
90
106
throw umf::umf2urResult (umf_ret);
91
107
}
92
108
109
+ std::unique_ptr<umf_level_zero_memory_provider_params_t ,
110
+ decltype (&umfLevelZeroMemoryProviderParamsDestroy)>
111
+ params (hParams, &umfLevelZeroMemoryProviderParamsDestroy);
112
+
93
113
umf_ret = umfLevelZeroMemoryProviderParamsSetContext (
94
- params , poolDescriptor.hContext ->getZeHandle ());
114
+ hParams , poolDescriptor.hContext ->getZeHandle ());
95
115
if (umf_ret != UMF_RESULT_SUCCESS) {
96
116
throw umf::umf2urResult (umf_ret);
97
117
};
98
118
99
119
ze_device_handle_t level_zero_device_handle =
100
120
poolDescriptor.hDevice ? poolDescriptor.hDevice ->ZeDevice : nullptr ;
101
121
102
- umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (params ,
122
+ umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (hParams ,
103
123
level_zero_device_handle);
104
124
if (umf_ret != UMF_RESULT_SUCCESS) {
105
125
throw umf::umf2urResult (umf_ret);
106
126
}
107
127
108
128
umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType (
109
- params , urToUmfMemoryType (poolDescriptor.type ));
129
+ hParams , urToUmfMemoryType (poolDescriptor.type ));
110
130
if (umf_ret != UMF_RESULT_SUCCESS) {
111
131
throw umf::umf2urResult (umf_ret);
112
132
}
@@ -123,46 +143,59 @@ makePool(usm::umf_disjoint_pool_config_t *poolParams,
123
143
}
124
144
125
145
umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices (
126
- params , residentZeHandles.data (), residentZeHandles.size ());
146
+ hParams , residentZeHandles.data (), residentZeHandles.size ());
127
147
if (umf_ret != UMF_RESULT_SUCCESS) {
128
148
throw umf::umf2urResult (umf_ret);
129
149
}
130
150
}
131
151
132
152
auto [ret, provider] =
133
- umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), params );
153
+ umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), hParams );
134
154
if (ret != UMF_RESULT_SUCCESS) {
135
155
throw umf::umf2urResult (ret);
136
156
}
137
157
138
- if (!poolParams) {
139
- auto [ret, poolHandle] = umf::poolMakeUniqueFromOps (
140
- umfProxyPoolOps (), std::move (provider), nullptr );
141
- if (ret != UMF_RESULT_SUCCESS)
142
- throw umf::umf2urResult (ret);
143
- return std::move (poolHandle);
144
- } else {
145
- auto umfParams = getUmfParamsHandle (*poolParams);
158
+ return std::move (provider);
159
+ }
146
160
147
- auto [ret, poolHandle] =
148
- umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (provider),
149
- static_cast <void *>(umfParams.get ()));
150
- if (ret != UMF_RESULT_SUCCESS)
151
- throw umf::umf2urResult (ret);
152
- return std::move (poolHandle);
153
- }
161
+ static umf::pool_unique_handle_t
162
+ makeDisjointPool (umf::provider_unique_handle_t &&provider,
163
+ usm::umf_disjoint_pool_config_t &poolParams) {
164
+ auto umfParams = getUmfParamsHandle (poolParams);
165
+ auto [ret, poolHandle] =
166
+ umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (provider),
167
+ static_cast <void *>(umfParams.get ()));
168
+ if (ret != UMF_RESULT_SUCCESS)
169
+ throw umf::umf2urResult (ret);
170
+ return std::move (poolHandle);
171
+ }
172
+
173
+ static umf::pool_unique_handle_t
174
+ makeProxyPool (umf::provider_unique_handle_t &&provider) {
175
+ auto [ret, poolHandle] = umf::poolMakeUniqueFromOps (
176
+ umfProxyPoolOps (), std::move (provider), nullptr );
177
+ if (ret != UMF_RESULT_SUCCESS)
178
+ throw umf::umf2urResult (ret);
179
+
180
+ return std::move (poolHandle);
154
181
}
155
182
156
183
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t hContext,
157
184
ur_usm_pool_desc_t *pPoolDesc)
158
185
: hContext(hContext) {
159
186
// TODO: handle UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK from pPoolDesc
160
187
auto disjointPoolConfigs = initializeDisjointPoolConfig ();
161
- if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
162
- for (auto &config : disjointPoolConfigs.Configs ) {
163
- config.MaxPoolableSize = limits->maxPoolableSize ;
164
- config.SlabMinSize = limits->minDriverAllocSize ;
188
+
189
+ if (disjointPoolConfigs.has_value ()) {
190
+ if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
191
+ for (auto &config : disjointPoolConfigs.value ().Configs ) {
192
+ config.MaxPoolableSize = limits->maxPoolableSize ;
193
+ config.SlabMinSize = limits->minDriverAllocSize ;
194
+ }
165
195
}
196
+ } else {
197
+ // If pooling is disabled, do nothing.
198
+ logger::info (" USM pooling is disabled. Skiping pool limits adjustment." );
166
199
}
167
200
168
201
auto [result, descriptors] = usm::pool_descriptor::create (this , hContext);
@@ -171,12 +204,13 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
171
204
}
172
205
173
206
for (auto &desc : descriptors) {
174
- if (disjointPoolConfigs.EnableBuffers ) {
207
+ if (disjointPoolConfigs.has_value () ) {
175
208
auto &poolConfig =
176
- disjointPoolConfigs.Configs [descToDisjoinPoolMemType (desc)];
177
- poolManager.addPool (desc, makePool (&poolConfig, desc));
209
+ disjointPoolConfigs.value ().Configs [descToDisjoinPoolMemType (desc)];
210
+ poolManager.addPool (desc,
211
+ makeDisjointPool (makeProvider (desc), poolConfig));
178
212
} else {
179
- poolManager.addPool (desc, makePool ( nullptr , desc));
213
+ poolManager.addPool (desc, makeProxyPool ( makeProvider ( desc) ));
180
214
}
181
215
}
182
216
}
0 commit comments