Skip to content

[ADT] Make DenseMap/DenseSet more resilient against OOM situations #107251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

marcauberer
Copy link
Member

@marcauberer marcauberer commented Sep 4, 2024

This patch hardens DenseMap and DenseSet against OOM scenarios, outlined in this RFC. Feel free to comment any comments/suggestions/objections there.

As discussed in the RFC, we would eventually like to have malfunction tests to ensure malfunction-safety. This is still in internal discussion on our side. Therefore we would like to contribute the improvements with the existing functional tests for now.

@marcauberer marcauberer force-pushed the llvm/adt/malfunction-safe-densemap branch from ddf904f to 84e90db Compare September 4, 2024 14:57
Copy link

github-actions bot commented Sep 4, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@marcauberer marcauberer force-pushed the llvm/adt/malfunction-safe-densemap branch 3 times, most recently from 07aaed1 to bb74000 Compare September 10, 2024 14:13
@marcauberer marcauberer marked this pull request as ready for review September 10, 2024 14:53
@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2024

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-llvm-adt

Author: Marc Auberer (marcauberer)

Changes

This patch hardens DenseMap and DenseSet against OOM scenarios, outlined in this RFC. Feel free to comment any comments/suggestions/objections there.

As discussed in the RFC, we would eventually like to have malfunction tests to ensure malfunction-safety. This is still in internal discussion on our side. Therefore we would like to contribute the improvements with the existing functional tests for now.


Patch is 53.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107251.diff

4 Files Affected:

  • (modified) llvm/include/llvm/ADT/DenseMap.h (+378-220)
  • (modified) llvm/include/llvm/ADT/DenseSet.h (+91-30)
  • (modified) llvm/include/llvm/TextAPI/SymbolSet.h (+8-6)
  • (modified) mlir/include/mlir/Support/LLVM.h (+7-3)
diff --git a/llvm/include/llvm/ADT/DenseMap.h b/llvm/include/llvm/ADT/DenseMap.h
index 00290c9dd0a585..6cc857f906e910 100644
--- a/llvm/include/llvm/ADT/DenseMap.h
+++ b/llvm/include/llvm/ADT/DenseMap.h
@@ -16,6 +16,7 @@
 
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/EpochTracker.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/AlignOf.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/MathExtras.h"
@@ -36,8 +37,6 @@ namespace llvm {
 
 namespace detail {
 
-// We extend a pair to allow users to override the bucket type with their own
-// implementation without requiring two members.
 template <typename KeyT, typename ValueT>
 struct DenseMapPair : public std::pair<KeyT, ValueT> {
   using std::pair<KeyT, ValueT>::pair;
@@ -48,16 +47,72 @@ struct DenseMapPair : public std::pair<KeyT, ValueT> {
   const ValueT &getSecond() const { return std::pair<KeyT, ValueT>::second; }
 };
 
+// OOM safe DenseMap bucket
+// flags to indicate if key and value are constructed
+// this is necessary since assigning empty key or tombstone key throws for some
+// key types
+
+template <typename KeyT, typename ValueT, typename Enable = void>
+struct DenseMapPairImpl;
+
+// part I: helpers to decide which specialization of DenseMapPairImpl to use
+template <typename KeyT, typename ValueT>
+using EnableIfTriviallyDestructibleMap =
+    typename std::enable_if_t<std::is_trivially_destructible_v<KeyT> &&
+                              std::is_trivially_destructible_v<ValueT>>;
+template <typename KeyT, typename ValueT>
+using EnableIfNotTriviallyDestructibleMap =
+    typename std::enable_if_t<!std::is_trivially_destructible_v<KeyT> ||
+                              !std::is_trivially_destructible_v<ValueT>>;
+
+// OOM safe DenseMap bucket
+// part II: specialization for trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+                        EnableIfTriviallyDestructibleMap<KeyT, ValueT>>
+    : public DenseMapPair<KeyT, ValueT> {
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const { return false; }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const { return false; }
+};
+
+// OOM safe DenseMap bucket
+// part III: specialization for non-trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+                        EnableIfNotTriviallyDestructibleMap<KeyT, ValueT>>
+    : DenseMapPair<KeyT, ValueT> {
+private:
+  bool m_keyConstructed;
+  bool m_valueConstructed;
+
+public:
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool constructed) {
+    m_keyConstructed = constructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const {
+    return m_keyConstructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool constructed) {
+    m_valueConstructed = constructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const {
+    return m_valueConstructed;
+  }
+};
+
 } // end namespace detail
 
 template <typename KeyT, typename ValueT,
           typename KeyInfoT = DenseMapInfo<KeyT>,
-          typename Bucket = llvm::detail::DenseMapPair<KeyT, ValueT>,
+          typename BucketT = llvm::detail::DenseMapPairImpl<KeyT, ValueT>,
+          typename BucketBaseT = llvm::detail::DenseMapPair<KeyT, ValueT>,
           bool IsConst = false>
 class DenseMapIterator;
 
 template <typename DerivedT, typename KeyT, typename ValueT, typename KeyInfoT,
-          typename BucketT>
+          typename BucketT, typename BucketBaseT>
 class DenseMapBase : public DebugEpochBase {
   template <typename T>
   using const_arg_type_t = typename const_pointer_or_const_ref<T>::type;
@@ -66,11 +121,12 @@ class DenseMapBase : public DebugEpochBase {
   using size_type = unsigned;
   using key_type = KeyT;
   using mapped_type = ValueT;
-  using value_type = BucketT;
+  using value_type = BucketBaseT;
 
-  using iterator = DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT>;
+  using iterator =
+      DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT>;
   using const_iterator =
-      DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, true>;
+      DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT, true>;
 
   inline iterator begin() {
     // When the map is empty, avoid the overhead of advancing/retreating past
@@ -109,7 +165,8 @@ class DenseMapBase : public DebugEpochBase {
 
   void clear() {
     incrementEpoch();
-    if (getNumEntries() == 0 && getNumTombstones() == 0) return;
+    if (getNumEntries() == 0 && getNumTombstones() == 0)
+      return;
 
     // If the capacity of the array is huge, and the # elements used is small,
     // shrink the array.
@@ -119,24 +176,20 @@ class DenseMapBase : public DebugEpochBase {
     }
 
     const KeyT EmptyKey = getEmptyKey();
-    if (std::is_trivially_destructible<ValueT>::value) {
+    if constexpr (std::is_trivially_destructible_v<ValueT>) {
       // Use a simpler loop when values don't need destruction.
       for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P)
         P->getFirst() = EmptyKey;
     } else {
-      const KeyT TombstoneKey = getTombstoneKey();
-      unsigned NumEntries = getNumEntries();
       for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
         if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey)) {
-          if (!KeyInfoT::isEqual(P->getFirst(), TombstoneKey)) {
+          if (P->isValueConstructed()) {
             P->getSecond().~ValueT();
-            --NumEntries;
+            P->setValueConstructed(false);
           }
           P->getFirst() = EmptyKey;
         }
       }
-      assert(NumEntries == 0 && "Node count imbalance!");
-      (void)NumEntries;
     }
     setNumEntries(0);
     setNumTombstones(0);
@@ -172,15 +225,14 @@ class DenseMapBase : public DebugEpochBase {
   /// The DenseMapInfo is responsible for supplying methods
   /// getHashValue(LookupKeyT) and isEqual(LookupKeyT, KeyT) for each key
   /// type used.
-  template<class LookupKeyT>
-  iterator find_as(const LookupKeyT &Val) {
+  template <class LookupKeyT> iterator find_as(const LookupKeyT &Val) {
     if (BucketT *Bucket = doFind(Val))
       return makeIterator(
           Bucket, shouldReverseIterate<KeyT>() ? getBuckets() : getBucketsEnd(),
           *this, true);
     return end();
   }
-  template<class LookupKeyT>
+  template <class LookupKeyT>
   const_iterator find_as(const LookupKeyT &Val) const {
     if (const BucketT *Bucket = doFind(Val))
       return makeConstIterator(
@@ -223,7 +275,7 @@ class DenseMapBase : public DebugEpochBase {
   // The value is constructed in-place if the key is not in the map, otherwise
   // it is not moved.
   template <typename... Ts>
-  std::pair<iterator, bool> try_emplace(KeyT &&Key, Ts &&... Args) {
+  std::pair<iterator, bool> try_emplace(KeyT &&Key, Ts &&...Args) {
     BucketT *TheBucket;
     if (LookupBucketFor(Key, TheBucket))
       return std::make_pair(makeIterator(TheBucket,
@@ -248,7 +300,7 @@ class DenseMapBase : public DebugEpochBase {
   // The value is constructed in-place if the key is not in the map, otherwise
   // it is not moved.
   template <typename... Ts>
-  std::pair<iterator, bool> try_emplace(const KeyT &Key, Ts &&... Args) {
+  std::pair<iterator, bool> try_emplace(const KeyT &Key, Ts &&...Args) {
     BucketT *TheBucket;
     if (LookupBucketFor(Key, TheBucket))
       return std::make_pair(makeIterator(TheBucket,
@@ -297,8 +349,7 @@ class DenseMapBase : public DebugEpochBase {
   }
 
   /// insert - Range insertion of pairs.
-  template<typename InputIt>
-  void insert(InputIt I, InputIt E) {
+  template <typename InputIt> void insert(InputIt I, InputIt E) {
     for (; I != E; ++I)
       insert(*I);
   }
@@ -341,17 +392,22 @@ class DenseMapBase : public DebugEpochBase {
       return false; // not in map.
 
     TheBucket->getSecond().~ValueT();
+    TheBucket->setValueConstructed(false);
     TheBucket->getFirst() = getTombstoneKey();
     decrementNumEntries();
     incrementNumTombstones();
     return true;
   }
   void erase(iterator I) {
-    BucketT *TheBucket = &*I;
-    TheBucket->getSecond().~ValueT();
-    TheBucket->getFirst() = getTombstoneKey();
-    decrementNumEntries();
-    incrementNumTombstones();
+    BucketT *TheBucket = static_cast<BucketT *>(&*I);
+    // Iterator can point to nullptr in case of memory malfunctions
+    if (TheBucket != nullptr) {
+      TheBucket->getSecond().~ValueT();
+      TheBucket->setValueConstructed(false);
+      TheBucket->getFirst() = getTombstoneKey();
+      decrementNumEntries();
+      incrementNumTombstones();
+    }
   }
 
   LLVM_DEPRECATED("Use [Key] instead", "[Key]")
@@ -404,15 +460,24 @@ class DenseMapBase : public DebugEpochBase {
   DenseMapBase() = default;
 
   void destroyAll() {
-    if (getNumBuckets() == 0) // Nothing to do.
-      return;
-
-    const KeyT EmptyKey = getEmptyKey(), TombstoneKey = getTombstoneKey();
     for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
-      if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey) &&
-          !KeyInfoT::isEqual(P->getFirst(), TombstoneKey))
+      if (P->isValueConstructed()) {
         P->getSecond().~ValueT();
-      P->getFirst().~KeyT();
+        P->setValueConstructed(false);
+      }
+      if (P->isKeyConstructed()) {
+        P->getFirst().~KeyT();
+        P->setKeyConstructed(false);
+      }
+    }
+  }
+
+  void initUnitialized() {
+    BucketT *B = getBuckets();
+    BucketT *E = getBucketsEnd();
+    for (; B != E; ++B) {
+      B->setKeyConstructed(false);
+      B->setValueConstructed(false);
     }
   }
 
@@ -420,11 +485,19 @@ class DenseMapBase : public DebugEpochBase {
     setNumEntries(0);
     setNumTombstones(0);
 
-    assert((getNumBuckets() & (getNumBuckets()-1)) == 0 &&
+    assert((getNumBuckets() & (getNumBuckets() - 1)) == 0 &&
            "# initial buckets must be a power of two!");
     const KeyT EmptyKey = getEmptyKey();
-    for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B)
+#ifndef NDEBUG
+    for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
+      assert(!B->isKeyConstructed());
+      assert(!B->isValueConstructed());
+    }
+#endif
+    for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
       ::new (&B->getFirst()) KeyT(EmptyKey);
+      B->setKeyConstructed(true);
+    }
   }
 
   /// Returns the number of buckets to allocate to ensure that the DenseMap can
@@ -454,18 +527,23 @@ class DenseMapBase : public DebugEpochBase {
         assert(!FoundVal && "Key already in new map?");
         DestBucket->getFirst() = std::move(B->getFirst());
         ::new (&DestBucket->getSecond()) ValueT(std::move(B->getSecond()));
+        DestBucket->setValueConstructed(true);
         incrementNumEntries();
-
-        // Free the value.
+      }
+      if (B->isValueConstructed()) {
         B->getSecond().~ValueT();
+        B->setValueConstructed(false);
+      }
+      if (B->isKeyConstructed()) {
+        B->getFirst().~KeyT();
+        B->setKeyConstructed(false);
       }
-      B->getFirst().~KeyT();
     }
   }
 
   template <typename OtherBaseT>
-  void copyFrom(
-      const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT> &other) {
+  void copyFrom(const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT,
+                                   BucketBaseT> &other) {
     assert(&other != this);
     assert(getNumBuckets() == other.getNumBuckets());
 
@@ -480,10 +558,13 @@ class DenseMapBase : public DebugEpochBase {
       for (size_t i = 0; i < getNumBuckets(); ++i) {
         ::new (&getBuckets()[i].getFirst())
             KeyT(other.getBuckets()[i].getFirst());
+        getBuckets()[i].setKeyConstructed(true);
         if (!KeyInfoT::isEqual(getBuckets()[i].getFirst(), getEmptyKey()) &&
-            !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey()))
+            !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey())) {
           ::new (&getBuckets()[i].getSecond())
               ValueT(other.getBuckets()[i].getSecond());
+          getBuckets()[i].setValueConstructed(true);
+        }
       }
   }
 
@@ -491,7 +572,7 @@ class DenseMapBase : public DebugEpochBase {
     return KeyInfoT::getHashValue(Val);
   }
 
-  template<typename LookupKeyT>
+  template <typename LookupKeyT>
   static unsigned getHashValue(const LookupKeyT &Val) {
     return KeyInfoT::getHashValue(Val);
   }
@@ -502,14 +583,11 @@ class DenseMapBase : public DebugEpochBase {
     return KeyInfoT::getEmptyKey();
   }
 
-  static const KeyT getTombstoneKey() {
-    return KeyInfoT::getTombstoneKey();
-  }
+  static const KeyT getTombstoneKey() { return KeyInfoT::getTombstoneKey(); }
 
 private:
-  iterator makeIterator(BucketT *P, BucketT *E,
-                        DebugEpochBase &Epoch,
-                        bool NoAdvance=false) {
+  iterator makeIterator(BucketT *P, BucketT *E, DebugEpochBase &Epoch,
+                        bool NoAdvance = false) {
     if (shouldReverseIterate<KeyT>()) {
       BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1;
       return iterator(B, E, Epoch, NoAdvance);
@@ -519,7 +597,7 @@ class DenseMapBase : public DebugEpochBase {
 
   const_iterator makeConstIterator(const BucketT *P, const BucketT *E,
                                    const DebugEpochBase &Epoch,
-                                   const bool NoAdvance=false) const {
+                                   const bool NoAdvance = false) const {
     if (shouldReverseIterate<KeyT>()) {
       const BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1;
       return const_iterator(B, E, Epoch, NoAdvance);
@@ -535,13 +613,9 @@ class DenseMapBase : public DebugEpochBase {
     static_cast<DerivedT *>(this)->setNumEntries(Num);
   }
 
-  void incrementNumEntries() {
-    setNumEntries(getNumEntries() + 1);
-  }
+  void incrementNumEntries() { setNumEntries(getNumEntries() + 1); }
 
-  void decrementNumEntries() {
-    setNumEntries(getNumEntries() - 1);
-  }
+  void decrementNumEntries() { setNumEntries(getNumEntries() - 1); }
 
   unsigned getNumTombstones() const {
     return static_cast<const DerivedT *>(this)->getNumTombstones();
@@ -551,65 +625,52 @@ class DenseMapBase : public DebugEpochBase {
     static_cast<DerivedT *>(this)->setNumTombstones(Num);
   }
 
-  void incrementNumTombstones() {
-    setNumTombstones(getNumTombstones() + 1);
-  }
+  void incrementNumTombstones() { setNumTombstones(getNumTombstones() + 1); }
 
-  void decrementNumTombstones() {
-    setNumTombstones(getNumTombstones() - 1);
-  }
+  void decrementNumTombstones() { setNumTombstones(getNumTombstones() - 1); }
 
   const BucketT *getBuckets() const {
     return static_cast<const DerivedT *>(this)->getBuckets();
   }
 
-  BucketT *getBuckets() {
-    return static_cast<DerivedT *>(this)->getBuckets();
-  }
+  BucketT *getBuckets() { return static_cast<DerivedT *>(this)->getBuckets(); }
 
   unsigned getNumBuckets() const {
     return static_cast<const DerivedT *>(this)->getNumBuckets();
   }
 
-  BucketT *getBucketsEnd() {
-    return getBuckets() + getNumBuckets();
-  }
+  BucketT *getBucketsEnd() { return getBuckets() + getNumBuckets(); }
 
   const BucketT *getBucketsEnd() const {
     return getBuckets() + getNumBuckets();
   }
 
-  void grow(unsigned AtLeast) {
-    static_cast<DerivedT *>(this)->grow(AtLeast);
-  }
+  void grow(unsigned AtLeast) { static_cast<DerivedT *>(this)->grow(AtLeast); }
 
-  void shrink_and_clear() {
-    static_cast<DerivedT *>(this)->shrink_and_clear();
-  }
+  void shrink_and_clear() { static_cast<DerivedT *>(this)->shrink_and_clear(); }
 
   template <typename KeyArg, typename... ValueArgs>
   BucketT *InsertIntoBucket(BucketT *TheBucket, KeyArg &&Key,
-                            ValueArgs &&... Values) {
-    TheBucket = InsertIntoBucketImpl(Key, Key, TheBucket);
-
+                            ValueArgs &&...Values) {
+    TheBucket = InsertIntoBucketImpl(Key, TheBucket);
     TheBucket->getFirst() = std::forward<KeyArg>(Key);
     ::new (&TheBucket->getSecond()) ValueT(std::forward<ValueArgs>(Values)...);
+    TheBucket->setValueConstructed(true);
     return TheBucket;
   }
 
   template <typename LookupKeyT>
   BucketT *InsertIntoBucketWithLookup(BucketT *TheBucket, KeyT &&Key,
                                       ValueT &&Value, LookupKeyT &Lookup) {
-    TheBucket = InsertIntoBucketImpl(Key, Lookup, TheBucket);
-
+    TheBucket = InsertIntoBucketImpl(Lookup, TheBucket);
     TheBucket->getFirst() = std::move(Key);
     ::new (&TheBucket->getSecond()) ValueT(std::move(Value));
+    TheBucket->setValueConstructed(true);
     return TheBucket;
   }
 
   template <typename LookupKeyT>
-  BucketT *InsertIntoBucketImpl(const KeyT &Key, const LookupKeyT &Lookup,
-                                BucketT *TheBucket) {
+  BucketT *InsertIntoBucketImpl(const LookupKeyT &Lookup, BucketT *TheBucket) {
     incrementEpoch();
 
     // If the load of the hash table is more than 3/4, or if fewer than 1/8 of
@@ -627,8 +688,9 @@ class DenseMapBase : public DebugEpochBase {
       this->grow(NumBuckets * 2);
       LookupBucketFor(Lookup, TheBucket);
       NumBuckets = getNumBuckets();
-    } else if (LLVM_UNLIKELY(NumBuckets-(NewNumEntries+getNumTombstones()) <=
-                             NumBuckets/8)) {
+    } else if (LLVM_UNLIKELY(NumBuckets -
+                                 (NewNumEntries + getNumTombstones()) <=
+                             NumBuckets / 8)) {
       this->grow(NumBuckets);
       LookupBucketFor(Lookup, TheBucket);
     }
@@ -696,7 +758,7 @@ class DenseMapBase : public DebugEpochBase {
            !KeyInfoT::isEqual(Val, TombstoneKey) &&
            "Empty/Tombstone value shouldn't be inserted into map!");
 
-    unsigned BucketNo = getHashValue(Val) & (NumBuckets-1);
+    unsigned BucketNo = getHashValue(Val) & (NumBuckets - 1);
     unsigned ProbeAmt = 1;
     while (true) {
       BucketT *ThisBucket = BucketsPtr + BucketNo;
@@ -719,23 +781,63 @@ class DenseMapBase : public DebugEpochBase {
       // prefer to return it than something that would require more probing.
       if (KeyInfoT::isEqual(ThisBucket->getFirst(), TombstoneKey) &&
           !FoundTombstone)
-        FoundTombstone = ThisBucket;  // Remember the first tombstone found.
+        FoundTombstone = ThisBucket; // Remember the first tombstone found.
 
       // Otherwise, it's a hash collision or a tombstone, continue quadratic
       // probing.
       BucketNo += ProbeAmt++;
-      BucketNo &= (NumBuckets-1);
+      BucketNo &= (NumBuckets - 1);
     }
   }
 
+protected:
+  // helper class to guarantee deallocation of buffer
+  class ReleaseOldBuffer {
+    BucketT *m_buckets;
+    unsigned m_numBuckets;
+
+  public:
+    ReleaseOldBuffer(BucketT *buckets, unsigned numBuckets)
+        : m_buckets(buckets), m_numBuckets(numBuckets) {}
+    ~ReleaseOldBuffer() {
+#ifndef NDEBUG
+      memset((void *)m_buckets, 0x5a, sizeof(BucketT) * m_numBuckets);
+#endif
+      // Free the old table.
+      size_t const alignment{alignof(BucketT)};
+      deallocate_buffer(static_cast<void *>(m_buckets),
+                        sizeof(BucketT) * m_numBuckets, alignment);
+    }
+  };
+
+  // helper class to guarantee destruction of bucket content
+  class ReleaseOldBuckets {
+    BucketT *m_buckets;
+    unsigned m_numBuckets;
+
+  public:
+    ReleaseOldBuckets(BucketT *buckets, unsigned numBuckets)
+        : m_buckets(buckets), m_numBuckets(numBuckets) {}
+    ~ReleaseOldBuckets() {
+      for (BucketT *B = m_buckets, *E = m_buckets + m_numBuckets; B != E; ++B) {
+        if (B->isValueConstructed()) {
+          B->getSecond().~ValueT();
+          B->setValueConstructed(false);
+        }
+        if (B->isKeyConstructed()) {
+          B->getFirst().~KeyT();
+          B->setKeyConstructed(false);
+        }
+      }
+    }
+  };
+
 public:
   /// Return the approximate size (in bytes) of the actual map.
   /// This is just the raw memory used by DenseMap.
   /// If entries are pointers to objects, the size of the referenced objects
   /// are not included.
-  size_t getMemorySize() const {
-    return getNumBuckets() * sizeof(BucketT);
-  }
+  size...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2024

@llvm/pr-subscribers-mlir-core

Author: Marc Auberer (marcauberer)

Changes

This patch hardens DenseMap and DenseSet against OOM scenarios, outlined in this RFC. Feel free to comment any comments/suggestions/objections there.

As discussed in the RFC, we would eventually like to have malfunction tests to ensure malfunction-safety. This is still in internal discussion on our side. Therefore we would like to contribute the improvements with the existing functional tests for now.


Patch is 53.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107251.diff

4 Files Affected:

  • (modified) llvm/include/llvm/ADT/DenseMap.h (+378-220)
  • (modified) llvm/include/llvm/ADT/DenseSet.h (+91-30)
  • (modified) llvm/include/llvm/TextAPI/SymbolSet.h (+8-6)
  • (modified) mlir/include/mlir/Support/LLVM.h (+7-3)
diff --git a/llvm/include/llvm/ADT/DenseMap.h b/llvm/include/llvm/ADT/DenseMap.h
index 00290c9dd0a585..6cc857f906e910 100644
--- a/llvm/include/llvm/ADT/DenseMap.h
+++ b/llvm/include/llvm/ADT/DenseMap.h
@@ -16,6 +16,7 @@
 
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/EpochTracker.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/AlignOf.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/MathExtras.h"
@@ -36,8 +37,6 @@ namespace llvm {
 
 namespace detail {
 
-// We extend a pair to allow users to override the bucket type with their own
-// implementation without requiring two members.
 template <typename KeyT, typename ValueT>
 struct DenseMapPair : public std::pair<KeyT, ValueT> {
   using std::pair<KeyT, ValueT>::pair;
@@ -48,16 +47,72 @@ struct DenseMapPair : public std::pair<KeyT, ValueT> {
   const ValueT &getSecond() const { return std::pair<KeyT, ValueT>::second; }
 };
 
+// OOM safe DenseMap bucket
+// flags to indicate if key and value are constructed
+// this is necessary since assigning empty key or tombstone key throws for some
+// key types
+
+template <typename KeyT, typename ValueT, typename Enable = void>
+struct DenseMapPairImpl;
+
+// part I: helpers to decide which specialization of DenseMapPairImpl to use
+template <typename KeyT, typename ValueT>
+using EnableIfTriviallyDestructibleMap =
+    typename std::enable_if_t<std::is_trivially_destructible_v<KeyT> &&
+                              std::is_trivially_destructible_v<ValueT>>;
+template <typename KeyT, typename ValueT>
+using EnableIfNotTriviallyDestructibleMap =
+    typename std::enable_if_t<!std::is_trivially_destructible_v<KeyT> ||
+                              !std::is_trivially_destructible_v<ValueT>>;
+
+// OOM safe DenseMap bucket
+// part II: specialization for trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+                        EnableIfTriviallyDestructibleMap<KeyT, ValueT>>
+    : public DenseMapPair<KeyT, ValueT> {
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const { return false; }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const { return false; }
+};
+
+// OOM safe DenseMap bucket
+// part III: specialization for non-trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+                        EnableIfNotTriviallyDestructibleMap<KeyT, ValueT>>
+    : DenseMapPair<KeyT, ValueT> {
+private:
+  bool m_keyConstructed;
+  bool m_valueConstructed;
+
+public:
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool constructed) {
+    m_keyConstructed = constructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const {
+    return m_keyConstructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool constructed) {
+    m_valueConstructed = constructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const {
+    return m_valueConstructed;
+  }
+};
+
 } // end namespace detail
 
 template <typename KeyT, typename ValueT,
           typename KeyInfoT = DenseMapInfo<KeyT>,
-          typename Bucket = llvm::detail::DenseMapPair<KeyT, ValueT>,
+          typename BucketT = llvm::detail::DenseMapPairImpl<KeyT, ValueT>,
+          typename BucketBaseT = llvm::detail::DenseMapPair<KeyT, ValueT>,
           bool IsConst = false>
 class DenseMapIterator;
 
 template <typename DerivedT, typename KeyT, typename ValueT, typename KeyInfoT,
-          typename BucketT>
+          typename BucketT, typename BucketBaseT>
 class DenseMapBase : public DebugEpochBase {
   template <typename T>
   using const_arg_type_t = typename const_pointer_or_const_ref<T>::type;
@@ -66,11 +121,12 @@ class DenseMapBase : public DebugEpochBase {
   using size_type = unsigned;
   using key_type = KeyT;
   using mapped_type = ValueT;
-  using value_type = BucketT;
+  using value_type = BucketBaseT;
 
-  using iterator = DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT>;
+  using iterator =
+      DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT>;
   using const_iterator =
-      DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, true>;
+      DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT, true>;
 
   inline iterator begin() {
     // When the map is empty, avoid the overhead of advancing/retreating past
@@ -109,7 +165,8 @@ class DenseMapBase : public DebugEpochBase {
 
   void clear() {
     incrementEpoch();
-    if (getNumEntries() == 0 && getNumTombstones() == 0) return;
+    if (getNumEntries() == 0 && getNumTombstones() == 0)
+      return;
 
     // If the capacity of the array is huge, and the # elements used is small,
     // shrink the array.
@@ -119,24 +176,20 @@ class DenseMapBase : public DebugEpochBase {
     }
 
     const KeyT EmptyKey = getEmptyKey();
-    if (std::is_trivially_destructible<ValueT>::value) {
+    if constexpr (std::is_trivially_destructible_v<ValueT>) {
       // Use a simpler loop when values don't need destruction.
       for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P)
         P->getFirst() = EmptyKey;
     } else {
-      const KeyT TombstoneKey = getTombstoneKey();
-      unsigned NumEntries = getNumEntries();
       for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
         if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey)) {
-          if (!KeyInfoT::isEqual(P->getFirst(), TombstoneKey)) {
+          if (P->isValueConstructed()) {
             P->getSecond().~ValueT();
-            --NumEntries;
+            P->setValueConstructed(false);
           }
           P->getFirst() = EmptyKey;
         }
       }
-      assert(NumEntries == 0 && "Node count imbalance!");
-      (void)NumEntries;
     }
     setNumEntries(0);
     setNumTombstones(0);
@@ -172,15 +225,14 @@ class DenseMapBase : public DebugEpochBase {
   /// The DenseMapInfo is responsible for supplying methods
   /// getHashValue(LookupKeyT) and isEqual(LookupKeyT, KeyT) for each key
   /// type used.
-  template<class LookupKeyT>
-  iterator find_as(const LookupKeyT &Val) {
+  template <class LookupKeyT> iterator find_as(const LookupKeyT &Val) {
     if (BucketT *Bucket = doFind(Val))
       return makeIterator(
           Bucket, shouldReverseIterate<KeyT>() ? getBuckets() : getBucketsEnd(),
           *this, true);
     return end();
   }
-  template<class LookupKeyT>
+  template <class LookupKeyT>
   const_iterator find_as(const LookupKeyT &Val) const {
     if (const BucketT *Bucket = doFind(Val))
       return makeConstIterator(
@@ -223,7 +275,7 @@ class DenseMapBase : public DebugEpochBase {
   // The value is constructed in-place if the key is not in the map, otherwise
   // it is not moved.
   template <typename... Ts>
-  std::pair<iterator, bool> try_emplace(KeyT &&Key, Ts &&... Args) {
+  std::pair<iterator, bool> try_emplace(KeyT &&Key, Ts &&...Args) {
     BucketT *TheBucket;
     if (LookupBucketFor(Key, TheBucket))
       return std::make_pair(makeIterator(TheBucket,
@@ -248,7 +300,7 @@ class DenseMapBase : public DebugEpochBase {
   // The value is constructed in-place if the key is not in the map, otherwise
   // it is not moved.
   template <typename... Ts>
-  std::pair<iterator, bool> try_emplace(const KeyT &Key, Ts &&... Args) {
+  std::pair<iterator, bool> try_emplace(const KeyT &Key, Ts &&...Args) {
     BucketT *TheBucket;
     if (LookupBucketFor(Key, TheBucket))
       return std::make_pair(makeIterator(TheBucket,
@@ -297,8 +349,7 @@ class DenseMapBase : public DebugEpochBase {
   }
 
   /// insert - Range insertion of pairs.
-  template<typename InputIt>
-  void insert(InputIt I, InputIt E) {
+  template <typename InputIt> void insert(InputIt I, InputIt E) {
     for (; I != E; ++I)
       insert(*I);
   }
@@ -341,17 +392,22 @@ class DenseMapBase : public DebugEpochBase {
       return false; // not in map.
 
     TheBucket->getSecond().~ValueT();
+    TheBucket->setValueConstructed(false);
     TheBucket->getFirst() = getTombstoneKey();
     decrementNumEntries();
     incrementNumTombstones();
     return true;
   }
   void erase(iterator I) {
-    BucketT *TheBucket = &*I;
-    TheBucket->getSecond().~ValueT();
-    TheBucket->getFirst() = getTombstoneKey();
-    decrementNumEntries();
-    incrementNumTombstones();
+    BucketT *TheBucket = static_cast<BucketT *>(&*I);
+    // Iterator can point to nullptr in case of memory malfunctions
+    if (TheBucket != nullptr) {
+      TheBucket->getSecond().~ValueT();
+      TheBucket->setValueConstructed(false);
+      TheBucket->getFirst() = getTombstoneKey();
+      decrementNumEntries();
+      incrementNumTombstones();
+    }
   }
 
   LLVM_DEPRECATED("Use [Key] instead", "[Key]")
@@ -404,15 +460,24 @@ class DenseMapBase : public DebugEpochBase {
   DenseMapBase() = default;
 
   void destroyAll() {
-    if (getNumBuckets() == 0) // Nothing to do.
-      return;
-
-    const KeyT EmptyKey = getEmptyKey(), TombstoneKey = getTombstoneKey();
     for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
-      if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey) &&
-          !KeyInfoT::isEqual(P->getFirst(), TombstoneKey))
+      if (P->isValueConstructed()) {
         P->getSecond().~ValueT();
-      P->getFirst().~KeyT();
+        P->setValueConstructed(false);
+      }
+      if (P->isKeyConstructed()) {
+        P->getFirst().~KeyT();
+        P->setKeyConstructed(false);
+      }
+    }
+  }
+
+  void initUnitialized() {
+    BucketT *B = getBuckets();
+    BucketT *E = getBucketsEnd();
+    for (; B != E; ++B) {
+      B->setKeyConstructed(false);
+      B->setValueConstructed(false);
     }
   }
 
@@ -420,11 +485,19 @@ class DenseMapBase : public DebugEpochBase {
     setNumEntries(0);
     setNumTombstones(0);
 
-    assert((getNumBuckets() & (getNumBuckets()-1)) == 0 &&
+    assert((getNumBuckets() & (getNumBuckets() - 1)) == 0 &&
            "# initial buckets must be a power of two!");
     const KeyT EmptyKey = getEmptyKey();
-    for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B)
+#ifndef NDEBUG
+    for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
+      assert(!B->isKeyConstructed());
+      assert(!B->isValueConstructed());
+    }
+#endif
+    for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
       ::new (&B->getFirst()) KeyT(EmptyKey);
+      B->setKeyConstructed(true);
+    }
   }
 
   /// Returns the number of buckets to allocate to ensure that the DenseMap can
@@ -454,18 +527,23 @@ class DenseMapBase : public DebugEpochBase {
         assert(!FoundVal && "Key already in new map?");
         DestBucket->getFirst() = std::move(B->getFirst());
         ::new (&DestBucket->getSecond()) ValueT(std::move(B->getSecond()));
+        DestBucket->setValueConstructed(true);
         incrementNumEntries();
-
-        // Free the value.
+      }
+      if (B->isValueConstructed()) {
         B->getSecond().~ValueT();
+        B->setValueConstructed(false);
+      }
+      if (B->isKeyConstructed()) {
+        B->getFirst().~KeyT();
+        B->setKeyConstructed(false);
       }
-      B->getFirst().~KeyT();
     }
   }
 
   template <typename OtherBaseT>
-  void copyFrom(
-      const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT> &other) {
+  void copyFrom(const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT,
+                                   BucketBaseT> &other) {
     assert(&other != this);
     assert(getNumBuckets() == other.getNumBuckets());
 
@@ -480,10 +558,13 @@ class DenseMapBase : public DebugEpochBase {
       for (size_t i = 0; i < getNumBuckets(); ++i) {
         ::new (&getBuckets()[i].getFirst())
             KeyT(other.getBuckets()[i].getFirst());
+        getBuckets()[i].setKeyConstructed(true);
         if (!KeyInfoT::isEqual(getBuckets()[i].getFirst(), getEmptyKey()) &&
-            !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey()))
+            !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey())) {
           ::new (&getBuckets()[i].getSecond())
               ValueT(other.getBuckets()[i].getSecond());
+          getBuckets()[i].setValueConstructed(true);
+        }
       }
   }
 
@@ -491,7 +572,7 @@ class DenseMapBase : public DebugEpochBase {
     return KeyInfoT::getHashValue(Val);
   }
 
-  template<typename LookupKeyT>
+  template <typename LookupKeyT>
   static unsigned getHashValue(const LookupKeyT &Val) {
     return KeyInfoT::getHashValue(Val);
   }
@@ -502,14 +583,11 @@ class DenseMapBase : public DebugEpochBase {
     return KeyInfoT::getEmptyKey();
   }
 
-  static const KeyT getTombstoneKey() {
-    return KeyInfoT::getTombstoneKey();
-  }
+  static const KeyT getTombstoneKey() { return KeyInfoT::getTombstoneKey(); }
 
 private:
-  iterator makeIterator(BucketT *P, BucketT *E,
-                        DebugEpochBase &Epoch,
-                        bool NoAdvance=false) {
+  iterator makeIterator(BucketT *P, BucketT *E, DebugEpochBase &Epoch,
+                        bool NoAdvance = false) {
     if (shouldReverseIterate<KeyT>()) {
       BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1;
       return iterator(B, E, Epoch, NoAdvance);
@@ -519,7 +597,7 @@ class DenseMapBase : public DebugEpochBase {
 
   const_iterator makeConstIterator(const BucketT *P, const BucketT *E,
                                    const DebugEpochBase &Epoch,
-                                   const bool NoAdvance=false) const {
+                                   const bool NoAdvance = false) const {
     if (shouldReverseIterate<KeyT>()) {
       const BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1;
       return const_iterator(B, E, Epoch, NoAdvance);
@@ -535,13 +613,9 @@ class DenseMapBase : public DebugEpochBase {
     static_cast<DerivedT *>(this)->setNumEntries(Num);
   }
 
-  void incrementNumEntries() {
-    setNumEntries(getNumEntries() + 1);
-  }
+  void incrementNumEntries() { setNumEntries(getNumEntries() + 1); }
 
-  void decrementNumEntries() {
-    setNumEntries(getNumEntries() - 1);
-  }
+  void decrementNumEntries() { setNumEntries(getNumEntries() - 1); }
 
   unsigned getNumTombstones() const {
     return static_cast<const DerivedT *>(this)->getNumTombstones();
@@ -551,65 +625,52 @@ class DenseMapBase : public DebugEpochBase {
     static_cast<DerivedT *>(this)->setNumTombstones(Num);
   }
 
-  void incrementNumTombstones() {
-    setNumTombstones(getNumTombstones() + 1);
-  }
+  void incrementNumTombstones() { setNumTombstones(getNumTombstones() + 1); }
 
-  void decrementNumTombstones() {
-    setNumTombstones(getNumTombstones() - 1);
-  }
+  void decrementNumTombstones() { setNumTombstones(getNumTombstones() - 1); }
 
   const BucketT *getBuckets() const {
     return static_cast<const DerivedT *>(this)->getBuckets();
   }
 
-  BucketT *getBuckets() {
-    return static_cast<DerivedT *>(this)->getBuckets();
-  }
+  BucketT *getBuckets() { return static_cast<DerivedT *>(this)->getBuckets(); }
 
   unsigned getNumBuckets() const {
     return static_cast<const DerivedT *>(this)->getNumBuckets();
   }
 
-  BucketT *getBucketsEnd() {
-    return getBuckets() + getNumBuckets();
-  }
+  BucketT *getBucketsEnd() { return getBuckets() + getNumBuckets(); }
 
   const BucketT *getBucketsEnd() const {
     return getBuckets() + getNumBuckets();
   }
 
-  void grow(unsigned AtLeast) {
-    static_cast<DerivedT *>(this)->grow(AtLeast);
-  }
+  void grow(unsigned AtLeast) { static_cast<DerivedT *>(this)->grow(AtLeast); }
 
-  void shrink_and_clear() {
-    static_cast<DerivedT *>(this)->shrink_and_clear();
-  }
+  void shrink_and_clear() { static_cast<DerivedT *>(this)->shrink_and_clear(); }
 
   template <typename KeyArg, typename... ValueArgs>
   BucketT *InsertIntoBucket(BucketT *TheBucket, KeyArg &&Key,
-                            ValueArgs &&... Values) {
-    TheBucket = InsertIntoBucketImpl(Key, Key, TheBucket);
-
+                            ValueArgs &&...Values) {
+    TheBucket = InsertIntoBucketImpl(Key, TheBucket);
     TheBucket->getFirst() = std::forward<KeyArg>(Key);
     ::new (&TheBucket->getSecond()) ValueT(std::forward<ValueArgs>(Values)...);
+    TheBucket->setValueConstructed(true);
     return TheBucket;
   }
 
   template <typename LookupKeyT>
   BucketT *InsertIntoBucketWithLookup(BucketT *TheBucket, KeyT &&Key,
                                       ValueT &&Value, LookupKeyT &Lookup) {
-    TheBucket = InsertIntoBucketImpl(Key, Lookup, TheBucket);
-
+    TheBucket = InsertIntoBucketImpl(Lookup, TheBucket);
     TheBucket->getFirst() = std::move(Key);
     ::new (&TheBucket->getSecond()) ValueT(std::move(Value));
+    TheBucket->setValueConstructed(true);
     return TheBucket;
   }
 
   template <typename LookupKeyT>
-  BucketT *InsertIntoBucketImpl(const KeyT &Key, const LookupKeyT &Lookup,
-                                BucketT *TheBucket) {
+  BucketT *InsertIntoBucketImpl(const LookupKeyT &Lookup, BucketT *TheBucket) {
     incrementEpoch();
 
     // If the load of the hash table is more than 3/4, or if fewer than 1/8 of
@@ -627,8 +688,9 @@ class DenseMapBase : public DebugEpochBase {
       this->grow(NumBuckets * 2);
       LookupBucketFor(Lookup, TheBucket);
       NumBuckets = getNumBuckets();
-    } else if (LLVM_UNLIKELY(NumBuckets-(NewNumEntries+getNumTombstones()) <=
-                             NumBuckets/8)) {
+    } else if (LLVM_UNLIKELY(NumBuckets -
+                                 (NewNumEntries + getNumTombstones()) <=
+                             NumBuckets / 8)) {
       this->grow(NumBuckets);
       LookupBucketFor(Lookup, TheBucket);
     }
@@ -696,7 +758,7 @@ class DenseMapBase : public DebugEpochBase {
            !KeyInfoT::isEqual(Val, TombstoneKey) &&
            "Empty/Tombstone value shouldn't be inserted into map!");
 
-    unsigned BucketNo = getHashValue(Val) & (NumBuckets-1);
+    unsigned BucketNo = getHashValue(Val) & (NumBuckets - 1);
     unsigned ProbeAmt = 1;
     while (true) {
       BucketT *ThisBucket = BucketsPtr + BucketNo;
@@ -719,23 +781,63 @@ class DenseMapBase : public DebugEpochBase {
       // prefer to return it than something that would require more probing.
       if (KeyInfoT::isEqual(ThisBucket->getFirst(), TombstoneKey) &&
           !FoundTombstone)
-        FoundTombstone = ThisBucket;  // Remember the first tombstone found.
+        FoundTombstone = ThisBucket; // Remember the first tombstone found.
 
       // Otherwise, it's a hash collision or a tombstone, continue quadratic
       // probing.
       BucketNo += ProbeAmt++;
-      BucketNo &= (NumBuckets-1);
+      BucketNo &= (NumBuckets - 1);
     }
   }
 
+protected:
+  // helper class to guarantee deallocation of buffer
+  class ReleaseOldBuffer {
+    BucketT *m_buckets;
+    unsigned m_numBuckets;
+
+  public:
+    ReleaseOldBuffer(BucketT *buckets, unsigned numBuckets)
+        : m_buckets(buckets), m_numBuckets(numBuckets) {}
+    ~ReleaseOldBuffer() {
+#ifndef NDEBUG
+      memset((void *)m_buckets, 0x5a, sizeof(BucketT) * m_numBuckets);
+#endif
+      // Free the old table.
+      size_t const alignment{alignof(BucketT)};
+      deallocate_buffer(static_cast<void *>(m_buckets),
+                        sizeof(BucketT) * m_numBuckets, alignment);
+    }
+  };
+
+  // helper class to guarantee destruction of bucket content
+  class ReleaseOldBuckets {
+    BucketT *m_buckets;
+    unsigned m_numBuckets;
+
+  public:
+    ReleaseOldBuckets(BucketT *buckets, unsigned numBuckets)
+        : m_buckets(buckets), m_numBuckets(numBuckets) {}
+    ~ReleaseOldBuckets() {
+      for (BucketT *B = m_buckets, *E = m_buckets + m_numBuckets; B != E; ++B) {
+        if (B->isValueConstructed()) {
+          B->getSecond().~ValueT();
+          B->setValueConstructed(false);
+        }
+        if (B->isKeyConstructed()) {
+          B->getFirst().~KeyT();
+          B->setKeyConstructed(false);
+        }
+      }
+    }
+  };
+
 public:
   /// Return the approximate size (in bytes) of the actual map.
   /// This is just the raw memory used by DenseMap.
   /// If entries are pointers to objects, the size of the referenced objects
   /// are not included.
-  size_t getMemorySize() const {
-    return getNumBuckets() * sizeof(BucketT);
-  }
+  size...
[truncated]

@kazutakahirata
Copy link
Contributor

Is there any way you could split your patch into several smaller ones?

If you are going to work on pretty much all of DenseMap/DenseSet, you might want to post a PR just to clang-format the whole thing. This should separate interesting changes and the rest.

After that, I would get small obvious changes out of the way. For example, your if constexpr (std::is_trivially_destructible_v<ValueT>) change will probably receive a quick LGTM.

Thanks!

@marcauberer
Copy link
Member Author

Yes, agreed.
For the clang-format preparation change, see here: #108162

@nikic
Copy link
Contributor

nikic commented Sep 11, 2024

This direction looks quite concerning to me. LLVM is explicitly not designed to be exception safe, and this change is not at all free. It introduces significant implementation complexity, compile-time impact and memory usage impact.

The minimum requirement here is going to be that this does not have any compile-time/memory usage impact on people who do not need this functionality -- but I'm not convinced that we should accept this downstream at all, as a matter of policy.

Major compile-time regressions: https://llvm-compile-time-tracker.com/compare.php?from=f4dd1bc8fc625d3938f95b9d06aaaeebd2e90dca&to=1cb62cb5fb8eeca9641992f98f1b24b328410034&stat=instructions:u

Major memory usage regressions: https://llvm-compile-time-tracker.com/compare.php?from=f4dd1bc8fc625d3938f95b9d06aaaeebd2e90dca&to=1cb62cb5fb8eeca9641992f98f1b24b328410034&stat=max-rss

@kuhar
Copy link
Member

kuhar commented Sep 11, 2024

+1 to what @nikic said. As an ADT contributor, the thing I really value about our implementation is that it's much simpler than alternative hash maps that have to worry about exceptions. I think it's safe to say that LLVM developers at large are not used to / trained to think about exception safety, and this puts burden on subsequent contributors who may not care about handling these specific OOM conditions.

@marcauberer marcauberer force-pushed the llvm/adt/malfunction-safe-densemap branch from bb74000 to 6c9cc94 Compare September 11, 2024 19:17
@marcauberer marcauberer force-pushed the llvm/adt/malfunction-safe-densemap branch from 6c9cc94 to 7f437b1 Compare September 13, 2024 09:13
@marcauberer marcauberer changed the title [ADT] Make DenseMap/DenseSet more resilient agains OOM situations [ADT] Make DenseMap/DenseSet more resilient against OOM situations Sep 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants