-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[ADT] Make DenseMap/DenseSet more resilient against OOM situations #107251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ADT] Make DenseMap/DenseSet more resilient against OOM situations #107251
Conversation
ddf904f
to
84e90db
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
07aaed1
to
bb74000
Compare
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-llvm-adt Author: Marc Auberer (marcauberer) ChangesThis patch hardens DenseMap and DenseSet against OOM scenarios, outlined in this RFC. Feel free to comment any comments/suggestions/objections there. As discussed in the RFC, we would eventually like to have malfunction tests to ensure malfunction-safety. This is still in internal discussion on our side. Therefore we would like to contribute the improvements with the existing functional tests for now. Patch is 53.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107251.diff 4 Files Affected:
diff --git a/llvm/include/llvm/ADT/DenseMap.h b/llvm/include/llvm/ADT/DenseMap.h
index 00290c9dd0a585..6cc857f906e910 100644
--- a/llvm/include/llvm/ADT/DenseMap.h
+++ b/llvm/include/llvm/ADT/DenseMap.h
@@ -16,6 +16,7 @@
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/EpochTracker.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/AlignOf.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/MathExtras.h"
@@ -36,8 +37,6 @@ namespace llvm {
namespace detail {
-// We extend a pair to allow users to override the bucket type with their own
-// implementation without requiring two members.
template <typename KeyT, typename ValueT>
struct DenseMapPair : public std::pair<KeyT, ValueT> {
using std::pair<KeyT, ValueT>::pair;
@@ -48,16 +47,72 @@ struct DenseMapPair : public std::pair<KeyT, ValueT> {
const ValueT &getSecond() const { return std::pair<KeyT, ValueT>::second; }
};
+// OOM safe DenseMap bucket
+// flags to indicate if key and value are constructed
+// this is necessary since assigning empty key or tombstone key throws for some
+// key types
+
+template <typename KeyT, typename ValueT, typename Enable = void>
+struct DenseMapPairImpl;
+
+// part I: helpers to decide which specialization of DenseMapPairImpl to use
+template <typename KeyT, typename ValueT>
+using EnableIfTriviallyDestructibleMap =
+ typename std::enable_if_t<std::is_trivially_destructible_v<KeyT> &&
+ std::is_trivially_destructible_v<ValueT>>;
+template <typename KeyT, typename ValueT>
+using EnableIfNotTriviallyDestructibleMap =
+ typename std::enable_if_t<!std::is_trivially_destructible_v<KeyT> ||
+ !std::is_trivially_destructible_v<ValueT>>;
+
+// OOM safe DenseMap bucket
+// part II: specialization for trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+ EnableIfTriviallyDestructibleMap<KeyT, ValueT>>
+ : public DenseMapPair<KeyT, ValueT> {
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool) {}
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const { return false; }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool) {}
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const { return false; }
+};
+
+// OOM safe DenseMap bucket
+// part III: specialization for non-trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+ EnableIfNotTriviallyDestructibleMap<KeyT, ValueT>>
+ : DenseMapPair<KeyT, ValueT> {
+private:
+ bool m_keyConstructed;
+ bool m_valueConstructed;
+
+public:
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool constructed) {
+ m_keyConstructed = constructed;
+ }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const {
+ return m_keyConstructed;
+ }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool constructed) {
+ m_valueConstructed = constructed;
+ }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const {
+ return m_valueConstructed;
+ }
+};
+
} // end namespace detail
template <typename KeyT, typename ValueT,
typename KeyInfoT = DenseMapInfo<KeyT>,
- typename Bucket = llvm::detail::DenseMapPair<KeyT, ValueT>,
+ typename BucketT = llvm::detail::DenseMapPairImpl<KeyT, ValueT>,
+ typename BucketBaseT = llvm::detail::DenseMapPair<KeyT, ValueT>,
bool IsConst = false>
class DenseMapIterator;
template <typename DerivedT, typename KeyT, typename ValueT, typename KeyInfoT,
- typename BucketT>
+ typename BucketT, typename BucketBaseT>
class DenseMapBase : public DebugEpochBase {
template <typename T>
using const_arg_type_t = typename const_pointer_or_const_ref<T>::type;
@@ -66,11 +121,12 @@ class DenseMapBase : public DebugEpochBase {
using size_type = unsigned;
using key_type = KeyT;
using mapped_type = ValueT;
- using value_type = BucketT;
+ using value_type = BucketBaseT;
- using iterator = DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT>;
+ using iterator =
+ DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT>;
using const_iterator =
- DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, true>;
+ DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT, true>;
inline iterator begin() {
// When the map is empty, avoid the overhead of advancing/retreating past
@@ -109,7 +165,8 @@ class DenseMapBase : public DebugEpochBase {
void clear() {
incrementEpoch();
- if (getNumEntries() == 0 && getNumTombstones() == 0) return;
+ if (getNumEntries() == 0 && getNumTombstones() == 0)
+ return;
// If the capacity of the array is huge, and the # elements used is small,
// shrink the array.
@@ -119,24 +176,20 @@ class DenseMapBase : public DebugEpochBase {
}
const KeyT EmptyKey = getEmptyKey();
- if (std::is_trivially_destructible<ValueT>::value) {
+ if constexpr (std::is_trivially_destructible_v<ValueT>) {
// Use a simpler loop when values don't need destruction.
for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P)
P->getFirst() = EmptyKey;
} else {
- const KeyT TombstoneKey = getTombstoneKey();
- unsigned NumEntries = getNumEntries();
for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey)) {
- if (!KeyInfoT::isEqual(P->getFirst(), TombstoneKey)) {
+ if (P->isValueConstructed()) {
P->getSecond().~ValueT();
- --NumEntries;
+ P->setValueConstructed(false);
}
P->getFirst() = EmptyKey;
}
}
- assert(NumEntries == 0 && "Node count imbalance!");
- (void)NumEntries;
}
setNumEntries(0);
setNumTombstones(0);
@@ -172,15 +225,14 @@ class DenseMapBase : public DebugEpochBase {
/// The DenseMapInfo is responsible for supplying methods
/// getHashValue(LookupKeyT) and isEqual(LookupKeyT, KeyT) for each key
/// type used.
- template<class LookupKeyT>
- iterator find_as(const LookupKeyT &Val) {
+ template <class LookupKeyT> iterator find_as(const LookupKeyT &Val) {
if (BucketT *Bucket = doFind(Val))
return makeIterator(
Bucket, shouldReverseIterate<KeyT>() ? getBuckets() : getBucketsEnd(),
*this, true);
return end();
}
- template<class LookupKeyT>
+ template <class LookupKeyT>
const_iterator find_as(const LookupKeyT &Val) const {
if (const BucketT *Bucket = doFind(Val))
return makeConstIterator(
@@ -223,7 +275,7 @@ class DenseMapBase : public DebugEpochBase {
// The value is constructed in-place if the key is not in the map, otherwise
// it is not moved.
template <typename... Ts>
- std::pair<iterator, bool> try_emplace(KeyT &&Key, Ts &&... Args) {
+ std::pair<iterator, bool> try_emplace(KeyT &&Key, Ts &&...Args) {
BucketT *TheBucket;
if (LookupBucketFor(Key, TheBucket))
return std::make_pair(makeIterator(TheBucket,
@@ -248,7 +300,7 @@ class DenseMapBase : public DebugEpochBase {
// The value is constructed in-place if the key is not in the map, otherwise
// it is not moved.
template <typename... Ts>
- std::pair<iterator, bool> try_emplace(const KeyT &Key, Ts &&... Args) {
+ std::pair<iterator, bool> try_emplace(const KeyT &Key, Ts &&...Args) {
BucketT *TheBucket;
if (LookupBucketFor(Key, TheBucket))
return std::make_pair(makeIterator(TheBucket,
@@ -297,8 +349,7 @@ class DenseMapBase : public DebugEpochBase {
}
/// insert - Range insertion of pairs.
- template<typename InputIt>
- void insert(InputIt I, InputIt E) {
+ template <typename InputIt> void insert(InputIt I, InputIt E) {
for (; I != E; ++I)
insert(*I);
}
@@ -341,17 +392,22 @@ class DenseMapBase : public DebugEpochBase {
return false; // not in map.
TheBucket->getSecond().~ValueT();
+ TheBucket->setValueConstructed(false);
TheBucket->getFirst() = getTombstoneKey();
decrementNumEntries();
incrementNumTombstones();
return true;
}
void erase(iterator I) {
- BucketT *TheBucket = &*I;
- TheBucket->getSecond().~ValueT();
- TheBucket->getFirst() = getTombstoneKey();
- decrementNumEntries();
- incrementNumTombstones();
+ BucketT *TheBucket = static_cast<BucketT *>(&*I);
+ // Iterator can point to nullptr in case of memory malfunctions
+ if (TheBucket != nullptr) {
+ TheBucket->getSecond().~ValueT();
+ TheBucket->setValueConstructed(false);
+ TheBucket->getFirst() = getTombstoneKey();
+ decrementNumEntries();
+ incrementNumTombstones();
+ }
}
LLVM_DEPRECATED("Use [Key] instead", "[Key]")
@@ -404,15 +460,24 @@ class DenseMapBase : public DebugEpochBase {
DenseMapBase() = default;
void destroyAll() {
- if (getNumBuckets() == 0) // Nothing to do.
- return;
-
- const KeyT EmptyKey = getEmptyKey(), TombstoneKey = getTombstoneKey();
for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
- if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey) &&
- !KeyInfoT::isEqual(P->getFirst(), TombstoneKey))
+ if (P->isValueConstructed()) {
P->getSecond().~ValueT();
- P->getFirst().~KeyT();
+ P->setValueConstructed(false);
+ }
+ if (P->isKeyConstructed()) {
+ P->getFirst().~KeyT();
+ P->setKeyConstructed(false);
+ }
+ }
+ }
+
+ void initUnitialized() {
+ BucketT *B = getBuckets();
+ BucketT *E = getBucketsEnd();
+ for (; B != E; ++B) {
+ B->setKeyConstructed(false);
+ B->setValueConstructed(false);
}
}
@@ -420,11 +485,19 @@ class DenseMapBase : public DebugEpochBase {
setNumEntries(0);
setNumTombstones(0);
- assert((getNumBuckets() & (getNumBuckets()-1)) == 0 &&
+ assert((getNumBuckets() & (getNumBuckets() - 1)) == 0 &&
"# initial buckets must be a power of two!");
const KeyT EmptyKey = getEmptyKey();
- for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B)
+#ifndef NDEBUG
+ for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
+ assert(!B->isKeyConstructed());
+ assert(!B->isValueConstructed());
+ }
+#endif
+ for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
::new (&B->getFirst()) KeyT(EmptyKey);
+ B->setKeyConstructed(true);
+ }
}
/// Returns the number of buckets to allocate to ensure that the DenseMap can
@@ -454,18 +527,23 @@ class DenseMapBase : public DebugEpochBase {
assert(!FoundVal && "Key already in new map?");
DestBucket->getFirst() = std::move(B->getFirst());
::new (&DestBucket->getSecond()) ValueT(std::move(B->getSecond()));
+ DestBucket->setValueConstructed(true);
incrementNumEntries();
-
- // Free the value.
+ }
+ if (B->isValueConstructed()) {
B->getSecond().~ValueT();
+ B->setValueConstructed(false);
+ }
+ if (B->isKeyConstructed()) {
+ B->getFirst().~KeyT();
+ B->setKeyConstructed(false);
}
- B->getFirst().~KeyT();
}
}
template <typename OtherBaseT>
- void copyFrom(
- const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT> &other) {
+ void copyFrom(const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT,
+ BucketBaseT> &other) {
assert(&other != this);
assert(getNumBuckets() == other.getNumBuckets());
@@ -480,10 +558,13 @@ class DenseMapBase : public DebugEpochBase {
for (size_t i = 0; i < getNumBuckets(); ++i) {
::new (&getBuckets()[i].getFirst())
KeyT(other.getBuckets()[i].getFirst());
+ getBuckets()[i].setKeyConstructed(true);
if (!KeyInfoT::isEqual(getBuckets()[i].getFirst(), getEmptyKey()) &&
- !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey()))
+ !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey())) {
::new (&getBuckets()[i].getSecond())
ValueT(other.getBuckets()[i].getSecond());
+ getBuckets()[i].setValueConstructed(true);
+ }
}
}
@@ -491,7 +572,7 @@ class DenseMapBase : public DebugEpochBase {
return KeyInfoT::getHashValue(Val);
}
- template<typename LookupKeyT>
+ template <typename LookupKeyT>
static unsigned getHashValue(const LookupKeyT &Val) {
return KeyInfoT::getHashValue(Val);
}
@@ -502,14 +583,11 @@ class DenseMapBase : public DebugEpochBase {
return KeyInfoT::getEmptyKey();
}
- static const KeyT getTombstoneKey() {
- return KeyInfoT::getTombstoneKey();
- }
+ static const KeyT getTombstoneKey() { return KeyInfoT::getTombstoneKey(); }
private:
- iterator makeIterator(BucketT *P, BucketT *E,
- DebugEpochBase &Epoch,
- bool NoAdvance=false) {
+ iterator makeIterator(BucketT *P, BucketT *E, DebugEpochBase &Epoch,
+ bool NoAdvance = false) {
if (shouldReverseIterate<KeyT>()) {
BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1;
return iterator(B, E, Epoch, NoAdvance);
@@ -519,7 +597,7 @@ class DenseMapBase : public DebugEpochBase {
const_iterator makeConstIterator(const BucketT *P, const BucketT *E,
const DebugEpochBase &Epoch,
- const bool NoAdvance=false) const {
+ const bool NoAdvance = false) const {
if (shouldReverseIterate<KeyT>()) {
const BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1;
return const_iterator(B, E, Epoch, NoAdvance);
@@ -535,13 +613,9 @@ class DenseMapBase : public DebugEpochBase {
static_cast<DerivedT *>(this)->setNumEntries(Num);
}
- void incrementNumEntries() {
- setNumEntries(getNumEntries() + 1);
- }
+ void incrementNumEntries() { setNumEntries(getNumEntries() + 1); }
- void decrementNumEntries() {
- setNumEntries(getNumEntries() - 1);
- }
+ void decrementNumEntries() { setNumEntries(getNumEntries() - 1); }
unsigned getNumTombstones() const {
return static_cast<const DerivedT *>(this)->getNumTombstones();
@@ -551,65 +625,52 @@ class DenseMapBase : public DebugEpochBase {
static_cast<DerivedT *>(this)->setNumTombstones(Num);
}
- void incrementNumTombstones() {
- setNumTombstones(getNumTombstones() + 1);
- }
+ void incrementNumTombstones() { setNumTombstones(getNumTombstones() + 1); }
- void decrementNumTombstones() {
- setNumTombstones(getNumTombstones() - 1);
- }
+ void decrementNumTombstones() { setNumTombstones(getNumTombstones() - 1); }
const BucketT *getBuckets() const {
return static_cast<const DerivedT *>(this)->getBuckets();
}
- BucketT *getBuckets() {
- return static_cast<DerivedT *>(this)->getBuckets();
- }
+ BucketT *getBuckets() { return static_cast<DerivedT *>(this)->getBuckets(); }
unsigned getNumBuckets() const {
return static_cast<const DerivedT *>(this)->getNumBuckets();
}
- BucketT *getBucketsEnd() {
- return getBuckets() + getNumBuckets();
- }
+ BucketT *getBucketsEnd() { return getBuckets() + getNumBuckets(); }
const BucketT *getBucketsEnd() const {
return getBuckets() + getNumBuckets();
}
- void grow(unsigned AtLeast) {
- static_cast<DerivedT *>(this)->grow(AtLeast);
- }
+ void grow(unsigned AtLeast) { static_cast<DerivedT *>(this)->grow(AtLeast); }
- void shrink_and_clear() {
- static_cast<DerivedT *>(this)->shrink_and_clear();
- }
+ void shrink_and_clear() { static_cast<DerivedT *>(this)->shrink_and_clear(); }
template <typename KeyArg, typename... ValueArgs>
BucketT *InsertIntoBucket(BucketT *TheBucket, KeyArg &&Key,
- ValueArgs &&... Values) {
- TheBucket = InsertIntoBucketImpl(Key, Key, TheBucket);
-
+ ValueArgs &&...Values) {
+ TheBucket = InsertIntoBucketImpl(Key, TheBucket);
TheBucket->getFirst() = std::forward<KeyArg>(Key);
::new (&TheBucket->getSecond()) ValueT(std::forward<ValueArgs>(Values)...);
+ TheBucket->setValueConstructed(true);
return TheBucket;
}
template <typename LookupKeyT>
BucketT *InsertIntoBucketWithLookup(BucketT *TheBucket, KeyT &&Key,
ValueT &&Value, LookupKeyT &Lookup) {
- TheBucket = InsertIntoBucketImpl(Key, Lookup, TheBucket);
-
+ TheBucket = InsertIntoBucketImpl(Lookup, TheBucket);
TheBucket->getFirst() = std::move(Key);
::new (&TheBucket->getSecond()) ValueT(std::move(Value));
+ TheBucket->setValueConstructed(true);
return TheBucket;
}
template <typename LookupKeyT>
- BucketT *InsertIntoBucketImpl(const KeyT &Key, const LookupKeyT &Lookup,
- BucketT *TheBucket) {
+ BucketT *InsertIntoBucketImpl(const LookupKeyT &Lookup, BucketT *TheBucket) {
incrementEpoch();
// If the load of the hash table is more than 3/4, or if fewer than 1/8 of
@@ -627,8 +688,9 @@ class DenseMapBase : public DebugEpochBase {
this->grow(NumBuckets * 2);
LookupBucketFor(Lookup, TheBucket);
NumBuckets = getNumBuckets();
- } else if (LLVM_UNLIKELY(NumBuckets-(NewNumEntries+getNumTombstones()) <=
- NumBuckets/8)) {
+ } else if (LLVM_UNLIKELY(NumBuckets -
+ (NewNumEntries + getNumTombstones()) <=
+ NumBuckets / 8)) {
this->grow(NumBuckets);
LookupBucketFor(Lookup, TheBucket);
}
@@ -696,7 +758,7 @@ class DenseMapBase : public DebugEpochBase {
!KeyInfoT::isEqual(Val, TombstoneKey) &&
"Empty/Tombstone value shouldn't be inserted into map!");
- unsigned BucketNo = getHashValue(Val) & (NumBuckets-1);
+ unsigned BucketNo = getHashValue(Val) & (NumBuckets - 1);
unsigned ProbeAmt = 1;
while (true) {
BucketT *ThisBucket = BucketsPtr + BucketNo;
@@ -719,23 +781,63 @@ class DenseMapBase : public DebugEpochBase {
// prefer to return it than something that would require more probing.
if (KeyInfoT::isEqual(ThisBucket->getFirst(), TombstoneKey) &&
!FoundTombstone)
- FoundTombstone = ThisBucket; // Remember the first tombstone found.
+ FoundTombstone = ThisBucket; // Remember the first tombstone found.
// Otherwise, it's a hash collision or a tombstone, continue quadratic
// probing.
BucketNo += ProbeAmt++;
- BucketNo &= (NumBuckets-1);
+ BucketNo &= (NumBuckets - 1);
}
}
+protected:
+ // helper class to guarantee deallocation of buffer
+ class ReleaseOldBuffer {
+ BucketT *m_buckets;
+ unsigned m_numBuckets;
+
+ public:
+ ReleaseOldBuffer(BucketT *buckets, unsigned numBuckets)
+ : m_buckets(buckets), m_numBuckets(numBuckets) {}
+ ~ReleaseOldBuffer() {
+#ifndef NDEBUG
+ memset((void *)m_buckets, 0x5a, sizeof(BucketT) * m_numBuckets);
+#endif
+ // Free the old table.
+ size_t const alignment{alignof(BucketT)};
+ deallocate_buffer(static_cast<void *>(m_buckets),
+ sizeof(BucketT) * m_numBuckets, alignment);
+ }
+ };
+
+ // helper class to guarantee destruction of bucket content
+ class ReleaseOldBuckets {
+ BucketT *m_buckets;
+ unsigned m_numBuckets;
+
+ public:
+ ReleaseOldBuckets(BucketT *buckets, unsigned numBuckets)
+ : m_buckets(buckets), m_numBuckets(numBuckets) {}
+ ~ReleaseOldBuckets() {
+ for (BucketT *B = m_buckets, *E = m_buckets + m_numBuckets; B != E; ++B) {
+ if (B->isValueConstructed()) {
+ B->getSecond().~ValueT();
+ B->setValueConstructed(false);
+ }
+ if (B->isKeyConstructed()) {
+ B->getFirst().~KeyT();
+ B->setKeyConstructed(false);
+ }
+ }
+ }
+ };
+
public:
/// Return the approximate size (in bytes) of the actual map.
/// This is just the raw memory used by DenseMap.
/// If entries are pointers to objects, the size of the referenced objects
/// are not included.
- size_t getMemorySize() const {
- return getNumBuckets() * sizeof(BucketT);
- }
+ size...
[truncated]
|
@llvm/pr-subscribers-mlir-core Author: Marc Auberer (marcauberer) ChangesThis patch hardens DenseMap and DenseSet against OOM scenarios, outlined in this RFC. Feel free to comment any comments/suggestions/objections there. As discussed in the RFC, we would eventually like to have malfunction tests to ensure malfunction-safety. This is still in internal discussion on our side. Therefore we would like to contribute the improvements with the existing functional tests for now. Patch is 53.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107251.diff 4 Files Affected:
diff --git a/llvm/include/llvm/ADT/DenseMap.h b/llvm/include/llvm/ADT/DenseMap.h
index 00290c9dd0a585..6cc857f906e910 100644
--- a/llvm/include/llvm/ADT/DenseMap.h
+++ b/llvm/include/llvm/ADT/DenseMap.h
@@ -16,6 +16,7 @@
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/EpochTracker.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/AlignOf.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/MathExtras.h"
@@ -36,8 +37,6 @@ namespace llvm {
namespace detail {
-// We extend a pair to allow users to override the bucket type with their own
-// implementation without requiring two members.
template <typename KeyT, typename ValueT>
struct DenseMapPair : public std::pair<KeyT, ValueT> {
using std::pair<KeyT, ValueT>::pair;
@@ -48,16 +47,72 @@ struct DenseMapPair : public std::pair<KeyT, ValueT> {
const ValueT &getSecond() const { return std::pair<KeyT, ValueT>::second; }
};
+// OOM safe DenseMap bucket
+// flags to indicate if key and value are constructed
+// this is necessary since assigning empty key or tombstone key throws for some
+// key types
+
+template <typename KeyT, typename ValueT, typename Enable = void>
+struct DenseMapPairImpl;
+
+// part I: helpers to decide which specialization of DenseMapPairImpl to use
+template <typename KeyT, typename ValueT>
+using EnableIfTriviallyDestructibleMap =
+ typename std::enable_if_t<std::is_trivially_destructible_v<KeyT> &&
+ std::is_trivially_destructible_v<ValueT>>;
+template <typename KeyT, typename ValueT>
+using EnableIfNotTriviallyDestructibleMap =
+ typename std::enable_if_t<!std::is_trivially_destructible_v<KeyT> ||
+ !std::is_trivially_destructible_v<ValueT>>;
+
+// OOM safe DenseMap bucket
+// part II: specialization for trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+ EnableIfTriviallyDestructibleMap<KeyT, ValueT>>
+ : public DenseMapPair<KeyT, ValueT> {
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool) {}
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const { return false; }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool) {}
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const { return false; }
+};
+
+// OOM safe DenseMap bucket
+// part III: specialization for non-trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+ EnableIfNotTriviallyDestructibleMap<KeyT, ValueT>>
+ : DenseMapPair<KeyT, ValueT> {
+private:
+ bool m_keyConstructed;
+ bool m_valueConstructed;
+
+public:
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool constructed) {
+ m_keyConstructed = constructed;
+ }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const {
+ return m_keyConstructed;
+ }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool constructed) {
+ m_valueConstructed = constructed;
+ }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const {
+ return m_valueConstructed;
+ }
+};
+
} // end namespace detail
template <typename KeyT, typename ValueT,
typename KeyInfoT = DenseMapInfo<KeyT>,
- typename Bucket = llvm::detail::DenseMapPair<KeyT, ValueT>,
+ typename BucketT = llvm::detail::DenseMapPairImpl<KeyT, ValueT>,
+ typename BucketBaseT = llvm::detail::DenseMapPair<KeyT, ValueT>,
bool IsConst = false>
class DenseMapIterator;
template <typename DerivedT, typename KeyT, typename ValueT, typename KeyInfoT,
- typename BucketT>
+ typename BucketT, typename BucketBaseT>
class DenseMapBase : public DebugEpochBase {
template <typename T>
using const_arg_type_t = typename const_pointer_or_const_ref<T>::type;
@@ -66,11 +121,12 @@ class DenseMapBase : public DebugEpochBase {
using size_type = unsigned;
using key_type = KeyT;
using mapped_type = ValueT;
- using value_type = BucketT;
+ using value_type = BucketBaseT;
- using iterator = DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT>;
+ using iterator =
+ DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT>;
using const_iterator =
- DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, true>;
+ DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT, true>;
inline iterator begin() {
// When the map is empty, avoid the overhead of advancing/retreating past
@@ -109,7 +165,8 @@ class DenseMapBase : public DebugEpochBase {
void clear() {
incrementEpoch();
- if (getNumEntries() == 0 && getNumTombstones() == 0) return;
+ if (getNumEntries() == 0 && getNumTombstones() == 0)
+ return;
// If the capacity of the array is huge, and the # elements used is small,
// shrink the array.
@@ -119,24 +176,20 @@ class DenseMapBase : public DebugEpochBase {
}
const KeyT EmptyKey = getEmptyKey();
- if (std::is_trivially_destructible<ValueT>::value) {
+ if constexpr (std::is_trivially_destructible_v<ValueT>) {
// Use a simpler loop when values don't need destruction.
for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P)
P->getFirst() = EmptyKey;
} else {
- const KeyT TombstoneKey = getTombstoneKey();
- unsigned NumEntries = getNumEntries();
for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey)) {
- if (!KeyInfoT::isEqual(P->getFirst(), TombstoneKey)) {
+ if (P->isValueConstructed()) {
P->getSecond().~ValueT();
- --NumEntries;
+ P->setValueConstructed(false);
}
P->getFirst() = EmptyKey;
}
}
- assert(NumEntries == 0 && "Node count imbalance!");
- (void)NumEntries;
}
setNumEntries(0);
setNumTombstones(0);
@@ -172,15 +225,14 @@ class DenseMapBase : public DebugEpochBase {
/// The DenseMapInfo is responsible for supplying methods
/// getHashValue(LookupKeyT) and isEqual(LookupKeyT, KeyT) for each key
/// type used.
- template<class LookupKeyT>
- iterator find_as(const LookupKeyT &Val) {
+ template <class LookupKeyT> iterator find_as(const LookupKeyT &Val) {
if (BucketT *Bucket = doFind(Val))
return makeIterator(
Bucket, shouldReverseIterate<KeyT>() ? getBuckets() : getBucketsEnd(),
*this, true);
return end();
}
- template<class LookupKeyT>
+ template <class LookupKeyT>
const_iterator find_as(const LookupKeyT &Val) const {
if (const BucketT *Bucket = doFind(Val))
return makeConstIterator(
@@ -223,7 +275,7 @@ class DenseMapBase : public DebugEpochBase {
// The value is constructed in-place if the key is not in the map, otherwise
// it is not moved.
template <typename... Ts>
- std::pair<iterator, bool> try_emplace(KeyT &&Key, Ts &&... Args) {
+ std::pair<iterator, bool> try_emplace(KeyT &&Key, Ts &&...Args) {
BucketT *TheBucket;
if (LookupBucketFor(Key, TheBucket))
return std::make_pair(makeIterator(TheBucket,
@@ -248,7 +300,7 @@ class DenseMapBase : public DebugEpochBase {
// The value is constructed in-place if the key is not in the map, otherwise
// it is not moved.
template <typename... Ts>
- std::pair<iterator, bool> try_emplace(const KeyT &Key, Ts &&... Args) {
+ std::pair<iterator, bool> try_emplace(const KeyT &Key, Ts &&...Args) {
BucketT *TheBucket;
if (LookupBucketFor(Key, TheBucket))
return std::make_pair(makeIterator(TheBucket,
@@ -297,8 +349,7 @@ class DenseMapBase : public DebugEpochBase {
}
/// insert - Range insertion of pairs.
- template<typename InputIt>
- void insert(InputIt I, InputIt E) {
+ template <typename InputIt> void insert(InputIt I, InputIt E) {
for (; I != E; ++I)
insert(*I);
}
@@ -341,17 +392,22 @@ class DenseMapBase : public DebugEpochBase {
return false; // not in map.
TheBucket->getSecond().~ValueT();
+ TheBucket->setValueConstructed(false);
TheBucket->getFirst() = getTombstoneKey();
decrementNumEntries();
incrementNumTombstones();
return true;
}
void erase(iterator I) {
- BucketT *TheBucket = &*I;
- TheBucket->getSecond().~ValueT();
- TheBucket->getFirst() = getTombstoneKey();
- decrementNumEntries();
- incrementNumTombstones();
+ BucketT *TheBucket = static_cast<BucketT *>(&*I);
+ // Iterator can point to nullptr in case of memory malfunctions
+ if (TheBucket != nullptr) {
+ TheBucket->getSecond().~ValueT();
+ TheBucket->setValueConstructed(false);
+ TheBucket->getFirst() = getTombstoneKey();
+ decrementNumEntries();
+ incrementNumTombstones();
+ }
}
LLVM_DEPRECATED("Use [Key] instead", "[Key]")
@@ -404,15 +460,24 @@ class DenseMapBase : public DebugEpochBase {
DenseMapBase() = default;
void destroyAll() {
- if (getNumBuckets() == 0) // Nothing to do.
- return;
-
- const KeyT EmptyKey = getEmptyKey(), TombstoneKey = getTombstoneKey();
for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
- if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey) &&
- !KeyInfoT::isEqual(P->getFirst(), TombstoneKey))
+ if (P->isValueConstructed()) {
P->getSecond().~ValueT();
- P->getFirst().~KeyT();
+ P->setValueConstructed(false);
+ }
+ if (P->isKeyConstructed()) {
+ P->getFirst().~KeyT();
+ P->setKeyConstructed(false);
+ }
+ }
+ }
+
+ void initUnitialized() {
+ BucketT *B = getBuckets();
+ BucketT *E = getBucketsEnd();
+ for (; B != E; ++B) {
+ B->setKeyConstructed(false);
+ B->setValueConstructed(false);
}
}
@@ -420,11 +485,19 @@ class DenseMapBase : public DebugEpochBase {
setNumEntries(0);
setNumTombstones(0);
- assert((getNumBuckets() & (getNumBuckets()-1)) == 0 &&
+ assert((getNumBuckets() & (getNumBuckets() - 1)) == 0 &&
"# initial buckets must be a power of two!");
const KeyT EmptyKey = getEmptyKey();
- for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B)
+#ifndef NDEBUG
+ for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
+ assert(!B->isKeyConstructed());
+ assert(!B->isValueConstructed());
+ }
+#endif
+ for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
::new (&B->getFirst()) KeyT(EmptyKey);
+ B->setKeyConstructed(true);
+ }
}
/// Returns the number of buckets to allocate to ensure that the DenseMap can
@@ -454,18 +527,23 @@ class DenseMapBase : public DebugEpochBase {
assert(!FoundVal && "Key already in new map?");
DestBucket->getFirst() = std::move(B->getFirst());
::new (&DestBucket->getSecond()) ValueT(std::move(B->getSecond()));
+ DestBucket->setValueConstructed(true);
incrementNumEntries();
-
- // Free the value.
+ }
+ if (B->isValueConstructed()) {
B->getSecond().~ValueT();
+ B->setValueConstructed(false);
+ }
+ if (B->isKeyConstructed()) {
+ B->getFirst().~KeyT();
+ B->setKeyConstructed(false);
}
- B->getFirst().~KeyT();
}
}
template <typename OtherBaseT>
- void copyFrom(
- const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT> &other) {
+ void copyFrom(const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT,
+ BucketBaseT> &other) {
assert(&other != this);
assert(getNumBuckets() == other.getNumBuckets());
@@ -480,10 +558,13 @@ class DenseMapBase : public DebugEpochBase {
for (size_t i = 0; i < getNumBuckets(); ++i) {
::new (&getBuckets()[i].getFirst())
KeyT(other.getBuckets()[i].getFirst());
+ getBuckets()[i].setKeyConstructed(true);
if (!KeyInfoT::isEqual(getBuckets()[i].getFirst(), getEmptyKey()) &&
- !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey()))
+ !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey())) {
::new (&getBuckets()[i].getSecond())
ValueT(other.getBuckets()[i].getSecond());
+ getBuckets()[i].setValueConstructed(true);
+ }
}
}
@@ -491,7 +572,7 @@ class DenseMapBase : public DebugEpochBase {
return KeyInfoT::getHashValue(Val);
}
- template<typename LookupKeyT>
+ template <typename LookupKeyT>
static unsigned getHashValue(const LookupKeyT &Val) {
return KeyInfoT::getHashValue(Val);
}
@@ -502,14 +583,11 @@ class DenseMapBase : public DebugEpochBase {
return KeyInfoT::getEmptyKey();
}
- static const KeyT getTombstoneKey() {
- return KeyInfoT::getTombstoneKey();
- }
+ static const KeyT getTombstoneKey() { return KeyInfoT::getTombstoneKey(); }
private:
- iterator makeIterator(BucketT *P, BucketT *E,
- DebugEpochBase &Epoch,
- bool NoAdvance=false) {
+ iterator makeIterator(BucketT *P, BucketT *E, DebugEpochBase &Epoch,
+ bool NoAdvance = false) {
if (shouldReverseIterate<KeyT>()) {
BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1;
return iterator(B, E, Epoch, NoAdvance);
@@ -519,7 +597,7 @@ class DenseMapBase : public DebugEpochBase {
const_iterator makeConstIterator(const BucketT *P, const BucketT *E,
const DebugEpochBase &Epoch,
- const bool NoAdvance=false) const {
+ const bool NoAdvance = false) const {
if (shouldReverseIterate<KeyT>()) {
const BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1;
return const_iterator(B, E, Epoch, NoAdvance);
@@ -535,13 +613,9 @@ class DenseMapBase : public DebugEpochBase {
static_cast<DerivedT *>(this)->setNumEntries(Num);
}
- void incrementNumEntries() {
- setNumEntries(getNumEntries() + 1);
- }
+ void incrementNumEntries() { setNumEntries(getNumEntries() + 1); }
- void decrementNumEntries() {
- setNumEntries(getNumEntries() - 1);
- }
+ void decrementNumEntries() { setNumEntries(getNumEntries() - 1); }
unsigned getNumTombstones() const {
return static_cast<const DerivedT *>(this)->getNumTombstones();
@@ -551,65 +625,52 @@ class DenseMapBase : public DebugEpochBase {
static_cast<DerivedT *>(this)->setNumTombstones(Num);
}
- void incrementNumTombstones() {
- setNumTombstones(getNumTombstones() + 1);
- }
+ void incrementNumTombstones() { setNumTombstones(getNumTombstones() + 1); }
- void decrementNumTombstones() {
- setNumTombstones(getNumTombstones() - 1);
- }
+ void decrementNumTombstones() { setNumTombstones(getNumTombstones() - 1); }
const BucketT *getBuckets() const {
return static_cast<const DerivedT *>(this)->getBuckets();
}
- BucketT *getBuckets() {
- return static_cast<DerivedT *>(this)->getBuckets();
- }
+ BucketT *getBuckets() { return static_cast<DerivedT *>(this)->getBuckets(); }
unsigned getNumBuckets() const {
return static_cast<const DerivedT *>(this)->getNumBuckets();
}
- BucketT *getBucketsEnd() {
- return getBuckets() + getNumBuckets();
- }
+ BucketT *getBucketsEnd() { return getBuckets() + getNumBuckets(); }
const BucketT *getBucketsEnd() const {
return getBuckets() + getNumBuckets();
}
- void grow(unsigned AtLeast) {
- static_cast<DerivedT *>(this)->grow(AtLeast);
- }
+ void grow(unsigned AtLeast) { static_cast<DerivedT *>(this)->grow(AtLeast); }
- void shrink_and_clear() {
- static_cast<DerivedT *>(this)->shrink_and_clear();
- }
+ void shrink_and_clear() { static_cast<DerivedT *>(this)->shrink_and_clear(); }
template <typename KeyArg, typename... ValueArgs>
BucketT *InsertIntoBucket(BucketT *TheBucket, KeyArg &&Key,
- ValueArgs &&... Values) {
- TheBucket = InsertIntoBucketImpl(Key, Key, TheBucket);
-
+ ValueArgs &&...Values) {
+ TheBucket = InsertIntoBucketImpl(Key, TheBucket);
TheBucket->getFirst() = std::forward<KeyArg>(Key);
::new (&TheBucket->getSecond()) ValueT(std::forward<ValueArgs>(Values)...);
+ TheBucket->setValueConstructed(true);
return TheBucket;
}
template <typename LookupKeyT>
BucketT *InsertIntoBucketWithLookup(BucketT *TheBucket, KeyT &&Key,
ValueT &&Value, LookupKeyT &Lookup) {
- TheBucket = InsertIntoBucketImpl(Key, Lookup, TheBucket);
-
+ TheBucket = InsertIntoBucketImpl(Lookup, TheBucket);
TheBucket->getFirst() = std::move(Key);
::new (&TheBucket->getSecond()) ValueT(std::move(Value));
+ TheBucket->setValueConstructed(true);
return TheBucket;
}
template <typename LookupKeyT>
- BucketT *InsertIntoBucketImpl(const KeyT &Key, const LookupKeyT &Lookup,
- BucketT *TheBucket) {
+ BucketT *InsertIntoBucketImpl(const LookupKeyT &Lookup, BucketT *TheBucket) {
incrementEpoch();
// If the load of the hash table is more than 3/4, or if fewer than 1/8 of
@@ -627,8 +688,9 @@ class DenseMapBase : public DebugEpochBase {
this->grow(NumBuckets * 2);
LookupBucketFor(Lookup, TheBucket);
NumBuckets = getNumBuckets();
- } else if (LLVM_UNLIKELY(NumBuckets-(NewNumEntries+getNumTombstones()) <=
- NumBuckets/8)) {
+ } else if (LLVM_UNLIKELY(NumBuckets -
+ (NewNumEntries + getNumTombstones()) <=
+ NumBuckets / 8)) {
this->grow(NumBuckets);
LookupBucketFor(Lookup, TheBucket);
}
@@ -696,7 +758,7 @@ class DenseMapBase : public DebugEpochBase {
!KeyInfoT::isEqual(Val, TombstoneKey) &&
"Empty/Tombstone value shouldn't be inserted into map!");
- unsigned BucketNo = getHashValue(Val) & (NumBuckets-1);
+ unsigned BucketNo = getHashValue(Val) & (NumBuckets - 1);
unsigned ProbeAmt = 1;
while (true) {
BucketT *ThisBucket = BucketsPtr + BucketNo;
@@ -719,23 +781,63 @@ class DenseMapBase : public DebugEpochBase {
// prefer to return it than something that would require more probing.
if (KeyInfoT::isEqual(ThisBucket->getFirst(), TombstoneKey) &&
!FoundTombstone)
- FoundTombstone = ThisBucket; // Remember the first tombstone found.
+ FoundTombstone = ThisBucket; // Remember the first tombstone found.
// Otherwise, it's a hash collision or a tombstone, continue quadratic
// probing.
BucketNo += ProbeAmt++;
- BucketNo &= (NumBuckets-1);
+ BucketNo &= (NumBuckets - 1);
}
}
+protected:
+ // helper class to guarantee deallocation of buffer
+ class ReleaseOldBuffer {
+ BucketT *m_buckets;
+ unsigned m_numBuckets;
+
+ public:
+ ReleaseOldBuffer(BucketT *buckets, unsigned numBuckets)
+ : m_buckets(buckets), m_numBuckets(numBuckets) {}
+ ~ReleaseOldBuffer() {
+#ifndef NDEBUG
+ memset((void *)m_buckets, 0x5a, sizeof(BucketT) * m_numBuckets);
+#endif
+ // Free the old table.
+ size_t const alignment{alignof(BucketT)};
+ deallocate_buffer(static_cast<void *>(m_buckets),
+ sizeof(BucketT) * m_numBuckets, alignment);
+ }
+ };
+
+ // helper class to guarantee destruction of bucket content
+ class ReleaseOldBuckets {
+ BucketT *m_buckets;
+ unsigned m_numBuckets;
+
+ public:
+ ReleaseOldBuckets(BucketT *buckets, unsigned numBuckets)
+ : m_buckets(buckets), m_numBuckets(numBuckets) {}
+ ~ReleaseOldBuckets() {
+ for (BucketT *B = m_buckets, *E = m_buckets + m_numBuckets; B != E; ++B) {
+ if (B->isValueConstructed()) {
+ B->getSecond().~ValueT();
+ B->setValueConstructed(false);
+ }
+ if (B->isKeyConstructed()) {
+ B->getFirst().~KeyT();
+ B->setKeyConstructed(false);
+ }
+ }
+ }
+ };
+
public:
/// Return the approximate size (in bytes) of the actual map.
/// This is just the raw memory used by DenseMap.
/// If entries are pointers to objects, the size of the referenced objects
/// are not included.
- size_t getMemorySize() const {
- return getNumBuckets() * sizeof(BucketT);
- }
+ size...
[truncated]
|
Is there any way you could split your patch into several smaller ones? If you are going to work on pretty much all of After that, I would get small obvious changes out of the way. For example, your Thanks! |
Yes, agreed. |
This direction looks quite concerning to me. LLVM is explicitly not designed to be exception safe, and this change is not at all free. It introduces significant implementation complexity, compile-time impact and memory usage impact. The minimum requirement here is going to be that this does not have any compile-time/memory usage impact on people who do not need this functionality -- but I'm not convinced that we should accept this downstream at all, as a matter of policy. Major compile-time regressions: https://llvm-compile-time-tracker.com/compare.php?from=f4dd1bc8fc625d3938f95b9d06aaaeebd2e90dca&to=1cb62cb5fb8eeca9641992f98f1b24b328410034&stat=instructions:u Major memory usage regressions: https://llvm-compile-time-tracker.com/compare.php?from=f4dd1bc8fc625d3938f95b9d06aaaeebd2e90dca&to=1cb62cb5fb8eeca9641992f98f1b24b328410034&stat=max-rss |
+1 to what @nikic said. As an ADT contributor, the thing I really value about our implementation is that it's much simpler than alternative hash maps that have to worry about exceptions. I think it's safe to say that LLVM developers at large are not used to / trained to think about exception safety, and this puts burden on subsequent contributors who may not care about handling these specific OOM conditions. |
bb74000
to
6c9cc94
Compare
6c9cc94
to
7f437b1
Compare
This patch hardens DenseMap and DenseSet against OOM scenarios, outlined in this RFC. Feel free to comment any comments/suggestions/objections there.
As discussed in the RFC, we would eventually like to have malfunction tests to ensure malfunction-safety. This is still in internal discussion on our side. Therefore we would like to contribute the improvements with the existing functional tests for now.