Skip to content

Commit 6cf1ae0

Browse files
authored
[SYCL] Evict program cache on PI_ERROR_OUT_OF_RESOURCES (#11987)
This PR adds resource cleanup mechanisms when `PI_ERROR_OUT_OF_RESOURCES` occurs on `piProgramBuild` and `piProgramLink`. Currently. the whole cache is cleared, but this behavior can be extended in the future. In order to ensure thread safety of the cache clearing mechanism, the maps of the cache now hold their `BuildResult`'s in `shared_ptr`'s, which allows the cache to be cleared even if there are still threads that still hold references to a `BuildResult`.
1 parent 3e7f66b commit 6cf1ae0

File tree

7 files changed

+390
-208
lines changed

7 files changed

+390
-208
lines changed

sycl/source/detail/context_impl.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ std::optional<sycl::detail::pi::PiProgram> context_impl::getProgramForDevImgs(
447447
const device &Device, const std::set<std::uintptr_t> &ImgIdentifiers,
448448
const std::string &ObjectTypeName) {
449449

450-
KernelProgramCache::ProgramWithBuildStateT *BuildRes = nullptr;
450+
KernelProgramCache::ProgramBuildResultPtr BuildRes = nullptr;
451451
{
452452
auto LockedCache = MKernelProgramCache.acquireCachedPrograms();
453453
auto &KeyMap = LockedCache.get().KeyMap;
@@ -471,12 +471,18 @@ std::optional<sycl::detail::pi::PiProgram> context_impl::getProgramForDevImgs(
471471
assert(KeyMappingsIt != KeyMap.end());
472472
auto CachedProgIt = Cache.find(KeyMappingsIt->second);
473473
assert(CachedProgIt != Cache.end());
474-
BuildRes = &CachedProgIt->second;
474+
BuildRes = CachedProgIt->second;
475475
}
476476
}
477477
if (!BuildRes)
478478
return std::nullopt;
479-
return *MKernelProgramCache.waitUntilBuilt<compile_program_error>(BuildRes);
479+
using BuildState = KernelProgramCache::BuildState;
480+
BuildState NewState = BuildRes->waitUntilTransition();
481+
if (NewState == BuildState::BS_Failed)
482+
throw compile_program_error(BuildRes->Error.Msg, BuildRes->Error.Code);
483+
484+
assert(NewState == BuildState::BS_Done);
485+
return BuildRes->Val;
480486
}
481487

482488
std::optional<sycl::detail::pi::PiProgram>

sycl/source/detail/kernel_program_cache.cpp

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,8 @@
1313
namespace sycl {
1414
inline namespace _V1 {
1515
namespace detail {
16-
KernelProgramCache::~KernelProgramCache() {
17-
for (auto &ProgIt : MCachedPrograms.Cache) {
18-
ProgramWithBuildStateT &ProgWithState = ProgIt.second;
19-
sycl::detail::pi::PiProgram *ToBeDeleted = ProgWithState.Ptr.load();
20-
21-
if (!ToBeDeleted)
22-
continue;
23-
24-
auto KernIt = MKernelsPerProgramCache.find(*ToBeDeleted);
25-
26-
if (KernIt != MKernelsPerProgramCache.end()) {
27-
for (auto &p : KernIt->second) {
28-
BuildResult<KernelArgMaskPairT> &KernelWithState = p.second;
29-
KernelArgMaskPairT *KernelArgMaskPair = KernelWithState.Ptr.load();
30-
31-
if (KernelArgMaskPair) {
32-
const PluginPtr &Plugin = MParentContext->getPlugin();
33-
Plugin->call<PiApiKind::piKernelRelease>(KernelArgMaskPair->first);
34-
}
35-
}
36-
MKernelsPerProgramCache.erase(KernIt);
37-
}
38-
39-
const PluginPtr &Plugin = MParentContext->getPlugin();
40-
Plugin->call<PiApiKind::piProgramRelease>(*ToBeDeleted);
41-
}
16+
const PluginPtr &KernelProgramCache::getPlugin() {
17+
return MParentContext->getPlugin();
4218
}
4319
} // namespace detail
4420
} // namespace _V1

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 144 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,16 @@ class KernelProgramCache {
4343
};
4444

4545
/// Denotes the state of a build.
46-
enum BuildState { BS_InProgress, BS_Done, BS_Failed };
46+
enum class BuildState { BS_Initial, BS_InProgress, BS_Done, BS_Failed };
4747

4848
/// Denotes pointer to some entity with its general state and build error.
4949
/// The pointer is not null if and only if the entity is usable.
5050
/// State of the entity is provided by the user of cache instance.
5151
/// Currently there is only a single user - ProgramManager class.
5252
template <typename T> struct BuildResult {
53-
std::atomic<T *> Ptr;
5453
T Val;
55-
std::atomic<BuildState> State;
56-
BuildError Error;
54+
std::atomic<BuildState> State{BuildState::BS_Initial};
55+
BuildError Error{"", 0};
5756

5857
/// Condition variable to signal that build result is ready.
5958
/// A per-object (i.e. kernel or program) condition variable is employed
@@ -69,10 +68,38 @@ class KernelProgramCache {
6968
/// A mutex to be employed along with MBuildCV.
7069
std::mutex MBuildResultMutex;
7170

72-
BuildResult(T *P, BuildState S) : Ptr{P}, State{S}, Error{"", 0} {}
71+
BuildState
72+
waitUntilTransition(BuildState From = BuildState::BS_InProgress) {
73+
BuildState To;
74+
std::unique_lock Lock(MBuildResultMutex);
75+
MBuildCV.wait(Lock, [&] {
76+
To = State;
77+
return State != From;
78+
});
79+
return To;
80+
}
81+
82+
void updateAndNotify(BuildState DesiredState) {
83+
{
84+
std::lock_guard<std::mutex> Lock(MBuildResultMutex);
85+
State.store(DesiredState);
86+
}
87+
MBuildCV.notify_all();
88+
}
89+
};
90+
91+
struct ProgramBuildResult : public BuildResult<sycl::detail::pi::PiProgram> {
92+
PluginPtr Plugin;
93+
ProgramBuildResult(const PluginPtr &Plugin) : Plugin(Plugin) {
94+
Val = nullptr;
95+
}
96+
~ProgramBuildResult() {
97+
if (Val)
98+
Plugin->call<PiApiKind::piProgramRelease>(Val);
99+
}
73100
};
101+
using ProgramBuildResultPtr = std::shared_ptr<ProgramBuildResult>;
74102

75-
using ProgramWithBuildStateT = BuildResult<sycl::detail::pi::PiProgram>;
76103
/* Drop LinkOptions and CompileOptions from CacheKey since they are only used
77104
* when debugging environment variables are set and we can just ignore them
78105
* since all kernels will have their build options overridden with the same
@@ -83,7 +110,7 @@ class KernelProgramCache {
83110
std::pair<std::uintptr_t, sycl::detail::pi::PiDevice>;
84111

85112
struct ProgramCache {
86-
::boost::unordered_map<ProgramCacheKeyT, ProgramWithBuildStateT> Cache;
113+
::boost::unordered_map<ProgramCacheKeyT, ProgramBuildResultPtr> Cache;
87114
::boost::unordered_multimap<CommonProgramKeyT, ProgramCacheKeyT> KeyMap;
88115

89116
size_t size() const noexcept { return Cache.size(); }
@@ -93,8 +120,20 @@ class KernelProgramCache {
93120

94121
using KernelArgMaskPairT =
95122
std::pair<sycl::detail::pi::PiKernel, const KernelArgMask *>;
123+
struct KernelBuildResult : public BuildResult<KernelArgMaskPairT> {
124+
PluginPtr Plugin;
125+
KernelBuildResult(const PluginPtr &Plugin) : Plugin(Plugin) {
126+
Val.first = nullptr;
127+
}
128+
~KernelBuildResult() {
129+
if (Val.first)
130+
Plugin->call<PiApiKind::piKernelRelease>(Val.first);
131+
}
132+
};
133+
using KernelBuildResultPtr = std::shared_ptr<KernelBuildResult>;
134+
96135
using KernelByNameT =
97-
::boost::unordered_map<std::string, BuildResult<KernelArgMaskPairT>>;
136+
::boost::unordered_map<std::string, KernelBuildResultPtr>;
98137
using KernelCacheT =
99138
::boost::unordered_map<sycl::detail::pi::PiProgram, KernelByNameT>;
100139

@@ -112,7 +151,7 @@ class KernelProgramCache {
112151
using KernelFastCacheT =
113152
::boost::unordered_flat_map<KernelFastCacheKeyT, KernelFastCacheValT>;
114153

115-
~KernelProgramCache();
154+
~KernelProgramCache() = default;
116155

117156
void setContextPtr(const ContextPtr &AContext) { MParentContext = AContext; }
118157

@@ -124,61 +163,30 @@ class KernelProgramCache {
124163
return {MKernelsPerProgramCache, MKernelsPerProgramCacheMutex};
125164
}
126165

127-
std::pair<ProgramWithBuildStateT *, bool>
166+
std::pair<ProgramBuildResultPtr, bool>
128167
getOrInsertProgram(const ProgramCacheKeyT &CacheKey) {
129168
auto LockedCache = acquireCachedPrograms();
130169
auto &ProgCache = LockedCache.get();
131-
auto Inserted = ProgCache.Cache.emplace(
132-
std::piecewise_construct, std::forward_as_tuple(CacheKey),
133-
std::forward_as_tuple(nullptr, BS_InProgress));
134-
if (Inserted.second) {
170+
auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr);
171+
if (DidInsert) {
172+
It->second = std::make_shared<ProgramBuildResult>(getPlugin());
135173
// Save reference between the common key and the full key.
136174
CommonProgramKeyT CommonKey =
137175
std::make_pair(CacheKey.first.second, CacheKey.second);
138-
ProgCache.KeyMap.emplace(std::piecewise_construct,
139-
std::forward_as_tuple(CommonKey),
140-
std::forward_as_tuple(CacheKey));
176+
ProgCache.KeyMap.emplace(CommonKey, CacheKey);
141177
}
142-
return std::make_pair(&Inserted.first->second, Inserted.second);
178+
return std::make_pair(It->second, DidInsert);
143179
}
144180

145-
std::pair<BuildResult<KernelArgMaskPairT> *, bool>
181+
std::pair<KernelBuildResultPtr, bool>
146182
getOrInsertKernel(sycl::detail::pi::PiProgram Program,
147183
const std::string &KernelName) {
148184
auto LockedCache = acquireKernelsPerProgramCache();
149185
auto &Cache = LockedCache.get()[Program];
150-
auto Inserted = Cache.emplace(
151-
std::piecewise_construct, std::forward_as_tuple(KernelName),
152-
std::forward_as_tuple(nullptr, BS_InProgress));
153-
return std::make_pair(&Inserted.first->second, Inserted.second);
154-
}
155-
156-
template <typename T, class Predicate>
157-
void waitUntilBuilt(BuildResult<T> &BR, Predicate Pred) const {
158-
std::unique_lock<std::mutex> Lock(BR.MBuildResultMutex);
159-
160-
BR.MBuildCV.wait(Lock, Pred);
161-
}
162-
163-
template <typename ExceptionT, typename RetT>
164-
RetT *waitUntilBuilt(BuildResult<RetT> *BuildResult) {
165-
// Any thread which will find nullptr in cache will wait until the pointer
166-
// is not null anymore.
167-
waitUntilBuilt(*BuildResult, [BuildResult]() {
168-
int State = BuildResult->State.load();
169-
return State == BuildState::BS_Done || State == BuildState::BS_Failed;
170-
});
171-
172-
if (BuildResult->Error.isFilledIn()) {
173-
const BuildError &Error = BuildResult->Error;
174-
throw ExceptionT(Error.Msg, Error.Code);
175-
}
176-
177-
return BuildResult->Ptr.load();
178-
}
179-
180-
template <typename T> void notifyAllBuild(BuildResult<T> &BR) const {
181-
BR.MBuildCV.notify_all();
186+
auto [It, DidInsert] = Cache.try_emplace(KernelName, nullptr);
187+
if (DidInsert)
188+
It->second = std::make_shared<KernelBuildResult>(getPlugin());
189+
return std::make_pair(It->second, DidInsert);
182190
}
183191

184192
template <typename KeyT>
@@ -203,11 +211,94 @@ class KernelProgramCache {
203211
///
204212
/// This member function should only be used in unit tests.
205213
void reset() {
214+
std::lock_guard<std::mutex> L1(MProgramCacheMutex);
215+
std::lock_guard<std::mutex> L2(MKernelsPerProgramCacheMutex);
216+
std::lock_guard<std::mutex> L3(MKernelFastCacheMutex);
206217
MCachedPrograms = ProgramCache{};
207218
MKernelsPerProgramCache = KernelCacheT{};
208219
MKernelFastCache = KernelFastCacheT{};
209220
}
210221

222+
/// Try to fetch entity (kernel or program) from cache. If there is no such
223+
/// entity try to build it. Throw any exception build process may throw.
224+
/// This method eliminates unwanted builds by employing atomic variable with
225+
/// build state and waiting until the entity is built in another thread.
226+
/// If the building thread has failed the awaiting thread will fail either.
227+
/// Exception thrown by build procedure are rethrown.
228+
///
229+
/// \tparam RetT type of entity to get
230+
/// \tparam ExceptionT type of exception to throw on awaiting thread if the
231+
/// building thread fails build step.
232+
/// \tparam KeyT key (in cache) to fetch built entity with
233+
/// \tparam AcquireFT type of function which will acquire the locked version
234+
/// of
235+
/// the cache. Accept reference to KernelProgramCache.
236+
/// \tparam GetCacheFT type of function which will fetch proper cache from
237+
/// locked version. Accepts reference to locked version of cache.
238+
/// \tparam BuildFT type of function which will build the entity if it is not
239+
/// in
240+
/// cache. Accepts nothing. Return pointer to built entity.
241+
///
242+
/// \return a pointer to cached build result, return value must not be
243+
/// nullptr.
244+
template <typename ExceptionT, typename GetCachedBuildFT, typename BuildFT>
245+
auto getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build) {
246+
using BuildState = KernelProgramCache::BuildState;
247+
constexpr size_t MaxAttempts = 2;
248+
for (size_t AttemptCounter = 0;; ++AttemptCounter) {
249+
auto Res = GetCachedBuild();
250+
auto &BuildResult = Res.first;
251+
BuildState Expected = BuildState::BS_Initial;
252+
BuildState Desired = BuildState::BS_InProgress;
253+
if (!BuildResult->State.compare_exchange_strong(Expected, Desired)) {
254+
// no insertion took place, thus some other thread has already inserted
255+
// smth in the cache
256+
BuildState NewState = BuildResult->waitUntilTransition();
257+
258+
// Build succeeded.
259+
if (NewState == BuildState::BS_Done)
260+
return BuildResult;
261+
262+
// Build failed, or this is the last attempt.
263+
if (NewState == BuildState::BS_Failed ||
264+
AttemptCounter + 1 == MaxAttempts) {
265+
if (BuildResult->Error.isFilledIn())
266+
throw ExceptionT(BuildResult->Error.Msg, BuildResult->Error.Code);
267+
else
268+
throw exception();
269+
}
270+
271+
// NewState == BuildState::BS_Initial
272+
// Build state was set back to the initial state,
273+
// which means to go back to the beginning of the
274+
// loop and try again.
275+
continue;
276+
}
277+
278+
// only the building thread will run this
279+
try {
280+
BuildResult->Val = Build();
281+
282+
BuildResult->updateAndNotify(BuildState::BS_Done);
283+
return BuildResult;
284+
} catch (const exception &Ex) {
285+
BuildResult->Error.Msg = Ex.what();
286+
BuildResult->Error.Code = Ex.get_cl_code();
287+
if (BuildResult->Error.Code == PI_ERROR_OUT_OF_RESOURCES) {
288+
reset();
289+
BuildResult->updateAndNotify(BuildState::BS_Initial);
290+
continue;
291+
}
292+
293+
BuildResult->updateAndNotify(BuildState::BS_Failed);
294+
std::rethrow_exception(std::current_exception());
295+
} catch (...) {
296+
BuildResult->updateAndNotify(BuildState::BS_Initial);
297+
std::rethrow_exception(std::current_exception());
298+
}
299+
}
300+
}
301+
211302
private:
212303
std::mutex MProgramCacheMutex;
213304
std::mutex MKernelsPerProgramCacheMutex;
@@ -219,6 +310,8 @@ class KernelProgramCache {
219310
std::mutex MKernelFastCacheMutex;
220311
KernelFastCacheT MKernelFastCache;
221312
friend class ::MockKernelProgramCache;
313+
314+
const PluginPtr &getPlugin();
222315
};
223316
} // namespace detail
224317
} // namespace _V1

0 commit comments

Comments
 (0)