Skip to content

[SandboxVec][DAG] Refactoring: Move MemPreds from DGNode to MemDGNode #111897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- DependencyGraph.h ----------------------------------*- C++ -*-===//
//===- DependencyGraph.h ----------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -96,9 +96,6 @@ class DGNode {
// TODO: Use a PointerIntPair for SubclassID and I.
/// For isa/dyn_cast etc.
DGNodeID SubclassID;
// TODO: Move MemPreds to MemDGNode.
/// Memory predecessors.
DenseSet<MemDGNode *> MemPreds;

DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
friend class MemDGNode; // For constructor.
Expand Down Expand Up @@ -170,17 +167,6 @@ class DGNode {
}

Instruction *getInstruction() const { return I; }
void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); }
/// \Returns all memory dependency predecessors.
iterator_range<DenseSet<MemDGNode *>::const_iterator> memPreds() const {
return make_range(MemPreds.begin(), MemPreds.end());
}
/// \Returns true if there is a memory dependency N->this.
bool hasMemPred(DGNode *N) const {
if (auto *MN = dyn_cast<MemDGNode>(N))
return MemPreds.count(MN);
return false;
}

#ifndef NDEBUG
virtual void print(raw_ostream &OS, bool PrintDeps = true) const;
Expand All @@ -198,6 +184,9 @@ class DGNode {
class MemDGNode final : public DGNode {
MemDGNode *PrevMemN = nullptr;
MemDGNode *NextMemN = nullptr;
/// Memory predecessors.
DenseSet<MemDGNode *> MemPreds;
friend class PredIterator; // For MemPreds.

void setNextNode(MemDGNode *N) { NextMemN = N; }
void setPrevNode(MemDGNode *N) { PrevMemN = N; }
Expand All @@ -222,6 +211,21 @@ class MemDGNode final : public DGNode {
MemDGNode *getPrevNode() const { return PrevMemN; }
/// \Returns the next Mem DGNode in instruction order.
MemDGNode *getNextNode() const { return NextMemN; }
/// Adds the mem dependency edge PredN->this.
void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); }
/// \Returns true if there is a memory dependency N->this.
bool hasMemPred(DGNode *N) const {
if (auto *MN = dyn_cast<MemDGNode>(N))
return MemPreds.count(MN);
return false;
}
/// \Returns all memory dependency predecessors. Used by tests.
iterator_range<DenseSet<MemDGNode *>::const_iterator> memPreds() const {
return make_range(MemPreds.begin(), MemPreds.end());
}
#ifndef NDEBUG
virtual void print(raw_ostream &OS, bool PrintDeps = true) const override;
#endif // NDEBUG
};

/// Convenience builders for a MemDGNode interval.
Expand Down Expand Up @@ -266,7 +270,7 @@ class DependencyGraph {

/// Go through all mem nodes in \p SrcScanRange and try to add dependencies to
/// \p DstN.
void scanAndAddDeps(DGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
void scanAndAddDeps(MemDGNode &DstN, const Interval<MemDGNode> &SrcScanRange);

public:
DependencyGraph(AAResults &AA)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ PredIterator::value_type PredIterator::operator*() {
// or a mem predecessor.
if (OpIt != OpItE)
return DAG->getNode(cast<Instruction>((Value *)*OpIt));
assert(MemIt != cast<MemDGNode>(N)->memPreds().end() &&
// It's a MemDGNode with OpIt == end, so we need to use MemIt.
assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
"Cant' dereference end iterator!");
return *MemIt;
}
Expand All @@ -45,7 +46,8 @@ PredIterator &PredIterator::operator++() {
OpIt = skipNonInstr(OpIt, OpItE);
return *this;
}
assert(MemIt != cast<MemDGNode>(N)->memPreds().end() && "Already at end!");
// It's a MemDGNode with OpIt == end, so we need to increment MemIt.
assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
++MemIt;
return *this;
}
Expand All @@ -57,10 +59,14 @@ bool PredIterator::operator==(const PredIterator &Other) const {
}

#ifndef NDEBUG
void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
void DGNode::print(raw_ostream &OS, bool PrintDeps) const { I->dumpOS(OS); }
void DGNode::dump() const {
print(dbgs());
dbgs() << "\n";
}
void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
I->dumpOS(OS);
if (PrintDeps) {
OS << "\n";
// Print memory preds.
static constexpr const unsigned Indent = 4;
for (auto *Pred : MemPreds) {
Expand All @@ -70,10 +76,6 @@ void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
}
}
}
void DGNode::dump() const {
print(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

Interval<MemDGNode>
Expand Down Expand Up @@ -179,7 +181,7 @@ bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
llvm_unreachable("Unknown DependencyType enum");
}

void DependencyGraph::scanAndAddDeps(DGNode &DstN,
void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
const Interval<MemDGNode> &SrcScanRange) {
assert(isa<MemDGNode>(DstN) &&
"DstN is the mem dep destination, so it must be mem");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ struct DependencyGraphTest : public testing::Test {
return *AA;
}
/// \Returns true if there is a dependency: SrcN->DstN.
bool dependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) {
const auto &Preds = DstN->memPreds();
auto It = find(Preds, SrcN);
return It != Preds.end();
bool memDependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) {
if (auto *MemDstN = dyn_cast<sandboxir::MemDGNode>(DstN))
return MemDstN->hasMemPred(SrcN);
return false;
}
};

Expand Down Expand Up @@ -230,9 +230,10 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
EXPECT_EQ(Span.top(), &*BB->begin());
EXPECT_EQ(Span.bottom(), BB->getTerminator());

sandboxir::DGNode *N0 = DAG.getNode(S0);
sandboxir::DGNode *N1 = DAG.getNode(S1);
sandboxir::DGNode *N2 = DAG.getNode(Ret);
auto *N0 = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
auto *N1 = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
auto *N2 = DAG.getNode(Ret);

// Check getInstruction().
EXPECT_EQ(N0->getInstruction(), S0);
EXPECT_EQ(N1->getInstruction(), S1);
Expand All @@ -247,7 +248,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
// Check memPreds().
EXPECT_TRUE(N0->memPreds().empty());
EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
EXPECT_TRUE(N2->memPreds().empty());
EXPECT_TRUE(N2->preds(DAG).empty());
}

TEST_F(DependencyGraphTest, Preds) {
Expand Down Expand Up @@ -399,12 +400,14 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
auto *Store0N = cast<sandboxir::MemDGNode>(
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
auto *Store1N = cast<sandboxir::MemDGNode>(
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
EXPECT_TRUE(Store0N->memPreds().empty());
EXPECT_THAT(Store1N->memPreds(), testing::ElementsAre(Store0N));
EXPECT_TRUE(RetN->memPreds().empty());
EXPECT_TRUE(RetN->preds(DAG).empty());
}

TEST_F(DependencyGraphTest, NonAliasingStores) {
Expand All @@ -422,13 +425,15 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) {
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
auto *Store0N = cast<sandboxir::MemDGNode>(
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
auto *Store1N = cast<sandboxir::MemDGNode>(
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
// We expect no dependencies because the stores don't alias.
EXPECT_TRUE(Store0N->memPreds().empty());
EXPECT_TRUE(Store1N->memPreds().empty());
EXPECT_TRUE(RetN->memPreds().empty());
EXPECT_TRUE(RetN->preds(DAG).empty());
}

TEST_F(DependencyGraphTest, VolatileLoads) {
Expand All @@ -446,12 +451,14 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Ld0N = DAG.getNode(cast<sandboxir::LoadInst>(&*It++));
auto *Ld1N = DAG.getNode(cast<sandboxir::LoadInst>(&*It++));
auto *Ld0N = cast<sandboxir::MemDGNode>(
DAG.getNode(cast<sandboxir::LoadInst>(&*It++)));
auto *Ld1N = cast<sandboxir::MemDGNode>(
DAG.getNode(cast<sandboxir::LoadInst>(&*It++)));
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
EXPECT_TRUE(Ld0N->memPreds().empty());
EXPECT_THAT(Ld1N->memPreds(), testing::ElementsAre(Ld0N));
EXPECT_TRUE(RetN->memPreds().empty());
EXPECT_TRUE(RetN->preds(DAG).empty());
}

TEST_F(DependencyGraphTest, VolatileSotres) {
Expand All @@ -469,12 +476,14 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v) {
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
auto *Store0N = cast<sandboxir::MemDGNode>(
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
auto *Store1N = cast<sandboxir::MemDGNode>(
DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
EXPECT_TRUE(Store0N->memPreds().empty());
EXPECT_THAT(Store1N->memPreds(), testing::ElementsAre(Store0N));
EXPECT_TRUE(RetN->memPreds().empty());
EXPECT_TRUE(RetN->preds(DAG).empty());
}

TEST_F(DependencyGraphTest, Call) {
Expand All @@ -498,12 +507,12 @@ define void @foo(float %v1, float %v2) {
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});

auto It = BB->begin();
auto *Call1N = DAG.getNode(&*It++);
auto *Call1N = cast<sandboxir::MemDGNode>(DAG.getNode(&*It++));
auto *AddN = DAG.getNode(&*It++);
auto *Call2N = DAG.getNode(&*It++);
auto *Call2N = cast<sandboxir::MemDGNode>(DAG.getNode(&*It++));

EXPECT_THAT(Call1N->memPreds(), testing::ElementsAre());
EXPECT_THAT(AddN->memPreds(), testing::ElementsAre());
EXPECT_THAT(AddN->preds(DAG), testing::ElementsAre());
EXPECT_THAT(Call2N->memPreds(), testing::ElementsAre(Call1N));
}

Expand Down Expand Up @@ -534,8 +543,8 @@ define void @foo() {
auto *AllocaN = DAG.getNode(&*It++);
auto *StackRestoreN = DAG.getNode(&*It++);

EXPECT_TRUE(dependency(AllocaN, StackRestoreN));
EXPECT_TRUE(dependency(StackSaveN, AllocaN));
EXPECT_TRUE(memDependency(AllocaN, StackRestoreN));
EXPECT_TRUE(memDependency(StackSaveN, AllocaN));
}

// Checks that stacksave and stackrestore depend on other mem instrs.
Expand Down Expand Up @@ -567,9 +576,9 @@ define void @foo(i8 %v0, i8 %v1, ptr %ptr) {
auto *StackRestoreN = DAG.getNode(&*It++);
auto *Store1N = DAG.getNode(&*It++);

EXPECT_TRUE(dependency(Store0N, StackSaveN));
EXPECT_TRUE(dependency(StackSaveN, StackRestoreN));
EXPECT_TRUE(dependency(StackRestoreN, Store1N));
EXPECT_TRUE(memDependency(Store0N, StackSaveN));
EXPECT_TRUE(memDependency(StackSaveN, StackRestoreN));
EXPECT_TRUE(memDependency(StackRestoreN, Store1N));
}

// Make sure there is a dependency between a stackrestore and an alloca.
Expand All @@ -596,7 +605,7 @@ define void @foo(ptr %ptr) {
auto *StackRestoreN = DAG.getNode(&*It++);
auto *AllocaN = DAG.getNode(&*It++);

EXPECT_TRUE(dependency(StackRestoreN, AllocaN));
EXPECT_TRUE(memDependency(StackRestoreN, AllocaN));
}

// Make sure there is a dependency between the alloca and stacksave
Expand All @@ -623,7 +632,7 @@ define void @foo(ptr %ptr) {
auto *AllocaN = DAG.getNode(&*It++);
auto *StackSaveN = DAG.getNode(&*It++);

EXPECT_TRUE(dependency(AllocaN, StackSaveN));
EXPECT_TRUE(memDependency(AllocaN, StackSaveN));
}

// A non-InAlloca in a stacksave-stackrestore region does not need extra
Expand Down Expand Up @@ -655,6 +664,6 @@ define void @foo() {
auto *AllocaN = DAG.getNode(&*It++);
auto *StackRestoreN = DAG.getNode(&*It++);

EXPECT_FALSE(dependency(StackSaveN, AllocaN));
EXPECT_FALSE(dependency(AllocaN, StackRestoreN));
EXPECT_FALSE(memDependency(StackSaveN, AllocaN));
EXPECT_FALSE(memDependency(AllocaN, StackRestoreN));
}
Loading