@@ -43,17 +43,16 @@ class KernelProgramCache {
43
43
};
44
44
45
45
// / Denotes the state of a build.
46
- enum BuildState { BS_InProgress, BS_Done, BS_Failed };
46
+ enum class BuildState { BS_Initial, BS_InProgress, BS_Done, BS_Failed };
47
47
48
48
// / Denotes pointer to some entity with its general state and build error.
49
49
// / The pointer is not null if and only if the entity is usable.
50
50
// / State of the entity is provided by the user of cache instance.
51
51
// / Currently there is only a single user - ProgramManager class.
52
52
template <typename T> struct BuildResult {
53
- std::atomic<T *> Ptr;
54
53
T Val;
55
- std::atomic<BuildState> State;
56
- BuildError Error;
54
+ std::atomic<BuildState> State{BuildState::BS_Initial} ;
55
+ BuildError Error{ " " , 0 } ;
57
56
58
57
// / Condition variable to signal that build result is ready.
59
58
// / A per-object (i.e. kernel or program) condition variable is employed
@@ -69,10 +68,38 @@ class KernelProgramCache {
69
68
// / A mutex to be employed along with MBuildCV.
70
69
std::mutex MBuildResultMutex;
71
70
72
- BuildResult (T *P, BuildState S) : Ptr{P}, State{S}, Error{" " , 0 } {}
71
+ BuildState
72
+ waitUntilTransition (BuildState From = BuildState::BS_InProgress) {
73
+ BuildState To;
74
+ std::unique_lock Lock (MBuildResultMutex);
75
+ MBuildCV.wait (Lock, [&] {
76
+ To = State;
77
+ return State != From;
78
+ });
79
+ return To;
80
+ }
81
+
82
+ void updateAndNotify (BuildState DesiredState) {
83
+ {
84
+ std::lock_guard<std::mutex> Lock (MBuildResultMutex);
85
+ State.store (DesiredState);
86
+ }
87
+ MBuildCV.notify_all ();
88
+ }
89
+ };
90
+
91
+ struct ProgramBuildResult : public BuildResult <sycl::detail::pi::PiProgram> {
92
+ PluginPtr Plugin;
93
+ ProgramBuildResult (const PluginPtr &Plugin) : Plugin(Plugin) {
94
+ Val = nullptr ;
95
+ }
96
+ ~ProgramBuildResult () {
97
+ if (Val)
98
+ Plugin->call <PiApiKind::piProgramRelease>(Val);
99
+ }
73
100
};
101
+ using ProgramBuildResultPtr = std::shared_ptr<ProgramBuildResult>;
74
102
75
- using ProgramWithBuildStateT = BuildResult<sycl::detail::pi::PiProgram>;
76
103
/* Drop LinkOptions and CompileOptions from CacheKey since they are only used
77
104
* when debugging environment variables are set and we can just ignore them
78
105
* since all kernels will have their build options overridden with the same
@@ -83,7 +110,7 @@ class KernelProgramCache {
83
110
std::pair<std::uintptr_t , sycl::detail::pi::PiDevice>;
84
111
85
112
struct ProgramCache {
86
- ::boost::unordered_map<ProgramCacheKeyT, ProgramWithBuildStateT > Cache;
113
+ ::boost::unordered_map<ProgramCacheKeyT, ProgramBuildResultPtr > Cache;
87
114
::boost::unordered_multimap<CommonProgramKeyT, ProgramCacheKeyT> KeyMap;
88
115
89
116
size_t size () const noexcept { return Cache.size (); }
@@ -93,8 +120,20 @@ class KernelProgramCache {
93
120
94
121
using KernelArgMaskPairT =
95
122
std::pair<sycl::detail::pi::PiKernel, const KernelArgMask *>;
123
+ struct KernelBuildResult : public BuildResult <KernelArgMaskPairT> {
124
+ PluginPtr Plugin;
125
+ KernelBuildResult (const PluginPtr &Plugin) : Plugin(Plugin) {
126
+ Val.first = nullptr ;
127
+ }
128
+ ~KernelBuildResult () {
129
+ if (Val.first )
130
+ Plugin->call <PiApiKind::piKernelRelease>(Val.first );
131
+ }
132
+ };
133
+ using KernelBuildResultPtr = std::shared_ptr<KernelBuildResult>;
134
+
96
135
using KernelByNameT =
97
- ::boost::unordered_map<std::string, BuildResult<KernelArgMaskPairT> >;
136
+ ::boost::unordered_map<std::string, KernelBuildResultPtr >;
98
137
using KernelCacheT =
99
138
::boost::unordered_map<sycl::detail::pi::PiProgram, KernelByNameT>;
100
139
@@ -112,7 +151,7 @@ class KernelProgramCache {
112
151
using KernelFastCacheT =
113
152
::boost::unordered_flat_map<KernelFastCacheKeyT, KernelFastCacheValT>;
114
153
115
- ~KernelProgramCache ();
154
+ ~KernelProgramCache () = default ;
116
155
117
156
void setContextPtr (const ContextPtr &AContext) { MParentContext = AContext; }
118
157
@@ -124,61 +163,30 @@ class KernelProgramCache {
124
163
return {MKernelsPerProgramCache, MKernelsPerProgramCacheMutex};
125
164
}
126
165
127
- std::pair<ProgramWithBuildStateT * , bool >
166
+ std::pair<ProgramBuildResultPtr , bool >
128
167
getOrInsertProgram (const ProgramCacheKeyT &CacheKey) {
129
168
auto LockedCache = acquireCachedPrograms ();
130
169
auto &ProgCache = LockedCache.get ();
131
- auto Inserted = ProgCache.Cache .emplace (
132
- std::piecewise_construct, std::forward_as_tuple (CacheKey),
133
- std::forward_as_tuple (nullptr , BS_InProgress));
134
- if (Inserted.second ) {
170
+ auto [It, DidInsert] = ProgCache.Cache .try_emplace (CacheKey, nullptr );
171
+ if (DidInsert) {
172
+ It->second = std::make_shared<ProgramBuildResult>(getPlugin ());
135
173
// Save reference between the common key and the full key.
136
174
CommonProgramKeyT CommonKey =
137
175
std::make_pair (CacheKey.first .second , CacheKey.second );
138
- ProgCache.KeyMap .emplace (std::piecewise_construct,
139
- std::forward_as_tuple (CommonKey),
140
- std::forward_as_tuple (CacheKey));
176
+ ProgCache.KeyMap .emplace (CommonKey, CacheKey);
141
177
}
142
- return std::make_pair (&Inserted. first ->second , Inserted. second );
178
+ return std::make_pair (It ->second , DidInsert );
143
179
}
144
180
145
- std::pair<BuildResult<KernelArgMaskPairT> * , bool >
181
+ std::pair<KernelBuildResultPtr , bool >
146
182
getOrInsertKernel (sycl::detail::pi::PiProgram Program,
147
183
const std::string &KernelName) {
148
184
auto LockedCache = acquireKernelsPerProgramCache ();
149
185
auto &Cache = LockedCache.get ()[Program];
150
- auto Inserted = Cache.emplace (
151
- std::piecewise_construct, std::forward_as_tuple (KernelName),
152
- std::forward_as_tuple (nullptr , BS_InProgress));
153
- return std::make_pair (&Inserted.first ->second , Inserted.second );
154
- }
155
-
156
- template <typename T, class Predicate >
157
- void waitUntilBuilt (BuildResult<T> &BR, Predicate Pred) const {
158
- std::unique_lock<std::mutex> Lock (BR.MBuildResultMutex );
159
-
160
- BR.MBuildCV .wait (Lock, Pred);
161
- }
162
-
163
- template <typename ExceptionT, typename RetT>
164
- RetT *waitUntilBuilt (BuildResult<RetT> *BuildResult) {
165
- // Any thread which will find nullptr in cache will wait until the pointer
166
- // is not null anymore.
167
- waitUntilBuilt (*BuildResult, [BuildResult]() {
168
- int State = BuildResult->State .load ();
169
- return State == BuildState::BS_Done || State == BuildState::BS_Failed;
170
- });
171
-
172
- if (BuildResult->Error .isFilledIn ()) {
173
- const BuildError &Error = BuildResult->Error ;
174
- throw ExceptionT (Error.Msg , Error.Code );
175
- }
176
-
177
- return BuildResult->Ptr .load ();
178
- }
179
-
180
- template <typename T> void notifyAllBuild (BuildResult<T> &BR) const {
181
- BR.MBuildCV .notify_all ();
186
+ auto [It, DidInsert] = Cache.try_emplace (KernelName, nullptr );
187
+ if (DidInsert)
188
+ It->second = std::make_shared<KernelBuildResult>(getPlugin ());
189
+ return std::make_pair (It->second , DidInsert);
182
190
}
183
191
184
192
template <typename KeyT>
@@ -203,11 +211,94 @@ class KernelProgramCache {
203
211
// /
204
212
// / This member function should only be used in unit tests.
205
213
void reset () {
214
+ std::lock_guard<std::mutex> L1 (MProgramCacheMutex);
215
+ std::lock_guard<std::mutex> L2 (MKernelsPerProgramCacheMutex);
216
+ std::lock_guard<std::mutex> L3 (MKernelFastCacheMutex);
206
217
MCachedPrograms = ProgramCache{};
207
218
MKernelsPerProgramCache = KernelCacheT{};
208
219
MKernelFastCache = KernelFastCacheT{};
209
220
}
210
221
222
+ // / Try to fetch entity (kernel or program) from cache. If there is no such
223
+ // / entity try to build it. Throw any exception build process may throw.
224
+ // / This method eliminates unwanted builds by employing atomic variable with
225
+ // / build state and waiting until the entity is built in another thread.
226
+ // / If the building thread has failed the awaiting thread will fail either.
227
+ // / Exception thrown by build procedure are rethrown.
228
+ // /
229
+ // / \tparam RetT type of entity to get
230
+ // / \tparam ExceptionT type of exception to throw on awaiting thread if the
231
+ // / building thread fails build step.
232
+ // / \tparam KeyT key (in cache) to fetch built entity with
233
+ // / \tparam AcquireFT type of function which will acquire the locked version
234
+ // / of
235
+ // / the cache. Accept reference to KernelProgramCache.
236
+ // / \tparam GetCacheFT type of function which will fetch proper cache from
237
+ // / locked version. Accepts reference to locked version of cache.
238
+ // / \tparam BuildFT type of function which will build the entity if it is not
239
+ // / in
240
+ // / cache. Accepts nothing. Return pointer to built entity.
241
+ // /
242
+ // / \return a pointer to cached build result, return value must not be
243
+ // / nullptr.
244
+ template <typename ExceptionT, typename GetCachedBuildFT, typename BuildFT>
245
+ auto getOrBuild (GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build) {
246
+ using BuildState = KernelProgramCache::BuildState;
247
+ constexpr size_t MaxAttempts = 2 ;
248
+ for (size_t AttemptCounter = 0 ;; ++AttemptCounter) {
249
+ auto Res = GetCachedBuild ();
250
+ auto &BuildResult = Res.first ;
251
+ BuildState Expected = BuildState::BS_Initial;
252
+ BuildState Desired = BuildState::BS_InProgress;
253
+ if (!BuildResult->State .compare_exchange_strong (Expected, Desired)) {
254
+ // no insertion took place, thus some other thread has already inserted
255
+ // smth in the cache
256
+ BuildState NewState = BuildResult->waitUntilTransition ();
257
+
258
+ // Build succeeded.
259
+ if (NewState == BuildState::BS_Done)
260
+ return BuildResult;
261
+
262
+ // Build failed, or this is the last attempt.
263
+ if (NewState == BuildState::BS_Failed ||
264
+ AttemptCounter + 1 == MaxAttempts) {
265
+ if (BuildResult->Error .isFilledIn ())
266
+ throw ExceptionT (BuildResult->Error .Msg , BuildResult->Error .Code );
267
+ else
268
+ throw exception ();
269
+ }
270
+
271
+ // NewState == BuildState::BS_Initial
272
+ // Build state was set back to the initial state,
273
+ // which means to go back to the beginning of the
274
+ // loop and try again.
275
+ continue ;
276
+ }
277
+
278
+ // only the building thread will run this
279
+ try {
280
+ BuildResult->Val = Build ();
281
+
282
+ BuildResult->updateAndNotify (BuildState::BS_Done);
283
+ return BuildResult;
284
+ } catch (const exception &Ex) {
285
+ BuildResult->Error .Msg = Ex.what ();
286
+ BuildResult->Error .Code = Ex.get_cl_code ();
287
+ if (BuildResult->Error .Code == PI_ERROR_OUT_OF_RESOURCES) {
288
+ reset ();
289
+ BuildResult->updateAndNotify (BuildState::BS_Initial);
290
+ continue ;
291
+ }
292
+
293
+ BuildResult->updateAndNotify (BuildState::BS_Failed);
294
+ std::rethrow_exception (std::current_exception ());
295
+ } catch (...) {
296
+ BuildResult->updateAndNotify (BuildState::BS_Initial);
297
+ std::rethrow_exception (std::current_exception ());
298
+ }
299
+ }
300
+ }
301
+
211
302
private:
212
303
std::mutex MProgramCacheMutex;
213
304
std::mutex MKernelsPerProgramCacheMutex;
@@ -219,6 +310,8 @@ class KernelProgramCache {
219
310
std::mutex MKernelFastCacheMutex;
220
311
KernelFastCacheT MKernelFastCache;
221
312
friend class ::MockKernelProgramCache;
313
+
314
+ const PluginPtr &getPlugin ();
222
315
};
223
316
} // namespace detail
224
317
} // namespace _V1
0 commit comments