Skip to content

Commit fd6260f

Browse files
authored
[EquivClasses] Shorten members_{begin,end} idiom (#134373)
Introduce members() iterator-helper to shorten the members_{begin,end} idiom. A previous attempt of this patch was #130319, which had to be reverted due to unit-test failures when attempting to call members() on the end iterator. In this patch, members() accepts either an ECValue or an ElemTy, which is more intuitive and doesn't suffer from the same issue.
1 parent fb9deab commit fd6260f

File tree

7 files changed

+39
-22
lines changed

7 files changed

+39
-22
lines changed

llvm/include/llvm/ADT/EquivalenceClasses.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define LLVM_ADT_EQUIVALENCECLASSES_H
1717

1818
#include "llvm/ADT/SmallVector.h"
19+
#include "llvm/ADT/iterator_range.h"
1920
#include <cassert>
2021
#include <cstddef>
2122
#include <cstdint>
@@ -184,6 +185,14 @@ class EquivalenceClasses {
184185
return member_iterator(nullptr);
185186
}
186187

188+
iterator_range<member_iterator> members(const ECValue &ECV) const {
189+
return make_range(member_begin(ECV), member_end());
190+
}
191+
192+
iterator_range<member_iterator> members(const ElemTy &V) const {
193+
return make_range(findLeader(V), member_end());
194+
}
195+
187196
/// Returns true if \p V is contained an equivalence class.
188197
bool contains(const ElemTy &V) const {
189198
return TheMapping.find(V) != TheMapping.end();

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,8 @@ void RuntimePointerChecking::groupChecks(
526526
// iteration order within an equivalence class member is only dependent on
527527
// the order in which unions and insertions are performed on the
528528
// equivalence class, the iteration order is deterministic.
529-
for (auto MI = DepCands.findLeader(Access), ME = DepCands.member_end();
530-
MI != ME; ++MI) {
531-
auto PointerI = PositionMap.find(MI->getPointer());
529+
for (auto M : DepCands.members(Access)) {
530+
auto PointerI = PositionMap.find(M.getPointer());
532531
assert(PointerI != PositionMap.end() &&
533532
"pointer in equivalence class not found in PositionMap");
534533
for (unsigned Pointer : PointerI->second) {

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
847847
if (!E->isLeader())
848848
continue;
849849
uint64_t LeaderDemandedBits = 0;
850-
for (Value *M : make_range(ECs.member_begin(*E), ECs.member_end()))
850+
for (Value *M : ECs.members(*E))
851851
LeaderDemandedBits |= DBits[M];
852852

853853
uint64_t MinBW = llvm::bit_width(LeaderDemandedBits);
@@ -859,15 +859,15 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
859859
// indvars.
860860
// If we are required to shrink a PHI, abandon this entire equivalence class.
861861
bool Abort = false;
862-
for (Value *M : make_range(ECs.member_begin(*E), ECs.member_end()))
862+
for (Value *M : ECs.members(*E))
863863
if (isa<PHINode>(M) && MinBW < M->getType()->getScalarSizeInBits()) {
864864
Abort = true;
865865
break;
866866
}
867867
if (Abort)
868868
continue;
869869

870-
for (Value *M : make_range(ECs.member_begin(*E), ECs.member_end())) {
870+
for (Value *M : ECs.members(*E)) {
871871
auto *MI = dyn_cast<Instruction>(M);
872872
if (!MI)
873873
continue;

llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,9 +1021,8 @@ void RecursiveSearchSplitting::setupWorkList() {
10211021
continue;
10221022

10231023
BitVector Cluster = SG.createNodesBitVector();
1024-
for (auto MI = NodeEC.member_begin(*Node); MI != NodeEC.member_end();
1025-
++MI) {
1026-
const SplitGraph::Node &N = SG.getNode(*MI);
1024+
for (unsigned M : NodeEC.members(*Node)) {
1025+
const SplitGraph::Node &N = SG.getNode(M);
10271026
if (N.isGraphEntryPoint())
10281027
N.getDependencies(Cluster);
10291028
}

llvm/lib/Transforms/IPO/LowerTypeTests.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,14 +2349,13 @@ bool LowerTypeTestsModule::lower() {
23492349
std::vector<Metadata *> TypeIds;
23502350
std::vector<GlobalTypeMember *> Globals;
23512351
std::vector<ICallBranchFunnel *> ICallBranchFunnels;
2352-
for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(*C);
2353-
MI != GlobalClasses.member_end(); ++MI) {
2354-
if (isa<Metadata *>(*MI))
2355-
TypeIds.push_back(cast<Metadata *>(*MI));
2356-
else if (isa<GlobalTypeMember *>(*MI))
2357-
Globals.push_back(cast<GlobalTypeMember *>(*MI));
2352+
for (auto M : GlobalClasses.members(*C)) {
2353+
if (isa<Metadata *>(M))
2354+
TypeIds.push_back(cast<Metadata *>(M));
2355+
else if (isa<GlobalTypeMember *>(M))
2356+
Globals.push_back(cast<GlobalTypeMember *>(M));
23582357
else
2359-
ICallBranchFunnels.push_back(cast<ICallBranchFunnel *>(*MI));
2358+
ICallBranchFunnels.push_back(cast<ICallBranchFunnel *>(M));
23602359
}
23612360

23622361
// Order type identifiers by unique ID for determinism. This ordering is

llvm/lib/Transforms/Scalar/Float2Int.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
320320
Type *ConvertedToTy = nullptr;
321321

322322
// For every member of the partition, union all the ranges together.
323-
for (auto MI = ECs.member_begin(*E), ME = ECs.member_end(); MI != ME;
324-
++MI) {
325-
Instruction *I = *MI;
326-
auto SeenI = SeenInsts.find(I);
323+
for (Instruction *I : ECs.members(*E)) {
324+
auto *SeenI = SeenInsts.find(I);
327325
if (SeenI == SeenInsts.end())
328326
continue;
329327

@@ -391,8 +389,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
391389
}
392390
}
393391

394-
for (auto MI = ECs.member_begin(*E), ME = ECs.member_end(); MI != ME; ++MI)
395-
convert(*MI, Ty);
392+
for (Instruction *I : ECs.members(*E))
393+
convert(I, Ty);
396394
MadeChange = true;
397395
}
398396

llvm/unittests/ADT/EquivalenceClassesTest.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "llvm/ADT/EquivalenceClasses.h"
10+
#include "gmock/gmock.h"
1011
#include "gtest/gtest.h"
1112

1213
using namespace llvm;
@@ -75,6 +76,18 @@ TEST(EquivalenceClassesTest, TwoSets) {
7576
EXPECT_FALSE(EqClasses.isEquivalent(i, j));
7677
}
7778

79+
TEST(EquivalenceClassesTest, MembersIterator) {
80+
EquivalenceClasses<int> EC;
81+
EC.unionSets(1, 2);
82+
EC.insert(4);
83+
EC.insert(5);
84+
EC.unionSets(5, 1);
85+
EXPECT_EQ(EC.getNumClasses(), 2u);
86+
87+
EXPECT_THAT(EC.members(4), testing::ElementsAre(4));
88+
EXPECT_THAT(EC.members(1), testing::ElementsAre(5, 1, 2));
89+
}
90+
7891
// Type-parameterized tests: Run the same test cases with different element
7992
// types.
8093
template <typename T> class ParameterizedTest : public testing::Test {};

0 commit comments

Comments
 (0)