@@ -43,7 +43,7 @@ class KernelProgramCache {
43
43
};
44
44
45
45
// / Denotes the state of a build.
46
- enum BuildState { BS_InProgress, BS_Done, BS_Failed };
46
+ enum class BuildState { BS_Initial, BS_InProgress, BS_Done, BS_Failed };
47
47
48
48
// / Denotes pointer to some entity with its general state and build error.
49
49
// / The pointer is not null if and only if the entity is usable.
@@ -52,8 +52,8 @@ class KernelProgramCache {
52
52
template <typename T> struct BuildResult {
53
53
std::atomic<T *> Ptr;
54
54
T Val;
55
- std::atomic<BuildState> State;
56
- BuildError Error;
55
+ std::atomic<BuildState> State{BuildState::BS_Initial} ;
56
+ BuildError Error{ " " , 0 } ;
57
57
58
58
// / Condition variable to signal that build result is ready.
59
59
// / A per-object (i.e. kernel or program) condition variable is employed
@@ -69,10 +69,38 @@ class KernelProgramCache {
69
69
// / A mutex to be employed along with MBuildCV.
70
70
std::mutex MBuildResultMutex;
71
71
72
- BuildResult (T *P, BuildState S) : Ptr{P}, State{S}, Error{" " , 0 } {}
72
+ BuildState
73
+ waitUntilTransition (BuildState From = BuildState::BS_InProgress) {
74
+ BuildState To;
75
+ std::unique_lock Lock (MBuildResultMutex);
76
+ MBuildCV.wait (Lock, [&] {
77
+ To = State;
78
+ return State != From;
79
+ });
80
+ return To;
81
+ }
82
+
83
+ void updateAndNotify (BuildState DesiredState) {
84
+ {
85
+ std::lock_guard<std::mutex> Lock (MBuildResultMutex);
86
+ State.store (DesiredState);
87
+ }
88
+ MBuildCV.notify_all ();
89
+ }
90
+ };
91
+
92
+ struct ProgramBuildResult : public BuildResult <sycl::detail::pi::PiProgram> {
93
+ PluginPtr Plugin;
94
+ ProgramBuildResult (const PluginPtr &Plugin) : Plugin(Plugin) {
95
+ Val = nullptr ;
96
+ }
97
+ ~ProgramBuildResult () {
98
+ if (Val)
99
+ Plugin->call <PiApiKind::piProgramRelease>(Val);
100
+ }
73
101
};
102
+ using ProgramBuildResultPtr = std::shared_ptr<ProgramBuildResult>;
74
103
75
- using ProgramWithBuildStateT = BuildResult<sycl::detail::pi::PiProgram>;
76
104
/* Drop LinkOptions and CompileOptions from CacheKey since they are only used
77
105
* when debugging environment variables are set and we can just ignore them
78
106
* since all kernels will have their build options overridden with the same
@@ -83,7 +111,7 @@ class KernelProgramCache {
83
111
std::pair<std::uintptr_t , sycl::detail::pi::PiDevice>;
84
112
85
113
struct ProgramCache {
86
- ::boost::unordered_map<ProgramCacheKeyT, ProgramWithBuildStateT > Cache;
114
+ ::boost::unordered_map<ProgramCacheKeyT, ProgramBuildResultPtr > Cache;
87
115
::boost::unordered_multimap<CommonProgramKeyT, ProgramCacheKeyT> KeyMap;
88
116
89
117
size_t size () const noexcept { return Cache.size (); }
@@ -93,8 +121,20 @@ class KernelProgramCache {
93
121
94
122
using KernelArgMaskPairT =
95
123
std::pair<sycl::detail::pi::PiKernel, const KernelArgMask *>;
124
+ struct KernelBuildResult : public BuildResult <KernelArgMaskPairT> {
125
+ PluginPtr Plugin;
126
+ KernelBuildResult (const PluginPtr &Plugin) : Plugin(Plugin) {
127
+ Val.first = nullptr ;
128
+ }
129
+ ~KernelBuildResult () {
130
+ if (Val.first )
131
+ Plugin->call <PiApiKind::piKernelRelease>(Val.first );
132
+ }
133
+ };
134
+ using KernelBuildResultPtr = std::shared_ptr<KernelBuildResult>;
135
+
96
136
using KernelByNameT =
97
- ::boost::unordered_map<std::string, BuildResult<KernelArgMaskPairT> >;
137
+ ::boost::unordered_map<std::string, KernelBuildResultPtr >;
98
138
using KernelCacheT =
99
139
::boost::unordered_map<sycl::detail::pi::PiProgram, KernelByNameT>;
100
140
@@ -112,7 +152,7 @@ class KernelProgramCache {
112
152
using KernelFastCacheT =
113
153
::boost::unordered_flat_map<KernelFastCacheKeyT, KernelFastCacheValT>;
114
154
115
- ~KernelProgramCache ();
155
+ ~KernelProgramCache () = default ;
116
156
117
157
void setContextPtr (const ContextPtr &AContext) { MParentContext = AContext; }
118
158
@@ -124,57 +164,30 @@ class KernelProgramCache {
124
164
return {MKernelsPerProgramCache, MKernelsPerProgramCacheMutex};
125
165
}
126
166
127
- std::pair<ProgramWithBuildStateT * , bool >
167
+ std::pair<ProgramBuildResultPtr , bool >
128
168
getOrInsertProgram (const ProgramCacheKeyT &CacheKey) {
129
169
auto LockedCache = acquireCachedPrograms ();
130
170
auto &ProgCache = LockedCache.get ();
131
- auto Inserted = ProgCache.Cache .emplace (
132
- std::piecewise_construct, std::forward_as_tuple (CacheKey),
133
- std::forward_as_tuple (nullptr , BS_InProgress));
134
- if (Inserted.second ) {
171
+ auto [It, DidInsert] = ProgCache.Cache .try_emplace (CacheKey, nullptr );
172
+ if (DidInsert) {
173
+ It->second = std::make_shared<ProgramBuildResult>(getPlugin ());
135
174
// Save reference between the common key and the full key.
136
175
CommonProgramKeyT CommonKey =
137
176
std::make_pair (CacheKey.first .second , CacheKey.second );
138
- ProgCache.KeyMap .emplace (std::piecewise_construct,
139
- std::forward_as_tuple (CommonKey),
140
- std::forward_as_tuple (CacheKey));
177
+ ProgCache.KeyMap .emplace (CommonKey, CacheKey);
141
178
}
142
- return std::make_pair (&Inserted. first ->second , Inserted. second );
179
+ return std::make_pair (It ->second , DidInsert );
143
180
}
144
181
145
- std::pair<BuildResult<KernelArgMaskPairT> * , bool >
182
+ std::pair<KernelBuildResultPtr , bool >
146
183
getOrInsertKernel (sycl::detail::pi::PiProgram Program,
147
184
const std::string &KernelName) {
148
185
auto LockedCache = acquireKernelsPerProgramCache ();
149
186
auto &Cache = LockedCache.get ()[Program];
150
- auto Inserted = Cache.emplace (
151
- std::piecewise_construct, std::forward_as_tuple (KernelName),
152
- std::forward_as_tuple (nullptr , BS_InProgress));
153
- return std::make_pair (&Inserted.first ->second , Inserted.second );
154
- }
155
-
156
- template <typename T, class Predicate >
157
- void waitUntilBuilt (BuildResult<T> &BR, Predicate Pred) const {
158
- std::unique_lock<std::mutex> Lock (BR.MBuildResultMutex );
159
-
160
- BR.MBuildCV .wait (Lock, Pred);
161
- }
162
-
163
- template <typename ExceptionT, typename RetT>
164
- RetT *waitUntilBuilt (BuildResult<RetT> *BuildResult) {
165
- // Any thread which will find nullptr in cache will wait until the pointer
166
- // is not null anymore.
167
- waitUntilBuilt (*BuildResult, [BuildResult]() {
168
- int State = BuildResult->State .load ();
169
- return State == BuildState::BS_Done || State == BuildState::BS_Failed;
170
- });
171
-
172
- if (BuildResult->Error .isFilledIn ()) {
173
- const BuildError &Error = BuildResult->Error ;
174
- throw ExceptionT (Error.Msg , Error.Code );
175
- }
176
-
177
- return BuildResult->Ptr .load ();
187
+ auto [It, DidInsert] = Cache.try_emplace (KernelName, nullptr );
188
+ if (DidInsert)
189
+ It->second = std::make_shared<KernelBuildResult>(getPlugin ());
190
+ return std::make_pair (It->second , DidInsert);
178
191
}
179
192
180
193
template <typename T> void notifyAllBuild (BuildResult<T> &BR) const {
@@ -208,6 +221,88 @@ class KernelProgramCache {
208
221
MKernelFastCache = KernelFastCacheT{};
209
222
}
210
223
224
+ // / Try to fetch entity (kernel or program) from cache. If there is no such
225
+ // / entity try to build it. Throw any exception build process may throw.
226
+ // / This method eliminates unwanted builds by employing atomic variable with
227
+ // / build state and waiting until the entity is built in another thread.
228
+ // / If the building thread has failed the awaiting thread will fail either.
229
+ // / Exception thrown by build procedure are rethrown.
230
+ // /
231
+ // / \tparam RetT type of entity to get
232
+ // / \tparam ExceptionT type of exception to throw on awaiting thread if the
233
+ // / building thread fails build step.
234
+ // / \tparam KeyT key (in cache) to fetch built entity with
235
+ // / \tparam AcquireFT type of function which will acquire the locked version
236
+ // / of
237
+ // / the cache. Accept reference to KernelProgramCache.
238
+ // / \tparam GetCacheFT type of function which will fetch proper cache from
239
+ // / locked version. Accepts reference to locked version of cache.
240
+ // / \tparam BuildFT type of function which will build the entity if it is not
241
+ // / in
242
+ // / cache. Accepts nothing. Return pointer to built entity.
243
+ // /
244
+ // / \return a pointer to cached build result, return value must not be
245
+ // / nullptr.
246
+ template <typename ExceptionT, typename GetCachedBuildFT, typename BuildFT>
247
+ auto getOrBuild (GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build) {
248
+ using BuildState = KernelProgramCache::BuildState;
249
+ constexpr size_t MaxAttempts = 2 ;
250
+ for (size_t AttemptCounter = 0 ;; ++AttemptCounter) {
251
+ auto [BuildResult, InsertionTookPlace] = GetCachedBuild ();
252
+ BuildState Expected = BuildState::BS_Initial;
253
+ BuildState Desired = BuildState::BS_InProgress;
254
+ if (!BuildResult->State .compare_exchange_strong (Expected, Desired)) {
255
+ // no insertion took place, thus some other thread has already inserted
256
+ // smth in the cache
257
+ BuildState NewState = BuildResult->waitUntilTransition ();
258
+
259
+ // Build succeeded.
260
+ if (NewState == BuildState::BS_Done)
261
+ return BuildResult;
262
+
263
+ // Build failed, or this is the last attempt.
264
+ if (NewState == BuildState::BS_Failed ||
265
+ AttemptCounter + 1 == MaxAttempts) {
266
+ if (BuildResult->Error .isFilledIn ())
267
+ throw ExceptionT (BuildResult->Error .Msg , BuildResult->Error .Code );
268
+ else
269
+ throw exception ();
270
+ }
271
+
272
+ // NewState == BuildState::BS_Initial
273
+ // Build state was set back to the initial state,
274
+ // which means to go back to the beginning of the
275
+ // loop and try again.
276
+ continue ;
277
+ }
278
+
279
+ // only the building thread will run this
280
+ try {
281
+ BuildResult->Val = Build ();
282
+
283
+ BuildResult->updateAndNotify (BuildState::BS_Done);
284
+ return BuildResult;
285
+ } catch (const exception &Ex) {
286
+ BuildResult->Error .Msg = Ex.what ();
287
+ BuildResult->Error .Code = Ex.get_cl_code ();
288
+ if (BuildResult->Error .Code == PI_ERROR_OUT_OF_RESOURCES) {
289
+ std::lock_guard<std::mutex> L1 (MProgramCacheMutex);
290
+ std::lock_guard<std::mutex> L2 (MKernelsPerProgramCacheMutex);
291
+ std::lock_guard<std::mutex> L3 (MKernelFastCacheMutex);
292
+ reset ();
293
+ BuildResult->updateAndNotify (BuildState::BS_Initial);
294
+ continue ;
295
+ }
296
+
297
+ BuildResult->updateAndNotify (BuildState::BS_Failed);
298
+ std::rethrow_exception (std::current_exception ());
299
+ } catch (...) {
300
+ BuildResult->updateAndNotify (BuildState::BS_Initial);
301
+ std::rethrow_exception (std::current_exception ());
302
+ }
303
+ }
304
+ }
305
+
211
306
private:
212
307
std::mutex MProgramCacheMutex;
213
308
std::mutex MKernelsPerProgramCacheMutex;
@@ -219,6 +314,8 @@ class KernelProgramCache {
219
314
std::mutex MKernelFastCacheMutex;
220
315
KernelFastCacheT MKernelFastCache;
221
316
friend class ::MockKernelProgramCache;
317
+
318
+ const PluginPtr &getPlugin ();
222
319
};
223
320
} // namespace detail
224
321
} // namespace _V1
0 commit comments