Skip to content

Commit 6c4c44b

Browse files
authored
[SetOperations] Support set containers with remove_if (#96613)
The current set_intersect implementation only works for std::set style sets that have a value-erase method that does not invalidate iterators. As such, it cannot be used for set containers like SetVector, which only has iterator-invalidating erase. Support such set containers by calling the remove_if method instead, if it exists. The detection code is adopted from how contains() is detected inside llvm::is_contained().
1 parent 28a3fbb commit 6c4c44b

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

llvm/include/llvm/ADT/SetOperations.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,20 @@
1515
#ifndef LLVM_ADT_SETOPERATIONS_H
1616
#define LLVM_ADT_SETOPERATIONS_H
1717

18+
#include "llvm/ADT/STLExtras.h"
19+
1820
namespace llvm {
1921

22+
namespace detail {
23+
template <typename Set, typename Fn>
24+
using check_has_member_remove_if_t =
25+
decltype(std::declval<Set>().remove_if(std::declval<Fn>()));
26+
27+
template <typename Set, typename Fn>
28+
static constexpr bool HasMemberRemoveIf =
29+
is_detected<check_has_member_remove_if_t, Set, Fn>::value;
30+
} // namespace detail
31+
2032
/// set_union(A, B) - Compute A := A u B, return whether A changed.
2133
///
2234
template <class S1Ty, class S2Ty> bool set_union(S1Ty &S1, const S2Ty &S2) {
@@ -36,11 +48,16 @@ template <class S1Ty, class S2Ty> bool set_union(S1Ty &S1, const S2Ty &S2) {
3648
/// elements that are not contained in S2.
3749
///
3850
template <class S1Ty, class S2Ty> void set_intersect(S1Ty &S1, const S2Ty &S2) {
39-
for (typename S1Ty::iterator I = S1.begin(); I != S1.end();) {
40-
const auto &E = *I;
41-
++I;
42-
if (!S2.count(E))
43-
S1.erase(E); // Erase element if not in S2
51+
auto Pred = [&S2](const auto &E) { return !S2.count(E); };
52+
if constexpr (detail::HasMemberRemoveIf<S1Ty, decltype(Pred)>) {
53+
S1.remove_if(Pred);
54+
} else {
55+
for (typename S1Ty::iterator I = S1.begin(); I != S1.end();) {
56+
const auto &E = *I;
57+
++I;
58+
if (!S2.count(E))
59+
S1.erase(E); // Erase element if not in S2
60+
}
4461
}
4562
}
4663

llvm/unittests/ADT/SetOperationsTest.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "llvm/ADT/SetOperations.h"
10+
#include "llvm/ADT/SetVector.h"
1011
#include "gmock/gmock.h"
1112
#include "gtest/gtest.h"
1213

@@ -65,6 +66,16 @@ TEST(SetOperationsTest, SetIntersect) {
6566
// is empty as they are non-overlapping.
6667
EXPECT_THAT(Set1, IsEmpty());
6768
EXPECT_EQ(ExpectedSet2, Set2);
69+
70+
// Check that set_intersect works on SetVector via remove_if.
71+
SmallSetVector<int, 4> SV;
72+
SV.insert(3);
73+
SV.insert(6);
74+
SV.insert(4);
75+
SV.insert(5);
76+
set_intersect(SV, Set2);
77+
// SV should contain only 6 and 5 now.
78+
EXPECT_THAT(SV, testing::ElementsAre(6, 5));
6879
}
6980

7081
TEST(SetOperationsTest, SetIntersection) {

0 commit comments

Comments
 (0)