Skip to content

Commit 6ccd59e

Browse files
kazutakahiratayuxuanchen1997
authored andcommitted
[ADT] Make set_subtract more efficient when subtrahend is larger (NFC) (#99401)
Summary: This patch is based on: commit fffe272 Author: Teresa Johnson <[email protected]> Date: Wed Jul 17 13:53:10 2024 -0700 This iteration comes with a couple of improvements: - We now accommodate S2Ty being SmallPtrSet, which has remove_if(pred) but not erase(iterator). (Lack of this code path broke the mlir build.) - The code path for erase(iterator) now pre-increments the iterator to avoid problems with iterator invalidation. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251011
1 parent 9ed22ac commit 6ccd59e

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

llvm/include/llvm/ADT/SetOperations.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ using check_has_member_remove_if_t =
2727
template <typename Set, typename Fn>
2828
static constexpr bool HasMemberRemoveIf =
2929
is_detected<check_has_member_remove_if_t, Set, Fn>::value;
30+
31+
template <typename Set>
32+
using check_has_member_erase_iter_t =
33+
decltype(std::declval<Set>().erase(std::declval<Set>().begin()));
34+
35+
template <typename Set>
36+
static constexpr bool HasMemberEraseIter =
37+
is_detected<check_has_member_erase_iter_t, Set>::value;
38+
3039
} // namespace detail
3140

3241
/// set_union(A, B) - Compute A := A u B, return whether A changed.
@@ -94,7 +103,35 @@ S1Ty set_difference(const S1Ty &S1, const S2Ty &S2) {
94103

95104
/// set_subtract(A, B) - Compute A := A - B
96105
///
106+
/// Selects the set to iterate based on the relative sizes of A and B for better
107+
/// efficiency.
108+
///
97109
template <class S1Ty, class S2Ty> void set_subtract(S1Ty &S1, const S2Ty &S2) {
110+
// If S1 is smaller than S2, iterate on S1 provided that S2 supports efficient
111+
// lookups via contains(). Note that a couple callers pass a vector for S2,
112+
// which doesn't support contains(), and wouldn't be efficient if it did.
113+
using ElemTy = decltype(*S1.begin());
114+
if constexpr (detail::HasMemberContains<S2Ty, ElemTy>) {
115+
auto Pred = [&S2](const auto &E) { return S2.contains(E); };
116+
if constexpr (detail::HasMemberRemoveIf<S1Ty, decltype(Pred)>) {
117+
if (S1.size() < S2.size()) {
118+
S1.remove_if(Pred);
119+
return;
120+
}
121+
} else if constexpr (detail::HasMemberEraseIter<S1Ty>) {
122+
if (S1.size() < S2.size()) {
123+
typename S1Ty::iterator Next;
124+
for (typename S1Ty::iterator SI = S1.begin(), SE = S1.end(); SI != SE;
125+
SI = Next) {
126+
Next = std::next(SI);
127+
if (S2.contains(*SI))
128+
S1.erase(SI);
129+
}
130+
return;
131+
}
132+
}
133+
}
134+
98135
for (typename S2Ty::const_iterator SI = S2.begin(), SE = S2.end(); SI != SE;
99136
++SI)
100137
S1.erase(*SI);

0 commit comments

Comments
 (0)