@@ -28,6 +28,7 @@ limitations under the License.
28
28
29
29
#include " absl/base/thread_annotations.h"
30
30
#include " absl/cleanup/cleanup.h"
31
+ #include " absl/hash/hash.h"
31
32
#include " absl/strings/str_cat.h"
32
33
#include " absl/synchronization/mutex.h"
33
34
#include " absl/synchronization/notification.h"
@@ -78,34 +79,58 @@ class HashablePyDictIter {
78
79
nb::detail::dict_iterator& iter_;
79
80
};
80
81
82
+ struct HashableKey {
83
+ nb::object context;
84
+ nb::args args;
85
+ nb::kwargs kwargs;
86
+
87
+ template <typename H>
88
+ friend H AbslHashValue (H h, const HashableKey& key) {
89
+ // Note: Despite the fact this is an ABSL hash function, it's safe to call
90
+ // functions that may throw exceptions such as nb::hash(), because it is
91
+ // used by an LRUCache, which uses a std::unordered_map, which is
92
+ // exception-safe.
93
+ h = H::combine (std::move (h), nb::hash (key.context ), nb::hash (key.args ));
94
+ nb::detail::dict_iterator begin = key.kwargs .begin ();
95
+ nb::detail::dict_iterator end = key.kwargs .end ();
96
+ h = H::combine_unordered (std::move (h), HashablePyDictIter (begin),
97
+ HashablePyDictIter (end));
98
+ h = H::combine (std::move (h), key.kwargs .size ());
99
+ return h;
100
+ }
101
+ };
102
+
81
103
} // namespace
82
104
83
105
class WeakrefLRUCache : public std ::enable_shared_from_this<WeakrefLRUCache> {
84
106
public:
85
- struct Key {
86
- nb::object context;
87
- nb::args args;
88
- nb::kwargs kwargs;
107
+ class Key {
108
+ public:
109
+ Key (nb::object context, nb::args args, nb::kwargs kwargs)
110
+ : context_(std::move(context)),
111
+ args_ (std::move(args)),
112
+ kwargs_(std::move(kwargs)),
113
+ cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {}
89
114
90
115
bool operator ==(const Key& other) const {
91
- return context .equal (other.context ) && args .equal (other.args ) &&
92
- kwargs .equal (other.kwargs );
116
+ return context_ .equal (other.context_ ) && args_ .equal (other.args_ ) &&
117
+ kwargs_ .equal (other.kwargs_ );
93
118
}
94
119
95
120
template <typename H>
96
121
friend H AbslHashValue (H h, const Key& key) {
97
- // Note: Despite the fact this is an ABSL hash function, it's safe to call
98
- // functions that may throw exceptions such as nb::hash(), because it is
99
- // used by an LRUCache, which uses a std::unordered_map, which is
100
- // exception-safe.
101
- h = H::combine (std::move (h), nb::hash (key.context ), nb::hash (key.args ));
102
- nb::detail::dict_iterator begin = key.kwargs .begin ();
103
- nb::detail::dict_iterator end = key.kwargs .end ();
104
- h = H::combine_unordered (std::move (h), HashablePyDictIter (begin),
105
- HashablePyDictIter (end));
106
- h = H::combine (std::move (h), key.kwargs .size ());
107
- return h;
122
+ return H::combine (std::move (h), key.cached_hash_ );
108
123
}
124
+
125
+ nb::object context () const { return context_; }
126
+ nb::args args () const { return args_; }
127
+ nb::kwargs kwargs () const { return kwargs_; }
128
+
129
+ private:
130
+ nb::object context_;
131
+ nb::args args_;
132
+ nb::kwargs kwargs_;
133
+ size_t cached_hash_;
109
134
};
110
135
111
136
struct CacheEntry {
@@ -123,14 +148,13 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
123
148
};
124
149
125
150
struct WeakrefCacheKey {
126
- nb::handle object ;
151
+ nb::weakref ref ;
127
152
size_t cached_hash;
128
153
};
129
154
130
155
using Cache = xla::LRUCache<Key, std::shared_ptr<CacheEntry>>;
131
156
132
157
struct WeakrefCacheValue {
133
- std::optional<nb::weakref> weakref;
134
158
std::shared_ptr<Cache> cache;
135
159
};
136
160
@@ -141,7 +165,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
141
165
struct WeakrefKeyEq {
142
166
bool operator ()(const WeakrefCacheKey& lhs,
143
167
const WeakrefCacheKey& rhs) const {
144
- return lhs.object .equal (rhs.object );
168
+ return lhs.ref .equal (rhs.ref );
145
169
}
146
170
};
147
171
@@ -150,43 +174,49 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
150
174
: cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {}
151
175
152
176
std::shared_ptr<Cache> GetCache (WeakrefCacheKey key) {
153
- auto [it, inserted] = entries_. emplace ( key, WeakrefCacheValue ()) ;
154
- if (!inserted ) {
155
- return it-> second .cache ;
177
+ WeakrefCacheValue& value = entries_[ key] ;
178
+ if (!value. cache ) {
179
+ value .cache = std::make_shared<Cache>(&lru_list_) ;
156
180
}
181
+ return value.cache ;
182
+ }
157
183
158
- auto & value = it->second ;
184
+ nb::object Call (nb::object weakref_key, nb::args args,
185
+ nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
186
+ nb::object context = cache_context_fn_ ();
187
+
188
+ // We precompute all of the hash values needed by the various maps rather
189
+ // than computing them during the std::unordered_map insertions. At the very
190
+ // least, MSVC's std::unordered_map has undefined behavior if the hash
191
+ // function throws an exception
192
+ // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace).
193
+ Key key (context, args, kwargs);
194
+ size_t wrcache_hash = static_cast <size_t >(nb::hash (weakref_key));
195
+
196
+ // No hash computations after this point.
159
197
160
- value.cache = std::make_shared<Cache>(&lru_list_);
161
198
auto weakref_gc_callback = nb::cpp_function (
162
- [this_weak = weak_from_this (), key ](nb::handle weakref) {
199
+ [this_weak = weak_from_this (), wrcache_hash ](nb::handle weakref) {
163
200
auto cache = this_weak.lock ();
164
201
if (cache == nullptr ) {
165
202
return ;
166
203
}
167
- auto it = cache->entries_ .find (key);
204
+ // The object the reference referred to is now in the process of being
205
+ // destroyed, so we cannot refer to its contents. Python weakref
206
+ // objects compare based on identity if the object they refer to is
207
+ // gone, so the hash lookup will work fine.
208
+ auto it = cache->entries_ .find (
209
+ WeakrefCacheKey{nb::borrow<nb::weakref>(weakref), wrcache_hash});
168
210
if (it == cache->entries_ .end ()) {
169
211
return ;
170
212
}
171
213
// Create temp-var to avoid re-entrant erase.
172
214
auto tmp = std::move (it->second );
173
215
cache->entries_ .erase (it);
174
216
});
175
- PyObject* ref =
176
- PyWeakref_NewRef (key.object .ptr (), weakref_gc_callback.ptr ());
177
- if (!ref) {
178
- entries_.erase (it);
179
- throw nb::python_error ();
180
- }
181
- value.weakref = nb::steal<nb::weakref>(ref);
182
- return value.cache ;
183
- }
184
-
185
- nb::object Call (nb::object weakref_key, nb::args args,
186
- nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
187
- nb::object context = cache_context_fn_ ();
188
- std::shared_ptr<Cache> cache_ptr = GetCache (WeakrefCacheKey{
189
- weakref_key, static_cast <size_t >(nb::hash (weakref_key))});
217
+ nb::weakref weakref = nb::weakref (weakref_key, weakref_gc_callback);
218
+ WeakrefCacheKey wrcache_key{weakref, wrcache_hash};
219
+ std::shared_ptr<Cache> cache_ptr = GetCache (wrcache_key);
190
220
Cache& cache = *cache_ptr;
191
221
++total_queries_;
192
222
@@ -206,7 +236,6 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
206
236
// released if that happens.
207
237
absl::Cleanup unlock = [this ]()
208
238
ABSL_UNLOCK_FUNCTION (mu_) { mu_.Unlock (); };
209
- Key key{context, args, kwargs};
210
239
entry = cache.GetOrCreateIfAbsent (key, [&inserted](const Key& key) {
211
240
inserted = true ;
212
241
return std::make_shared<CacheEntry>();
@@ -245,8 +274,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
245
274
for (const auto & wr_entry : entries_) {
246
275
for (const auto & rest : *wr_entry.second .cache ) {
247
276
nb::tuple result =
248
- nb::make_tuple (*wr_entry.second . weakref , rest.first .context ,
249
- rest.first .args , rest.first .kwargs );
277
+ nb::make_tuple (*wr_entry.first . ref , rest.first .context () ,
278
+ rest.first .args () , rest.first .kwargs () );
250
279
results.push_back (std::move (result));
251
280
}
252
281
}
0 commit comments