Skip to content

[SetOperations] Support set containers with remove_if #96613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 26, 2024

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Jun 25, 2024

The current set_intersect implementation only works for std::set style sets that have a value-erase method that does not invalidate iterators. As such, it cannot be used for set containers like SetVector, which only has iterator-invalidating erase.

Support such set containers by calling the remove_if method instead, if it exists. The detection code is adopted from how contains() is detected inside llvm::is_contained().

The current set_intersect implementation only works for std::set
style sets that have a value-erase method that does not invalidate
iterators. As such, it cannot be used for set containers like
SetVector, which only has iterator-invalidating erase.

Support such set containers by calling the remove_if method instead,
if it exists. The detection code is adopted from how contains()
is detected inside llvm::is_contained().
@llvmbot
Copy link
Member

llvmbot commented Jun 25, 2024

@llvm/pr-subscribers-llvm-adt

Author: Nikita Popov (nikic)

Changes

The current set_intersect implementation only works for std::set style sets that have a value-erase method that does not invalidate iterators. As such, it cannot be used for set containers like SetVector, which only has iterator-invalidating erase.

Support such set containers by calling the remove_if method instead, if it exists. The detection code is adopted from how contains() is detected inside llvm::is_contained().


Full diff: https://github.com/llvm/llvm-project/pull/96613.diff

2 Files Affected:

  • (modified) llvm/include/llvm/ADT/SetOperations.h (+22-5)
  • (modified) llvm/unittests/ADT/SetOperationsTest.cpp (+11)
diff --git a/llvm/include/llvm/ADT/SetOperations.h b/llvm/include/llvm/ADT/SetOperations.h
index 6c04c764e52076..9e125388509bde 100644
--- a/llvm/include/llvm/ADT/SetOperations.h
+++ b/llvm/include/llvm/ADT/SetOperations.h
@@ -15,8 +15,20 @@
 #ifndef LLVM_ADT_SETOPERATIONS_H
 #define LLVM_ADT_SETOPERATIONS_H
 
+#include "llvm/ADT/STLExtras.h"
+
 namespace llvm {
 
+namespace detail {
+template <typename Set, typename Fn>
+using check_has_member_remove_if_t =
+    decltype(std::declval<Set>().remove_if(std::declval<Fn>()));
+
+template <typename Set, typename Fn>
+static constexpr bool HasMemberRemoveIf =
+    is_detected<check_has_member_remove_if_t, Set, Fn>::value;
+} // namespace detail
+
 /// set_union(A, B) - Compute A := A u B, return whether A changed.
 ///
 template <class S1Ty, class S2Ty> bool set_union(S1Ty &S1, const S2Ty &S2) {
@@ -36,11 +48,16 @@ template <class S1Ty, class S2Ty> bool set_union(S1Ty &S1, const S2Ty &S2) {
 /// elements that are not contained in S2.
 ///
 template <class S1Ty, class S2Ty> void set_intersect(S1Ty &S1, const S2Ty &S2) {
-  for (typename S1Ty::iterator I = S1.begin(); I != S1.end();) {
-    const auto &E = *I;
-    ++I;
-    if (!S2.count(E))
-      S1.erase(E); // Erase element if not in S2
+  if constexpr (detail::HasMemberRemoveIf<S1Ty,
+                                          bool (*)(decltype(*S2.begin()))>) {
+    S1.remove_if([S2](const auto &E) { return !S2.count(E); });
+  } else {
+    for (typename S1Ty::iterator I = S1.begin(); I != S1.end();) {
+      const auto &E = *I;
+      ++I;
+      if (!S2.count(E))
+        S1.erase(E); // Erase element if not in S2
+    }
   }
 }
 
diff --git a/llvm/unittests/ADT/SetOperationsTest.cpp b/llvm/unittests/ADT/SetOperationsTest.cpp
index 982ea819fd809e..7bd5189e488216 100644
--- a/llvm/unittests/ADT/SetOperationsTest.cpp
+++ b/llvm/unittests/ADT/SetOperationsTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
@@ -65,6 +66,16 @@ TEST(SetOperationsTest, SetIntersect) {
   // is empty as they are non-overlapping.
   EXPECT_THAT(Set1, IsEmpty());
   EXPECT_EQ(ExpectedSet2, Set2);
+
+  // Check that set_intersect works on SetVector via remove_if.
+  SmallSetVector<int, 4> SV;
+  SV.insert(3);
+  SV.insert(6);
+  SV.insert(4);
+  SV.insert(5);
+  set_intersect(SV, Set2);
+  // SV should contain only 6 and 5 now.
+  EXPECT_EQ(SV.getArrayRef(), ArrayRef({6, 5}));
 }
 
 TEST(SetOperationsTest, SetIntersection) {

@nikic nikic merged commit 6c4c44b into llvm:main Jun 26, 2024
7 checks passed
@nikic nikic deleted the set-intersect-remove-if branch June 26, 2024 07:20
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
The current set_intersect implementation only works for std::set style
sets that have a value-erase method that does not invalidate iterators.
As such, it cannot be used for set containers like SetVector, which only
has iterator-invalidating erase.

Support such set containers by calling the remove_if method instead, if
it exists. The detection code is adopted from how contains() is detected
inside llvm::is_contained().
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants