Skip to content

Commit 2032ce4

Browse files
[SYCL] Refactor program build and cache (#8104)
In order to help device_global find associated programs in the cache, the way the program manager and the cache interacts needs a refactoring. This commit makes this refactoring. Additionally this adds the notion of a "common key" (CommonProgramKeyT) representing the common parts of the full cache key shared between builds of the same program on the same device, excluding specializations such as specialization constants and compilation options. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 38f28a2 commit 2032ce4

File tree

4 files changed

+114
-104
lines changed

4 files changed

+114
-104
lines changed

sycl/source/detail/kernel_program_cache.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace sycl {
1414
__SYCL_INLINE_VER_NAMESPACE(_V1) {
1515
namespace detail {
1616
KernelProgramCache::~KernelProgramCache() {
17-
for (auto &ProgIt : MCachedPrograms) {
17+
for (auto &ProgIt : MCachedPrograms.Cache) {
1818
ProgramWithBuildStateT &ProgWithState = ProgIt.second;
1919
PiProgramT *ToBeDeleted = ProgWithState.Ptr.load();
2020

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,16 @@ class KernelProgramCache {
3939
bool isFilledIn() const { return !Msg.empty(); }
4040
};
4141

42+
/// Denotes the state of a build.
43+
enum BuildState { BS_InProgress, BS_Done, BS_Failed };
44+
4245
/// Denotes pointer to some entity with its general state and build error.
4346
/// The pointer is not null if and only if the entity is usable.
4447
/// State of the entity is provided by the user of cache instance.
4548
/// Currently there is only a single user - ProgramManager class.
4649
template <typename T> struct BuildResult {
4750
std::atomic<T *> Ptr;
48-
std::atomic<int> State;
51+
std::atomic<BuildState> State;
4952
BuildError Error;
5053

5154
/// Condition variable to signal that build result is ready.
@@ -62,15 +65,23 @@ class KernelProgramCache {
6265
/// A mutex to be employed along with MBuildCV.
6366
std::mutex MBuildResultMutex;
6467

65-
BuildResult(T *P, int S) : Ptr{P}, State{S}, Error{"", 0} {}
68+
BuildResult(T *P, BuildState S) : Ptr{P}, State{S}, Error{"", 0} {}
6669
};
6770

6871
using PiProgramT = std::remove_pointer<RT::PiProgram>::type;
6972
using PiProgramPtrT = std::atomic<PiProgramT *>;
7073
using ProgramWithBuildStateT = BuildResult<PiProgramT>;
7174
using ProgramCacheKeyT = std::pair<std::pair<SerializedObj, std::uintptr_t>,
7275
std::pair<RT::PiDevice, std::string>>;
73-
using ProgramCacheT = std::map<ProgramCacheKeyT, ProgramWithBuildStateT>;
76+
using CommonProgramKeyT = std::pair<std::uintptr_t, RT::PiDevice>;
77+
78+
struct ProgramCache {
79+
std::map<ProgramCacheKeyT, ProgramWithBuildStateT> Cache;
80+
std::multimap<CommonProgramKeyT, ProgramCacheKeyT> KeyMap;
81+
82+
size_t size() const noexcept { return Cache.size(); }
83+
};
84+
7485
using ContextPtr = context_impl *;
7586

7687
using PiKernelT = std::remove_pointer<RT::PiKernel>::type;
@@ -91,21 +102,66 @@ class KernelProgramCache {
91102

92103
void setContextPtr(const ContextPtr &AContext) { MParentContext = AContext; }
93104

94-
Locked<ProgramCacheT> acquireCachedPrograms() {
105+
Locked<ProgramCache> acquireCachedPrograms() {
95106
return {MCachedPrograms, MProgramCacheMutex};
96107
}
97108

98109
Locked<KernelCacheT> acquireKernelsPerProgramCache() {
99110
return {MKernelsPerProgramCache, MKernelsPerProgramCacheMutex};
100111
}
101112

113+
std::pair<ProgramWithBuildStateT *, bool>
114+
getOrInsertProgram(const ProgramCacheKeyT &CacheKey) {
115+
auto LockedCache = acquireCachedPrograms();
116+
auto &ProgCache = LockedCache.get();
117+
auto Inserted = ProgCache.Cache.emplace(
118+
std::piecewise_construct, std::forward_as_tuple(CacheKey),
119+
std::forward_as_tuple(nullptr, BS_InProgress));
120+
if (Inserted.second) {
121+
// Save reference between the common key and the full key.
122+
CommonProgramKeyT CommonKey =
123+
std::make_pair(CacheKey.first.second, CacheKey.second.first);
124+
ProgCache.KeyMap.emplace(std::piecewise_construct,
125+
std::forward_as_tuple(CommonKey),
126+
std::forward_as_tuple(CacheKey));
127+
}
128+
return std::make_pair(&Inserted.first->second, Inserted.second);
129+
}
130+
131+
std::pair<KernelWithBuildStateT *, bool>
132+
getOrInsertKernel(RT::PiProgram Program, const std::string &KernelName) {
133+
auto LockedCache = acquireKernelsPerProgramCache();
134+
auto &Cache = LockedCache.get()[Program];
135+
auto Inserted = Cache.emplace(
136+
std::piecewise_construct, std::forward_as_tuple(KernelName),
137+
std::forward_as_tuple(nullptr, BS_InProgress));
138+
return std::make_pair(&Inserted.first->second, Inserted.second);
139+
}
140+
102141
template <typename T, class Predicate>
103142
void waitUntilBuilt(BuildResult<T> &BR, Predicate Pred) const {
104143
std::unique_lock<std::mutex> Lock(BR.MBuildResultMutex);
105144

106145
BR.MBuildCV.wait(Lock, Pred);
107146
}
108147

148+
template <typename ExceptionT, typename RetT>
149+
RetT *waitUntilBuilt(BuildResult<RetT> *BuildResult) {
150+
// Any thread which will find nullptr in cache will wait until the pointer
151+
// is not null anymore.
152+
waitUntilBuilt(*BuildResult, [BuildResult]() {
153+
int State = BuildResult->State.load();
154+
return State == BuildState::BS_Done || State == BuildState::BS_Failed;
155+
});
156+
157+
if (BuildResult->Error.isFilledIn()) {
158+
const BuildError &Error = BuildResult->Error;
159+
throw ExceptionT(Error.Msg, Error.Code);
160+
}
161+
162+
return BuildResult->Ptr.load();
163+
}
164+
109165
template <typename T> void notifyAllBuild(BuildResult<T> &BR) const {
110166
BR.MBuildCV.notify_all();
111167
}
@@ -132,7 +188,7 @@ class KernelProgramCache {
132188
///
133189
/// This member function should only be used in unit tests.
134190
void reset() {
135-
MCachedPrograms = ProgramCacheT{};
191+
MCachedPrograms = ProgramCache{};
136192
MKernelsPerProgramCache = KernelCacheT{};
137193
MKernelFastCache = KernelFastCacheT{};
138194
}
@@ -141,7 +197,7 @@ class KernelProgramCache {
141197
std::mutex MProgramCacheMutex;
142198
std::mutex MKernelsPerProgramCacheMutex;
143199

144-
ProgramCacheT MCachedPrograms;
200+
ProgramCache MCachedPrograms;
145201
KernelCacheT MKernelsPerProgramCache;
146202
ContextPtr MParentContext;
147203

0 commit comments

Comments
 (0)