Skip to content

Commit 1b803fe

Browse files
author
Jeff Niu
authored
[mlir] Optimize ThreadLocalCache by removing atomic bottleneck (#93270)
The ThreadLocalCache implementation is used by the MLIRContext (among other things) to try to manage thread contention in the StorageUniquers. There is a bunch of fancy shared pointer/weak pointer setups that basically keeps everything alive across threads at the right time, but a huge bottleneck is the `weak_ptr::lock` call inside the `::get` method. This is because the `lock` method has to hit the atomic refcount several times, and this is bottlenecking performance across many threads. However, all this is doing is checking whether the storage is initialized. We know that it cannot be an expired weak pointer because the thread local cache object we're calling into owns the memory and is still alive for the method call to be valid. Thus, we can store and extra `Value *` inside the thread local cache for speedy retrieval if the cache is already initialized for the thread, which is the common case. This also tightens the size of the critical section in the same method by scoping the mutex more to just the mutation on `perInstanceState`. Before: <img width="560" alt="image" src="https://github.com/llvm/llvm-project/assets/15016832/f4ea3f32-6649-4c10-88c4-b7522031e8c9"> After: <img width="344" alt="image" src="https://github.com/llvm/llvm-project/assets/15016832/1216db25-3dc1-4b0f-be89-caeff622dd35">
1 parent cd9bab2 commit 1b803fe

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

mlir/include/mlir/Support/ThreadLocalCache.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,12 @@ class ThreadLocalCache {
5858
/// ValueT. We use a weak reference here so that the object can be destroyed
5959
/// without needing to lock access to the cache itself.
6060
struct CacheType
61-
: public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
61+
: public llvm::SmallDenseMap<PerInstanceState *,
62+
std::pair<std::weak_ptr<ValueT>, ValueT *>> {
6263
~CacheType() {
6364
// Remove the values of this cache that haven't already expired.
6465
for (auto &it : *this)
65-
if (std::shared_ptr<ValueT> value = it.second.lock())
66+
if (std::shared_ptr<ValueT> value = it.second.first.lock())
6667
it.first->remove(value.get());
6768
}
6869

@@ -71,7 +72,7 @@ class ThreadLocalCache {
7172
void clearExpiredEntries() {
7273
for (auto it = this->begin(), e = this->end(); it != e;) {
7374
auto curIt = it++;
74-
if (curIt->second.expired())
75+
if (curIt->second.first.expired())
7576
this->erase(curIt);
7677
}
7778
}
@@ -88,22 +89,27 @@ class ThreadLocalCache {
8889
ValueT &get() {
8990
// Check for an already existing instance for this thread.
9091
CacheType &staticCache = getStaticCache();
91-
std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
92-
if (std::shared_ptr<ValueT> value = threadInstance.lock())
92+
std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
93+
staticCache[perInstanceState.get()];
94+
if (ValueT *value = threadInstance.second)
9395
return *value;
9496

9597
// Otherwise, create a new instance for this thread.
96-
llvm::sys::SmartScopedLock<true> threadInstanceLock(
97-
perInstanceState->instanceMutex);
98-
perInstanceState->instances.push_back(std::make_unique<ValueT>());
99-
ValueT *instance = perInstanceState->instances.back().get();
100-
threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
98+
{
99+
llvm::sys::SmartScopedLock<true> threadInstanceLock(
100+
perInstanceState->instanceMutex);
101+
threadInstance.second =
102+
perInstanceState->instances.emplace_back(std::make_unique<ValueT>())
103+
.get();
104+
}
105+
threadInstance.first =
106+
std::shared_ptr<ValueT>(perInstanceState, threadInstance.second);
101107

102108
// Before returning the new instance, take the chance to clear out any used
103109
// entries in the static map. The cache is only cleared within the same
104110
// thread to remove the need to lock the cache itself.
105111
staticCache.clearExpiredEntries();
106-
return *instance;
112+
return *threadInstance.second;
107113
}
108114
ValueT &operator*() { return get(); }
109115
ValueT *operator->() { return &get(); }

0 commit comments

Comments
 (0)