Skip to content

Commit a4916d2

Browse files
authored
[SandboxVec][DAG] Refactoring: Move MemPreds from DGNode to MemDGNode (#111897)
1 parent 07892aa commit a4916d2

File tree

3 files changed

+72
-57
lines changed

3 files changed

+72
-57
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- DependencyGraph.h ----------------------------------*- C++ -*-===//
1+
//===- DependencyGraph.h ----------------------------------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -96,9 +96,6 @@ class DGNode {
9696
// TODO: Use a PointerIntPair for SubclassID and I.
9797
/// For isa/dyn_cast etc.
9898
DGNodeID SubclassID;
99-
// TODO: Move MemPreds to MemDGNode.
100-
/// Memory predecessors.
101-
DenseSet<MemDGNode *> MemPreds;
10299

103100
DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
104101
friend class MemDGNode; // For constructor.
@@ -170,17 +167,6 @@ class DGNode {
170167
}
171168

172169
Instruction *getInstruction() const { return I; }
173-
void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); }
174-
/// \Returns all memory dependency predecessors.
175-
iterator_range<DenseSet<MemDGNode *>::const_iterator> memPreds() const {
176-
return make_range(MemPreds.begin(), MemPreds.end());
177-
}
178-
/// \Returns true if there is a memory dependency N->this.
179-
bool hasMemPred(DGNode *N) const {
180-
if (auto *MN = dyn_cast<MemDGNode>(N))
181-
return MemPreds.count(MN);
182-
return false;
183-
}
184170

185171
#ifndef NDEBUG
186172
virtual void print(raw_ostream &OS, bool PrintDeps = true) const;
@@ -198,6 +184,9 @@ class DGNode {
198184
class MemDGNode final : public DGNode {
199185
MemDGNode *PrevMemN = nullptr;
200186
MemDGNode *NextMemN = nullptr;
187+
/// Memory predecessors.
188+
DenseSet<MemDGNode *> MemPreds;
189+
friend class PredIterator; // For MemPreds.
201190

202191
void setNextNode(MemDGNode *N) { NextMemN = N; }
203192
void setPrevNode(MemDGNode *N) { PrevMemN = N; }
@@ -222,6 +211,21 @@ class MemDGNode final : public DGNode {
222211
MemDGNode *getPrevNode() const { return PrevMemN; }
223212
/// \Returns the next Mem DGNode in instruction order.
224213
MemDGNode *getNextNode() const { return NextMemN; }
214+
/// Adds the mem dependency edge PredN->this.
215+
void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); }
216+
/// \Returns true if there is a memory dependency N->this.
217+
bool hasMemPred(DGNode *N) const {
218+
if (auto *MN = dyn_cast<MemDGNode>(N))
219+
return MemPreds.count(MN);
220+
return false;
221+
}
222+
/// \Returns all memory dependency predecessors. Used by tests.
223+
iterator_range<DenseSet<MemDGNode *>::const_iterator> memPreds() const {
224+
return make_range(MemPreds.begin(), MemPreds.end());
225+
}
226+
#ifndef NDEBUG
227+
virtual void print(raw_ostream &OS, bool PrintDeps = true) const override;
228+
#endif // NDEBUG
225229
};
226230

227231
/// Convenience builders for a MemDGNode interval.
@@ -266,7 +270,7 @@ class DependencyGraph {
266270

267271
/// Go through all mem nodes in \p SrcScanRange and try to add dependencies to
268272
/// \p DstN.
269-
void scanAndAddDeps(DGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
273+
void scanAndAddDeps(MemDGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
270274

271275
public:
272276
DependencyGraph(AAResults &AA)

llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ PredIterator::value_type PredIterator::operator*() {
2323
// or a mem predecessor.
2424
if (OpIt != OpItE)
2525
return DAG->getNode(cast<Instruction>((Value *)*OpIt));
26-
assert(MemIt != cast<MemDGNode>(N)->memPreds().end() &&
26+
// It's a MemDGNode with OpIt == end, so we need to use MemIt.
27+
assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
2728
"Cant' dereference end iterator!");
2829
return *MemIt;
2930
}
@@ -45,7 +46,8 @@ PredIterator &PredIterator::operator++() {
4546
OpIt = skipNonInstr(OpIt, OpItE);
4647
return *this;
4748
}
48-
assert(MemIt != cast<MemDGNode>(N)->memPreds().end() && "Already at end!");
49+
// It's a MemDGNode with OpIt == end, so we need to increment MemIt.
50+
assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
4951
++MemIt;
5052
return *this;
5153
}
@@ -57,10 +59,14 @@ bool PredIterator::operator==(const PredIterator &Other) const {
5759
}
5860

5961
#ifndef NDEBUG
60-
void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
62+
void DGNode::print(raw_ostream &OS, bool PrintDeps) const { I->dumpOS(OS); }
63+
void DGNode::dump() const {
64+
print(dbgs());
65+
dbgs() << "\n";
66+
}
67+
void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
6168
I->dumpOS(OS);
6269
if (PrintDeps) {
63-
OS << "\n";
6470
// Print memory preds.
6571
static constexpr const unsigned Indent = 4;
6672
for (auto *Pred : MemPreds) {
@@ -70,10 +76,6 @@ void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
7076
}
7177
}
7278
}
73-
void DGNode::dump() const {
74-
print(dbgs());
75-
dbgs() << "\n";
76-
}
7779
#endif // NDEBUG
7880

7981
Interval<MemDGNode>
@@ -179,7 +181,7 @@ bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
179181
llvm_unreachable("Unknown DependencyType enum");
180182
}
181183

182-
void DependencyGraph::scanAndAddDeps(DGNode &DstN,
184+
void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
183185
const Interval<MemDGNode> &SrcScanRange) {
184186
assert(isa<MemDGNode>(DstN) &&
185187
"DstN is the mem dep destination, so it must be mem");

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ struct DependencyGraphTest : public testing::Test {
5050
return *AA;
5151
}
5252
/// \Returns true if there is a dependency: SrcN->DstN.
53-
bool dependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) {
54-
const auto &Preds = DstN->memPreds();
55-
auto It = find(Preds, SrcN);
56-
return It != Preds.end();
53+
bool memDependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) {
54+
if (auto *MemDstN = dyn_cast<sandboxir::MemDGNode>(DstN))
55+
return MemDstN->hasMemPred(SrcN);
56+
return false;
5757
}
5858
};
5959

@@ -230,9 +230,10 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
230230
EXPECT_EQ(Span.top(), &*BB->begin());
231231
EXPECT_EQ(Span.bottom(), BB->getTerminator());
232232

233-
sandboxir::DGNode *N0 = DAG.getNode(S0);
234-
sandboxir::DGNode *N1 = DAG.getNode(S1);
235-
sandboxir::DGNode *N2 = DAG.getNode(Ret);
233+
auto *N0 = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
234+
auto *N1 = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
235+
auto *N2 = DAG.getNode(Ret);
236+
236237
// Check getInstruction().
237238
EXPECT_EQ(N0->getInstruction(), S0);
238239
EXPECT_EQ(N1->getInstruction(), S1);
@@ -247,7 +248,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
247248
// Check memPreds().
248249
EXPECT_TRUE(N0->memPreds().empty());
249250
EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
250-
EXPECT_TRUE(N2->memPreds().empty());
251+
EXPECT_TRUE(N2->preds(DAG).empty());
251252
}
252253

253254
TEST_F(DependencyGraphTest, Preds) {
@@ -399,12 +400,14 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
399400
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
400401
DAG.extend({&*BB->begin(), BB->getTerminator()});
401402
auto It = BB->begin();
402-
auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
403-
auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
403+
auto *Store0N = cast<sandboxir::MemDGNode>(
404+
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
405+
auto *Store1N = cast<sandboxir::MemDGNode>(
406+
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
404407
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
405408
EXPECT_TRUE(Store0N->memPreds().empty());
406409
EXPECT_THAT(Store1N->memPreds(), testing::ElementsAre(Store0N));
407-
EXPECT_TRUE(RetN->memPreds().empty());
410+
EXPECT_TRUE(RetN->preds(DAG).empty());
408411
}
409412

410413
TEST_F(DependencyGraphTest, NonAliasingStores) {
@@ -422,13 +425,15 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) {
422425
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
423426
DAG.extend({&*BB->begin(), BB->getTerminator()});
424427
auto It = BB->begin();
425-
auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
426-
auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
428+
auto *Store0N = cast<sandboxir::MemDGNode>(
429+
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
430+
auto *Store1N = cast<sandboxir::MemDGNode>(
431+
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
427432
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
428433
// We expect no dependencies because the stores don't alias.
429434
EXPECT_TRUE(Store0N->memPreds().empty());
430435
EXPECT_TRUE(Store1N->memPreds().empty());
431-
EXPECT_TRUE(RetN->memPreds().empty());
436+
EXPECT_TRUE(RetN->preds(DAG).empty());
432437
}
433438

434439
TEST_F(DependencyGraphTest, VolatileLoads) {
@@ -446,12 +451,14 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
446451
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
447452
DAG.extend({&*BB->begin(), BB->getTerminator()});
448453
auto It = BB->begin();
449-
auto *Ld0N = DAG.getNode(cast<sandboxir::LoadInst>(&*It++));
450-
auto *Ld1N = DAG.getNode(cast<sandboxir::LoadInst>(&*It++));
454+
auto *Ld0N = cast<sandboxir::MemDGNode>(
455+
DAG.getNode(cast<sandboxir::LoadInst>(&*It++)));
456+
auto *Ld1N = cast<sandboxir::MemDGNode>(
457+
DAG.getNode(cast<sandboxir::LoadInst>(&*It++)));
451458
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
452459
EXPECT_TRUE(Ld0N->memPreds().empty());
453460
EXPECT_THAT(Ld1N->memPreds(), testing::ElementsAre(Ld0N));
454-
EXPECT_TRUE(RetN->memPreds().empty());
461+
EXPECT_TRUE(RetN->preds(DAG).empty());
455462
}
456463

457464
TEST_F(DependencyGraphTest, VolatileSotres) {
@@ -469,12 +476,14 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v) {
469476
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
470477
DAG.extend({&*BB->begin(), BB->getTerminator()});
471478
auto It = BB->begin();
472-
auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
473-
auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
479+
auto *Store0N = cast<sandboxir::MemDGNode>(
480+
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
481+
auto *Store1N = cast<sandboxir::MemDGNode>(
482+
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
474483
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
475484
EXPECT_TRUE(Store0N->memPreds().empty());
476485
EXPECT_THAT(Store1N->memPreds(), testing::ElementsAre(Store0N));
477-
EXPECT_TRUE(RetN->memPreds().empty());
486+
EXPECT_TRUE(RetN->preds(DAG).empty());
478487
}
479488

480489
TEST_F(DependencyGraphTest, Call) {
@@ -498,12 +507,12 @@ define void @foo(float %v1, float %v2) {
498507
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
499508

500509
auto It = BB->begin();
501-
auto *Call1N = DAG.getNode(&*It++);
510+
auto *Call1N = cast<sandboxir::MemDGNode>(DAG.getNode(&*It++));
502511
auto *AddN = DAG.getNode(&*It++);
503-
auto *Call2N = DAG.getNode(&*It++);
512+
auto *Call2N = cast<sandboxir::MemDGNode>(DAG.getNode(&*It++));
504513

505514
EXPECT_THAT(Call1N->memPreds(), testing::ElementsAre());
506-
EXPECT_THAT(AddN->memPreds(), testing::ElementsAre());
515+
EXPECT_THAT(AddN->preds(DAG), testing::ElementsAre());
507516
EXPECT_THAT(Call2N->memPreds(), testing::ElementsAre(Call1N));
508517
}
509518

@@ -534,8 +543,8 @@ define void @foo() {
534543
auto *AllocaN = DAG.getNode(&*It++);
535544
auto *StackRestoreN = DAG.getNode(&*It++);
536545

537-
EXPECT_TRUE(dependency(AllocaN, StackRestoreN));
538-
EXPECT_TRUE(dependency(StackSaveN, AllocaN));
546+
EXPECT_TRUE(memDependency(AllocaN, StackRestoreN));
547+
EXPECT_TRUE(memDependency(StackSaveN, AllocaN));
539548
}
540549

541550
// Checks that stacksave and stackrestore depend on other mem instrs.
@@ -567,9 +576,9 @@ define void @foo(i8 %v0, i8 %v1, ptr %ptr) {
567576
auto *StackRestoreN = DAG.getNode(&*It++);
568577
auto *Store1N = DAG.getNode(&*It++);
569578

570-
EXPECT_TRUE(dependency(Store0N, StackSaveN));
571-
EXPECT_TRUE(dependency(StackSaveN, StackRestoreN));
572-
EXPECT_TRUE(dependency(StackRestoreN, Store1N));
579+
EXPECT_TRUE(memDependency(Store0N, StackSaveN));
580+
EXPECT_TRUE(memDependency(StackSaveN, StackRestoreN));
581+
EXPECT_TRUE(memDependency(StackRestoreN, Store1N));
573582
}
574583

575584
// Make sure there is a dependency between a stackrestore and an alloca.
@@ -596,7 +605,7 @@ define void @foo(ptr %ptr) {
596605
auto *StackRestoreN = DAG.getNode(&*It++);
597606
auto *AllocaN = DAG.getNode(&*It++);
598607

599-
EXPECT_TRUE(dependency(StackRestoreN, AllocaN));
608+
EXPECT_TRUE(memDependency(StackRestoreN, AllocaN));
600609
}
601610

602611
// Make sure there is a dependency between the alloca and stacksave
@@ -623,7 +632,7 @@ define void @foo(ptr %ptr) {
623632
auto *AllocaN = DAG.getNode(&*It++);
624633
auto *StackSaveN = DAG.getNode(&*It++);
625634

626-
EXPECT_TRUE(dependency(AllocaN, StackSaveN));
635+
EXPECT_TRUE(memDependency(AllocaN, StackSaveN));
627636
}
628637

629638
// A non-InAlloca in a stacksave-stackrestore region does not need extra
@@ -655,6 +664,6 @@ define void @foo() {
655664
auto *AllocaN = DAG.getNode(&*It++);
656665
auto *StackRestoreN = DAG.getNode(&*It++);
657666

658-
EXPECT_FALSE(dependency(StackSaveN, AllocaN));
659-
EXPECT_FALSE(dependency(AllocaN, StackRestoreN));
667+
EXPECT_FALSE(memDependency(StackSaveN, AllocaN));
668+
EXPECT_FALSE(memDependency(AllocaN, StackRestoreN));
660669
}

0 commit comments

Comments
 (0)