Skip to content

Commit cd6e808

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
[XLA:Python] Fix more bugs in the weakref_lru_cache implementation.
a) MSVC's std::unordered_map says behavior is undefined if the hash function throws an exception (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). That's easy to work around, though: we can just precompute all the hash functions. b) my idiom for avoiding heterogenous lookups had a use after free problem: the weakref callback is called after the object is already in an invalid state. However, there's a much simpler solution: just create the weakref object and use it as a key unconditionally. Yes, this will mean we create more weak references than perhaps we had to otherwise. But this is simple and obviously correct. PiperOrigin-RevId: 681522048
1 parent 6345728 commit cd6e808

File tree

3 files changed

+98
-45
lines changed

3 files changed

+98
-45
lines changed

xla/python/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,7 @@ cc_library(
11121112
# placeholder for index annotation deps
11131113
"@com_google_absl//absl/base:core_headers",
11141114
"@com_google_absl//absl/cleanup",
1115+
"@com_google_absl//absl/hash",
11151116
"@com_google_absl//absl/strings",
11161117
"@com_google_absl//absl/synchronization",
11171118
"@nanobind",

xla/python/weakref_lru_cache.cc

Lines changed: 74 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828

2929
#include "absl/base/thread_annotations.h"
3030
#include "absl/cleanup/cleanup.h"
31+
#include "absl/hash/hash.h"
3132
#include "absl/strings/str_cat.h"
3233
#include "absl/synchronization/mutex.h"
3334
#include "absl/synchronization/notification.h"
@@ -78,34 +79,58 @@ class HashablePyDictIter {
7879
nb::detail::dict_iterator& iter_;
7980
};
8081

82+
struct HashableKey {
83+
nb::object context;
84+
nb::args args;
85+
nb::kwargs kwargs;
86+
87+
template <typename H>
88+
friend H AbslHashValue(H h, const HashableKey& key) {
89+
// Note: Despite the fact this is an ABSL hash function, it's safe to call
90+
// functions that may throw exceptions such as nb::hash(), because it is
91+
// used by an LRUCache, which uses a std::unordered_map, which is
92+
// exception-safe.
93+
h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args));
94+
nb::detail::dict_iterator begin = key.kwargs.begin();
95+
nb::detail::dict_iterator end = key.kwargs.end();
96+
h = H::combine_unordered(std::move(h), HashablePyDictIter(begin),
97+
HashablePyDictIter(end));
98+
h = H::combine(std::move(h), key.kwargs.size());
99+
return h;
100+
}
101+
};
102+
81103
} // namespace
82104

83105
class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
84106
public:
85-
struct Key {
86-
nb::object context;
87-
nb::args args;
88-
nb::kwargs kwargs;
107+
class Key {
108+
public:
109+
Key(nb::object context, nb::args args, nb::kwargs kwargs)
110+
: context_(std::move(context)),
111+
args_(std::move(args)),
112+
kwargs_(std::move(kwargs)),
113+
cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {}
89114

90115
bool operator==(const Key& other) const {
91-
return context.equal(other.context) && args.equal(other.args) &&
92-
kwargs.equal(other.kwargs);
116+
return context_.equal(other.context_) && args_.equal(other.args_) &&
117+
kwargs_.equal(other.kwargs_);
93118
}
94119

95120
template <typename H>
96121
friend H AbslHashValue(H h, const Key& key) {
97-
// Note: Despite the fact this is an ABSL hash function, it's safe to call
98-
// functions that may throw exceptions such as nb::hash(), because it is
99-
// used by an LRUCache, which uses a std::unordered_map, which is
100-
// exception-safe.
101-
h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args));
102-
nb::detail::dict_iterator begin = key.kwargs.begin();
103-
nb::detail::dict_iterator end = key.kwargs.end();
104-
h = H::combine_unordered(std::move(h), HashablePyDictIter(begin),
105-
HashablePyDictIter(end));
106-
h = H::combine(std::move(h), key.kwargs.size());
107-
return h;
122+
return H::combine(std::move(h), key.cached_hash_);
108123
}
124+
125+
nb::object context() const { return context_; }
126+
nb::args args() const { return args_; }
127+
nb::kwargs kwargs() const { return kwargs_; }
128+
129+
private:
130+
nb::object context_;
131+
nb::args args_;
132+
nb::kwargs kwargs_;
133+
size_t cached_hash_;
109134
};
110135

111136
struct CacheEntry {
@@ -123,14 +148,13 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
123148
};
124149

125150
struct WeakrefCacheKey {
126-
nb::handle object;
151+
nb::weakref ref;
127152
size_t cached_hash;
128153
};
129154

130155
using Cache = xla::LRUCache<Key, std::shared_ptr<CacheEntry>>;
131156

132157
struct WeakrefCacheValue {
133-
std::optional<nb::weakref> weakref;
134158
std::shared_ptr<Cache> cache;
135159
};
136160

@@ -141,7 +165,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
141165
struct WeakrefKeyEq {
142166
bool operator()(const WeakrefCacheKey& lhs,
143167
const WeakrefCacheKey& rhs) const {
144-
return lhs.object.equal(rhs.object);
168+
return lhs.ref.equal(rhs.ref);
145169
}
146170
};
147171

@@ -150,43 +174,49 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
150174
: cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {}
151175

152176
std::shared_ptr<Cache> GetCache(WeakrefCacheKey key) {
153-
auto [it, inserted] = entries_.emplace(key, WeakrefCacheValue());
154-
if (!inserted) {
155-
return it->second.cache;
177+
WeakrefCacheValue& value = entries_[key];
178+
if (!value.cache) {
179+
value.cache = std::make_shared<Cache>(&lru_list_);
156180
}
181+
return value.cache;
182+
}
157183

158-
auto& value = it->second;
184+
nb::object Call(nb::object weakref_key, nb::args args,
185+
nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
186+
nb::object context = cache_context_fn_();
187+
188+
// We precompute all of the hash values needed by the various maps rather
189+
// than computing them during the std::unordered_map insertions. At the very
190+
// least, MSVC's std::unordered_map has undefined behavior if the hash
191+
// function throws an exception
192+
// (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace).
193+
Key key(context, args, kwargs);
194+
size_t wrcache_hash = static_cast<size_t>(nb::hash(weakref_key));
195+
196+
// No hash computations after this point.
159197

160-
value.cache = std::make_shared<Cache>(&lru_list_);
161198
auto weakref_gc_callback = nb::cpp_function(
162-
[this_weak = weak_from_this(), key](nb::handle weakref) {
199+
[this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) {
163200
auto cache = this_weak.lock();
164201
if (cache == nullptr) {
165202
return;
166203
}
167-
auto it = cache->entries_.find(key);
204+
// The object the reference referred to is now in the process of being
205+
// destroyed, so we cannot refer to its contents. Python weakref
206+
// objects compare based on identity if the object they refer to is
207+
// gone, so the hash lookup will work fine.
208+
auto it = cache->entries_.find(
209+
WeakrefCacheKey{nb::borrow<nb::weakref>(weakref), wrcache_hash});
168210
if (it == cache->entries_.end()) {
169211
return;
170212
}
171213
// Create temp-var to avoid re-entrant erase.
172214
auto tmp = std::move(it->second);
173215
cache->entries_.erase(it);
174216
});
175-
PyObject* ref =
176-
PyWeakref_NewRef(key.object.ptr(), weakref_gc_callback.ptr());
177-
if (!ref) {
178-
entries_.erase(it);
179-
throw nb::python_error();
180-
}
181-
value.weakref = nb::steal<nb::weakref>(ref);
182-
return value.cache;
183-
}
184-
185-
nb::object Call(nb::object weakref_key, nb::args args,
186-
nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
187-
nb::object context = cache_context_fn_();
188-
std::shared_ptr<Cache> cache_ptr = GetCache(WeakrefCacheKey{
189-
weakref_key, static_cast<size_t>(nb::hash(weakref_key))});
217+
nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback);
218+
WeakrefCacheKey wrcache_key{weakref, wrcache_hash};
219+
std::shared_ptr<Cache> cache_ptr = GetCache(wrcache_key);
190220
Cache& cache = *cache_ptr;
191221
++total_queries_;
192222

@@ -206,7 +236,6 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
206236
// released if that happens.
207237
absl::Cleanup unlock = [this]()
208238
ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); };
209-
Key key{context, args, kwargs};
210239
entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) {
211240
inserted = true;
212241
return std::make_shared<CacheEntry>();
@@ -245,8 +274,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
245274
for (const auto& wr_entry : entries_) {
246275
for (const auto& rest : *wr_entry.second.cache) {
247276
nb::tuple result =
248-
nb::make_tuple(*wr_entry.second.weakref, rest.first.context,
249-
rest.first.args, rest.first.kwargs);
277+
nb::make_tuple(*wr_entry.first.ref, rest.first.context(),
278+
rest.first.args(), rest.first.kwargs());
250279
results.push_back(std::move(result));
251280
}
252281
}

xla/python/weakref_lru_cache_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,29 @@ class WRKey:
160160
"WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)",
161161
)
162162

163+
def testGCKeys(self):
164+
class WRKey:
165+
166+
def __init__(self, x):
167+
self.x = x
168+
169+
def __eq__(self, other):
170+
return self.x == other.x
171+
172+
def __hash__(self):
173+
return hash(self.x)
174+
175+
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
176+
keys = [WRKey(i) for i in range(10)]
177+
for i in range(10):
178+
cache(keys[i], i)
179+
180+
# Delete some keys, to exercise the weakref callback behavior.
181+
del keys[::2]
182+
183+
for key in keys:
184+
cache(key, 7)
185+
163186

164187
if __name__ == "__main__":
165188
absltest.main()

0 commit comments

Comments
 (0)