Skip to content

Commit 563842d

Browse files
committed
[L0 v2] support SYCL_PI_LEVEL_ZERO_DISABLE_USM_ALLOCATOR
and fix potential memory leak when creating UMF params.
1 parent 607bfb7 commit 563842d

File tree

1 file changed

+69
-35
lines changed
  • source/adapters/level_zero/v2

1 file changed

+69
-35
lines changed

source/adapters/level_zero/v2/usm.cpp

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,17 @@ ur_result_t getProviderNativeError(const char *providerName,
3434
}
3535
} // namespace umf
3636

37-
static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig() {
37+
static std::optional<usm::DisjointPoolAllConfigs>
38+
initializeDisjointPoolConfig() {
39+
const char *UrRetDisable = std::getenv("UR_L0_DISABLE_USM_ALLOCATOR");
40+
const char *PiRetDisable =
41+
std::getenv("SYCL_PI_LEVEL_ZERO_DISABLE_USM_ALLOCATOR");
42+
const char *Disable =
43+
UrRetDisable ? UrRetDisable : (PiRetDisable ? PiRetDisable : nullptr);
44+
if (Disable != nullptr && Disable != std::string("")) {
45+
return std::nullopt;
46+
}
47+
3848
const char *PoolUrTraceVal = std::getenv("UR_L0_USM_ALLOCATOR_TRACE");
3949

4050
int PoolTrace = 0;
@@ -47,7 +57,14 @@ static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig() {
4757
return usm::DisjointPoolAllConfigs(PoolTrace);
4858
}
4959

50-
return usm::parseDisjointPoolConfig(PoolUrConfigVal, PoolTrace);
60+
// TODO: rework parseDisjointPoolConfig to return optional,
61+
// once EnableBuffers is no longer used (by legacy L0)
62+
auto configs = usm::parseDisjointPoolConfig(PoolUrConfigVal, PoolTrace);
63+
if (configs.EnableBuffers) {
64+
return configs;
65+
}
66+
67+
return std::nullopt;
5168
}
5269

5370
inline umf_usm_memory_type_t urToUmfMemoryType(ur_usm_type_t type) {
@@ -81,32 +98,35 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
8198
}
8299
}
83100

84-
static umf::pool_unique_handle_t
85-
makePool(usm::umf_disjoint_pool_config_t *poolParams,
86-
usm::pool_descriptor poolDescriptor) {
87-
umf_level_zero_memory_provider_params_handle_t params = NULL;
88-
umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate(&params);
101+
static umf::provider_unique_handle_t
102+
makeProvider(usm::pool_descriptor poolDescriptor) {
103+
umf_level_zero_memory_provider_params_handle_t hParams;
104+
umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate(&hParams);
89105
if (umf_ret != UMF_RESULT_SUCCESS) {
90106
throw umf::umf2urResult(umf_ret);
91107
}
92108

109+
std::unique_ptr<umf_level_zero_memory_provider_params_t,
110+
decltype(&umfLevelZeroMemoryProviderParamsDestroy)>
111+
params(hParams, &umfLevelZeroMemoryProviderParamsDestroy);
112+
93113
umf_ret = umfLevelZeroMemoryProviderParamsSetContext(
94-
params, poolDescriptor.hContext->getZeHandle());
114+
hParams, poolDescriptor.hContext->getZeHandle());
95115
if (umf_ret != UMF_RESULT_SUCCESS) {
96116
throw umf::umf2urResult(umf_ret);
97117
};
98118

99119
ze_device_handle_t level_zero_device_handle =
100120
poolDescriptor.hDevice ? poolDescriptor.hDevice->ZeDevice : nullptr;
101121

102-
umf_ret = umfLevelZeroMemoryProviderParamsSetDevice(params,
122+
umf_ret = umfLevelZeroMemoryProviderParamsSetDevice(hParams,
103123
level_zero_device_handle);
104124
if (umf_ret != UMF_RESULT_SUCCESS) {
105125
throw umf::umf2urResult(umf_ret);
106126
}
107127

108128
umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType(
109-
params, urToUmfMemoryType(poolDescriptor.type));
129+
hParams, urToUmfMemoryType(poolDescriptor.type));
110130
if (umf_ret != UMF_RESULT_SUCCESS) {
111131
throw umf::umf2urResult(umf_ret);
112132
}
@@ -123,46 +143,59 @@ makePool(usm::umf_disjoint_pool_config_t *poolParams,
123143
}
124144

125145
umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices(
126-
params, residentZeHandles.data(), residentZeHandles.size());
146+
hParams, residentZeHandles.data(), residentZeHandles.size());
127147
if (umf_ret != UMF_RESULT_SUCCESS) {
128148
throw umf::umf2urResult(umf_ret);
129149
}
130150
}
131151

132152
auto [ret, provider] =
133-
umf::providerMakeUniqueFromOps(umfLevelZeroMemoryProviderOps(), params);
153+
umf::providerMakeUniqueFromOps(umfLevelZeroMemoryProviderOps(), hParams);
134154
if (ret != UMF_RESULT_SUCCESS) {
135155
throw umf::umf2urResult(ret);
136156
}
137157

138-
if (!poolParams) {
139-
auto [ret, poolHandle] = umf::poolMakeUniqueFromOps(
140-
umfProxyPoolOps(), std::move(provider), nullptr);
141-
if (ret != UMF_RESULT_SUCCESS)
142-
throw umf::umf2urResult(ret);
143-
return std::move(poolHandle);
144-
} else {
145-
auto umfParams = getUmfParamsHandle(*poolParams);
158+
return std::move(provider);
159+
}
146160

147-
auto [ret, poolHandle] =
148-
umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(provider),
149-
static_cast<void *>(umfParams.get()));
150-
if (ret != UMF_RESULT_SUCCESS)
151-
throw umf::umf2urResult(ret);
152-
return std::move(poolHandle);
153-
}
161+
static umf::pool_unique_handle_t
162+
makeDisjointPool(umf::provider_unique_handle_t &&provider,
163+
usm::umf_disjoint_pool_config_t &poolParams) {
164+
auto umfParams = getUmfParamsHandle(poolParams);
165+
auto [ret, poolHandle] =
166+
umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(provider),
167+
static_cast<void *>(umfParams.get()));
168+
if (ret != UMF_RESULT_SUCCESS)
169+
throw umf::umf2urResult(ret);
170+
return std::move(poolHandle);
171+
}
172+
173+
static umf::pool_unique_handle_t
174+
makeProxyPool(umf::provider_unique_handle_t &&provider) {
175+
auto [ret, poolHandle] = umf::poolMakeUniqueFromOps(
176+
umfProxyPoolOps(), std::move(provider), nullptr);
177+
if (ret != UMF_RESULT_SUCCESS)
178+
throw umf::umf2urResult(ret);
179+
180+
return std::move(poolHandle);
154181
}
155182

156183
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
157184
ur_usm_pool_desc_t *pPoolDesc)
158185
: hContext(hContext) {
159186
// TODO: handle UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK from pPoolDesc
160187
auto disjointPoolConfigs = initializeDisjointPoolConfig();
161-
if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t>(pPoolDesc)) {
162-
for (auto &config : disjointPoolConfigs.Configs) {
163-
config.MaxPoolableSize = limits->maxPoolableSize;
164-
config.SlabMinSize = limits->minDriverAllocSize;
188+
189+
if (disjointPoolConfigs.has_value()) {
190+
if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t>(pPoolDesc)) {
191+
for (auto &config : disjointPoolConfigs.value().Configs) {
192+
config.MaxPoolableSize = limits->maxPoolableSize;
193+
config.SlabMinSize = limits->minDriverAllocSize;
194+
}
165195
}
196+
} else {
197+
// If pooling is disabled, do nothing.
198+
logger::info("USM pooling is disabled. Skiping pool limits adjustment.");
166199
}
167200

168201
auto [result, descriptors] = usm::pool_descriptor::create(this, hContext);
@@ -171,12 +204,13 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
171204
}
172205

173206
for (auto &desc : descriptors) {
174-
if (disjointPoolConfigs.EnableBuffers) {
207+
if (disjointPoolConfigs.has_value()) {
175208
auto &poolConfig =
176-
disjointPoolConfigs.Configs[descToDisjoinPoolMemType(desc)];
177-
poolManager.addPool(desc, makePool(&poolConfig, desc));
209+
disjointPoolConfigs.value().Configs[descToDisjoinPoolMemType(desc)];
210+
poolManager.addPool(desc,
211+
makeDisjointPool(makeProvider(desc), poolConfig));
178212
} else {
179-
poolManager.addPool(desc, makePool(nullptr, desc));
213+
poolManager.addPool(desc, makeProxyPool(makeProvider(desc)));
180214
}
181215
}
182216
}

0 commit comments

Comments
 (0)