Skip to content

Commit 9dce9fb

Browse files
committed
Use shared_ptr approach
1 parent dfd64b6 commit 9dce9fb

File tree

7 files changed

+312
-190
lines changed

7 files changed

+312
-190
lines changed

sycl/source/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
set(CMAKE_BUILD_TYPE Debug)
12
#To-Do:
23
#1. Figure out why CMP0057 has to be set. Should have been taken care of earlier in the build
34
#2. Use AddLLVM to modify the build and access config options

sycl/source/detail/context_impl.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ std::optional<sycl::detail::pi::PiProgram> context_impl::getProgramForDevImgs(
447447
const device &Device, const std::set<std::uintptr_t> &ImgIdentifiers,
448448
const std::string &ObjectTypeName) {
449449

450-
KernelProgramCache::ProgramWithBuildStateT *BuildRes = nullptr;
450+
KernelProgramCache::ProgramBuildResultPtr BuildRes = nullptr;
451451
{
452452
auto LockedCache = MKernelProgramCache.acquireCachedPrograms();
453453
auto &KeyMap = LockedCache.get().KeyMap;
@@ -471,12 +471,18 @@ std::optional<sycl::detail::pi::PiProgram> context_impl::getProgramForDevImgs(
471471
assert(KeyMappingsIt != KeyMap.end());
472472
auto CachedProgIt = Cache.find(KeyMappingsIt->second);
473473
assert(CachedProgIt != Cache.end());
474-
BuildRes = &CachedProgIt->second;
474+
BuildRes = CachedProgIt->second;
475475
}
476476
}
477477
if (!BuildRes)
478478
return std::nullopt;
479-
return *MKernelProgramCache.waitUntilBuilt<compile_program_error>(BuildRes);
479+
using BuildState = KernelProgramCache::BuildState;
480+
BuildState NewState = BuildRes->waitUntilTransition();
481+
if (NewState == BuildState::BS_Failed)
482+
throw compile_program_error(BuildRes->Error.Msg, BuildRes->Error.Code);
483+
484+
assert(NewState == BuildState::BS_Done);
485+
return BuildRes->Val;
480486
}
481487

482488
std::optional<sycl::detail::pi::PiProgram>

sycl/source/detail/kernel_program_cache.cpp

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,8 @@
1313
namespace sycl {
1414
inline namespace _V1 {
1515
namespace detail {
16-
KernelProgramCache::~KernelProgramCache() {
17-
for (auto &ProgIt : MCachedPrograms.Cache) {
18-
ProgramWithBuildStateT &ProgWithState = ProgIt.second;
19-
sycl::detail::pi::PiProgram *ToBeDeleted = ProgWithState.Ptr.load();
20-
21-
if (!ToBeDeleted)
22-
continue;
23-
24-
auto KernIt = MKernelsPerProgramCache.find(*ToBeDeleted);
25-
26-
if (KernIt != MKernelsPerProgramCache.end()) {
27-
for (auto &p : KernIt->second) {
28-
BuildResult<KernelArgMaskPairT> &KernelWithState = p.second;
29-
KernelArgMaskPairT *KernelArgMaskPair = KernelWithState.Ptr.load();
30-
31-
if (KernelArgMaskPair) {
32-
const PluginPtr &Plugin = MParentContext->getPlugin();
33-
Plugin->call<PiApiKind::piKernelRelease>(KernelArgMaskPair->first);
34-
}
35-
}
36-
MKernelsPerProgramCache.erase(KernIt);
37-
}
38-
39-
const PluginPtr &Plugin = MParentContext->getPlugin();
40-
Plugin->call<PiApiKind::piProgramRelease>(*ToBeDeleted);
41-
}
16+
const PluginPtr &KernelProgramCache::getPlugin() {
17+
return MParentContext->getPlugin();
4218
}
4319
} // namespace detail
4420
} // namespace _V1

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 143 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class KernelProgramCache {
4343
};
4444

4545
/// Denotes the state of a build.
46-
enum BuildState { BS_InProgress, BS_Done, BS_Failed };
46+
enum class BuildState { BS_Initial, BS_InProgress, BS_Done, BS_Failed };
4747

4848
/// Denotes pointer to some entity with its general state and build error.
4949
/// The pointer is not null if and only if the entity is usable.
@@ -52,8 +52,8 @@ class KernelProgramCache {
5252
template <typename T> struct BuildResult {
5353
std::atomic<T *> Ptr;
5454
T Val;
55-
std::atomic<BuildState> State;
56-
BuildError Error;
55+
std::atomic<BuildState> State{BuildState::BS_Initial};
56+
BuildError Error{"", 0};
5757

5858
/// Condition variable to signal that build result is ready.
5959
/// A per-object (i.e. kernel or program) condition variable is employed
@@ -69,10 +69,38 @@ class KernelProgramCache {
6969
/// A mutex to be employed along with MBuildCV.
7070
std::mutex MBuildResultMutex;
7171

72-
BuildResult(T *P, BuildState S) : Ptr{P}, State{S}, Error{"", 0} {}
72+
BuildState
73+
waitUntilTransition(BuildState From = BuildState::BS_InProgress) {
74+
BuildState To;
75+
std::unique_lock Lock(MBuildResultMutex);
76+
MBuildCV.wait(Lock, [&] {
77+
To = State;
78+
return State != From;
79+
});
80+
return To;
81+
}
82+
83+
void updateAndNotify(BuildState DesiredState) {
84+
{
85+
std::lock_guard<std::mutex> Lock(MBuildResultMutex);
86+
State.store(DesiredState);
87+
}
88+
MBuildCV.notify_all();
89+
}
90+
};
91+
92+
struct ProgramBuildResult : public BuildResult<sycl::detail::pi::PiProgram> {
93+
PluginPtr Plugin;
94+
ProgramBuildResult(const PluginPtr &Plugin) : Plugin(Plugin) {
95+
Val = nullptr;
96+
}
97+
~ProgramBuildResult() {
98+
if (Val)
99+
Plugin->call<PiApiKind::piProgramRelease>(Val);
100+
}
73101
};
102+
using ProgramBuildResultPtr = std::shared_ptr<ProgramBuildResult>;
74103

75-
using ProgramWithBuildStateT = BuildResult<sycl::detail::pi::PiProgram>;
76104
/* Drop LinkOptions and CompileOptions from CacheKey since they are only used
77105
* when debugging environment variables are set and we can just ignore them
78106
* since all kernels will have their build options overridden with the same
@@ -83,7 +111,7 @@ class KernelProgramCache {
83111
std::pair<std::uintptr_t, sycl::detail::pi::PiDevice>;
84112

85113
struct ProgramCache {
86-
::boost::unordered_map<ProgramCacheKeyT, ProgramWithBuildStateT> Cache;
114+
::boost::unordered_map<ProgramCacheKeyT, ProgramBuildResultPtr> Cache;
87115
::boost::unordered_multimap<CommonProgramKeyT, ProgramCacheKeyT> KeyMap;
88116

89117
size_t size() const noexcept { return Cache.size(); }
@@ -93,8 +121,20 @@ class KernelProgramCache {
93121

94122
using KernelArgMaskPairT =
95123
std::pair<sycl::detail::pi::PiKernel, const KernelArgMask *>;
124+
struct KernelBuildResult : public BuildResult<KernelArgMaskPairT> {
125+
PluginPtr Plugin;
126+
KernelBuildResult(const PluginPtr &Plugin) : Plugin(Plugin) {
127+
Val.first = nullptr;
128+
}
129+
~KernelBuildResult() {
130+
if (Val.first)
131+
Plugin->call<PiApiKind::piKernelRelease>(Val.first);
132+
}
133+
};
134+
using KernelBuildResultPtr = std::shared_ptr<KernelBuildResult>;
135+
96136
using KernelByNameT =
97-
::boost::unordered_map<std::string, BuildResult<KernelArgMaskPairT>>;
137+
::boost::unordered_map<std::string, KernelBuildResultPtr>;
98138
using KernelCacheT =
99139
::boost::unordered_map<sycl::detail::pi::PiProgram, KernelByNameT>;
100140

@@ -112,7 +152,7 @@ class KernelProgramCache {
112152
using KernelFastCacheT =
113153
::boost::unordered_flat_map<KernelFastCacheKeyT, KernelFastCacheValT>;
114154

115-
~KernelProgramCache();
155+
~KernelProgramCache() = default;
116156

117157
void setContextPtr(const ContextPtr &AContext) { MParentContext = AContext; }
118158

@@ -124,57 +164,30 @@ class KernelProgramCache {
124164
return {MKernelsPerProgramCache, MKernelsPerProgramCacheMutex};
125165
}
126166

127-
std::pair<ProgramWithBuildStateT *, bool>
167+
std::pair<ProgramBuildResultPtr, bool>
128168
getOrInsertProgram(const ProgramCacheKeyT &CacheKey) {
129169
auto LockedCache = acquireCachedPrograms();
130170
auto &ProgCache = LockedCache.get();
131-
auto Inserted = ProgCache.Cache.emplace(
132-
std::piecewise_construct, std::forward_as_tuple(CacheKey),
133-
std::forward_as_tuple(nullptr, BS_InProgress));
134-
if (Inserted.second) {
171+
auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr);
172+
if (DidInsert) {
173+
It->second = std::make_shared<ProgramBuildResult>(getPlugin());
135174
// Save reference between the common key and the full key.
136175
CommonProgramKeyT CommonKey =
137176
std::make_pair(CacheKey.first.second, CacheKey.second);
138-
ProgCache.KeyMap.emplace(std::piecewise_construct,
139-
std::forward_as_tuple(CommonKey),
140-
std::forward_as_tuple(CacheKey));
177+
ProgCache.KeyMap.emplace(CommonKey, CacheKey);
141178
}
142-
return std::make_pair(&Inserted.first->second, Inserted.second);
179+
return std::make_pair(It->second, DidInsert);
143180
}
144181

145-
std::pair<BuildResult<KernelArgMaskPairT> *, bool>
182+
std::pair<KernelBuildResultPtr, bool>
146183
getOrInsertKernel(sycl::detail::pi::PiProgram Program,
147184
const std::string &KernelName) {
148185
auto LockedCache = acquireKernelsPerProgramCache();
149186
auto &Cache = LockedCache.get()[Program];
150-
auto Inserted = Cache.emplace(
151-
std::piecewise_construct, std::forward_as_tuple(KernelName),
152-
std::forward_as_tuple(nullptr, BS_InProgress));
153-
return std::make_pair(&Inserted.first->second, Inserted.second);
154-
}
155-
156-
template <typename T, class Predicate>
157-
void waitUntilBuilt(BuildResult<T> &BR, Predicate Pred) const {
158-
std::unique_lock<std::mutex> Lock(BR.MBuildResultMutex);
159-
160-
BR.MBuildCV.wait(Lock, Pred);
161-
}
162-
163-
template <typename ExceptionT, typename RetT>
164-
RetT *waitUntilBuilt(BuildResult<RetT> *BuildResult) {
165-
// Any thread which will find nullptr in cache will wait until the pointer
166-
// is not null anymore.
167-
waitUntilBuilt(*BuildResult, [BuildResult]() {
168-
int State = BuildResult->State.load();
169-
return State == BuildState::BS_Done || State == BuildState::BS_Failed;
170-
});
171-
172-
if (BuildResult->Error.isFilledIn()) {
173-
const BuildError &Error = BuildResult->Error;
174-
throw ExceptionT(Error.Msg, Error.Code);
175-
}
176-
177-
return BuildResult->Ptr.load();
187+
auto [It, DidInsert] = Cache.try_emplace(KernelName, nullptr);
188+
if (DidInsert)
189+
It->second = std::make_shared<KernelBuildResult>(getPlugin());
190+
return std::make_pair(It->second, DidInsert);
178191
}
179192

180193
template <typename T> void notifyAllBuild(BuildResult<T> &BR) const {
@@ -208,6 +221,88 @@ class KernelProgramCache {
208221
MKernelFastCache = KernelFastCacheT{};
209222
}
210223

224+
/// Try to fetch entity (kernel or program) from cache. If there is no such
225+
/// entity try to build it. Throw any exception build process may throw.
226+
/// This method eliminates unwanted builds by employing atomic variable with
227+
/// build state and waiting until the entity is built in another thread.
228+
/// If the building thread has failed the awaiting thread will fail either.
229+
/// Exception thrown by build procedure are rethrown.
230+
///
231+
/// \tparam RetT type of entity to get
232+
/// \tparam ExceptionT type of exception to throw on awaiting thread if the
233+
/// building thread fails build step.
234+
/// \tparam KeyT key (in cache) to fetch built entity with
235+
/// \tparam AcquireFT type of function which will acquire the locked version
236+
/// of
237+
/// the cache. Accept reference to KernelProgramCache.
238+
/// \tparam GetCacheFT type of function which will fetch proper cache from
239+
/// locked version. Accepts reference to locked version of cache.
240+
/// \tparam BuildFT type of function which will build the entity if it is not
241+
/// in
242+
/// cache. Accepts nothing. Return pointer to built entity.
243+
///
244+
/// \return a pointer to cached build result, return value must not be
245+
/// nullptr.
246+
template <typename ExceptionT, typename GetCachedBuildFT, typename BuildFT>
247+
auto getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build) {
248+
using BuildState = KernelProgramCache::BuildState;
249+
constexpr size_t MaxAttempts = 2;
250+
for (size_t AttemptCounter = 0;; ++AttemptCounter) {
251+
auto [BuildResult, InsertionTookPlace] = GetCachedBuild();
252+
BuildState Expected = BuildState::BS_Initial;
253+
BuildState Desired = BuildState::BS_InProgress;
254+
if (!BuildResult->State.compare_exchange_strong(Expected, Desired)) {
255+
// no insertion took place, thus some other thread has already inserted
256+
// smth in the cache
257+
BuildState NewState = BuildResult->waitUntilTransition();
258+
259+
// Build succeeded.
260+
if (NewState == BuildState::BS_Done)
261+
return BuildResult;
262+
263+
// Build failed, or this is the last attempt.
264+
if (NewState == BuildState::BS_Failed ||
265+
AttemptCounter + 1 == MaxAttempts) {
266+
if (BuildResult->Error.isFilledIn())
267+
throw ExceptionT(BuildResult->Error.Msg, BuildResult->Error.Code);
268+
else
269+
throw exception();
270+
}
271+
272+
// NewState == BuildState::BS_Initial
273+
// Build state was set back to the initial state,
274+
// which means to go back to the beginning of the
275+
// loop and try again.
276+
continue;
277+
}
278+
279+
// only the building thread will run this
280+
try {
281+
BuildResult->Val = Build();
282+
283+
BuildResult->updateAndNotify(BuildState::BS_Done);
284+
return BuildResult;
285+
} catch (const exception &Ex) {
286+
BuildResult->Error.Msg = Ex.what();
287+
BuildResult->Error.Code = Ex.get_cl_code();
288+
if (BuildResult->Error.Code == PI_ERROR_OUT_OF_RESOURCES) {
289+
std::lock_guard<std::mutex> L1(MProgramCacheMutex);
290+
std::lock_guard<std::mutex> L2(MKernelsPerProgramCacheMutex);
291+
std::lock_guard<std::mutex> L3(MKernelFastCacheMutex);
292+
reset();
293+
BuildResult->updateAndNotify(BuildState::BS_Initial);
294+
continue;
295+
}
296+
297+
BuildResult->updateAndNotify(BuildState::BS_Failed);
298+
std::rethrow_exception(std::current_exception());
299+
} catch (...) {
300+
BuildResult->updateAndNotify(BuildState::BS_Initial);
301+
std::rethrow_exception(std::current_exception());
302+
}
303+
}
304+
}
305+
211306
private:
212307
std::mutex MProgramCacheMutex;
213308
std::mutex MKernelsPerProgramCacheMutex;
@@ -219,6 +314,8 @@ class KernelProgramCache {
219314
std::mutex MKernelFastCacheMutex;
220315
KernelFastCacheT MKernelFastCache;
221316
friend class ::MockKernelProgramCache;
317+
318+
const PluginPtr &getPlugin();
222319
};
223320
} // namespace detail
224321
} // namespace _V1

0 commit comments

Comments
 (0)