@@ -39,13 +39,16 @@ class KernelProgramCache {
39
39
bool isFilledIn () const { return !Msg.empty (); }
40
40
};
41
41
42
+ // / Denotes the state of a build.
43
+ enum BuildState { BS_InProgress, BS_Done, BS_Failed };
44
+
42
45
// / Denotes pointer to some entity with its general state and build error.
43
46
// / The pointer is not null if and only if the entity is usable.
44
47
// / State of the entity is provided by the user of cache instance.
45
48
// / Currently there is only a single user - ProgramManager class.
46
49
template <typename T> struct BuildResult {
47
50
std::atomic<T *> Ptr;
48
- std::atomic<int > State;
51
+ std::atomic<BuildState > State;
49
52
BuildError Error;
50
53
51
54
// / Condition variable to signal that build result is ready.
@@ -62,15 +65,23 @@ class KernelProgramCache {
62
65
// / A mutex to be employed along with MBuildCV.
63
66
std::mutex MBuildResultMutex;
64
67
65
- BuildResult (T *P, int S) : Ptr{P}, State{S}, Error{" " , 0 } {}
68
+ BuildResult (T *P, BuildState S) : Ptr{P}, State{S}, Error{" " , 0 } {}
66
69
};
67
70
68
71
using PiProgramT = std::remove_pointer<RT::PiProgram>::type;
69
72
using PiProgramPtrT = std::atomic<PiProgramT *>;
70
73
using ProgramWithBuildStateT = BuildResult<PiProgramT>;
71
74
using ProgramCacheKeyT = std::pair<std::pair<SerializedObj, std::uintptr_t >,
72
75
std::pair<RT::PiDevice, std::string>>;
73
- using ProgramCacheT = std::map<ProgramCacheKeyT, ProgramWithBuildStateT>;
76
+ using CommonProgramKeyT = std::pair<std::uintptr_t , RT::PiDevice>;
77
+
78
+ struct ProgramCache {
79
+ std::map<ProgramCacheKeyT, ProgramWithBuildStateT> Cache;
80
+ std::multimap<CommonProgramKeyT, ProgramCacheKeyT> KeyMap;
81
+
82
+ size_t size () const noexcept { return Cache.size (); }
83
+ };
84
+
74
85
using ContextPtr = context_impl *;
75
86
76
87
using PiKernelT = std::remove_pointer<RT::PiKernel>::type;
@@ -91,21 +102,66 @@ class KernelProgramCache {
91
102
92
103
void setContextPtr (const ContextPtr &AContext) { MParentContext = AContext; }
93
104
94
- Locked<ProgramCacheT > acquireCachedPrograms () {
105
+ Locked<ProgramCache > acquireCachedPrograms () {
95
106
return {MCachedPrograms, MProgramCacheMutex};
96
107
}
97
108
98
109
Locked<KernelCacheT> acquireKernelsPerProgramCache () {
99
110
return {MKernelsPerProgramCache, MKernelsPerProgramCacheMutex};
100
111
}
101
112
113
+ std::pair<ProgramWithBuildStateT *, bool >
114
+ getOrInsertProgram (const ProgramCacheKeyT &CacheKey) {
115
+ auto LockedCache = acquireCachedPrograms ();
116
+ auto &ProgCache = LockedCache.get ();
117
+ auto Inserted = ProgCache.Cache .emplace (
118
+ std::piecewise_construct, std::forward_as_tuple (CacheKey),
119
+ std::forward_as_tuple (nullptr , BS_InProgress));
120
+ if (Inserted.second ) {
121
+ // Save reference between the common key and the full key.
122
+ CommonProgramKeyT CommonKey =
123
+ std::make_pair (CacheKey.first .second , CacheKey.second .first );
124
+ ProgCache.KeyMap .emplace (std::piecewise_construct,
125
+ std::forward_as_tuple (CommonKey),
126
+ std::forward_as_tuple (CacheKey));
127
+ }
128
+ return std::make_pair (&Inserted.first ->second , Inserted.second );
129
+ }
130
+
131
+ std::pair<KernelWithBuildStateT *, bool >
132
+ getOrInsertKernel (RT::PiProgram Program, const std::string &KernelName) {
133
+ auto LockedCache = acquireKernelsPerProgramCache ();
134
+ auto &Cache = LockedCache.get ()[Program];
135
+ auto Inserted = Cache.emplace (
136
+ std::piecewise_construct, std::forward_as_tuple (KernelName),
137
+ std::forward_as_tuple (nullptr , BS_InProgress));
138
+ return std::make_pair (&Inserted.first ->second , Inserted.second );
139
+ }
140
+
102
141
template <typename T, class Predicate >
103
142
void waitUntilBuilt (BuildResult<T> &BR, Predicate Pred) const {
104
143
std::unique_lock<std::mutex> Lock (BR.MBuildResultMutex );
105
144
106
145
BR.MBuildCV .wait (Lock, Pred);
107
146
}
108
147
148
+ template <typename ExceptionT, typename RetT>
149
+ RetT *waitUntilBuilt (BuildResult<RetT> *BuildResult) {
150
+ // Any thread which will find nullptr in cache will wait until the pointer
151
+ // is not null anymore.
152
+ waitUntilBuilt (*BuildResult, [BuildResult]() {
153
+ int State = BuildResult->State .load ();
154
+ return State == BuildState::BS_Done || State == BuildState::BS_Failed;
155
+ });
156
+
157
+ if (BuildResult->Error .isFilledIn ()) {
158
+ const BuildError &Error = BuildResult->Error ;
159
+ throw ExceptionT (Error.Msg , Error.Code );
160
+ }
161
+
162
+ return BuildResult->Ptr .load ();
163
+ }
164
+
109
165
template <typename T> void notifyAllBuild (BuildResult<T> &BR) const {
110
166
BR.MBuildCV .notify_all ();
111
167
}
@@ -132,7 +188,7 @@ class KernelProgramCache {
132
188
// /
133
189
// / This member function should only be used in unit tests.
134
190
void reset () {
135
- MCachedPrograms = ProgramCacheT {};
191
+ MCachedPrograms = ProgramCache {};
136
192
MKernelsPerProgramCache = KernelCacheT{};
137
193
MKernelFastCache = KernelFastCacheT{};
138
194
}
@@ -141,7 +197,7 @@ class KernelProgramCache {
141
197
std::mutex MProgramCacheMutex;
142
198
std::mutex MKernelsPerProgramCacheMutex;
143
199
144
- ProgramCacheT MCachedPrograms;
200
+ ProgramCache MCachedPrograms;
145
201
KernelCacheT MKernelsPerProgramCache;
146
202
ContextPtr MParentContext;
147
203
0 commit comments