Skip to content

[SandboxVec][DAG] MemDGNode for memory-dependency candidate nodes #109684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,61 @@

namespace llvm::sandboxir {

class DependencyGraph;
class MemDGNode;

/// SubclassIDs for isa/dyn_cast etc.
enum class DGNodeID {
DGNode,
MemDGNode,
};

/// A DependencyGraph Node that points to an Instruction and contains memory
/// dependency edges.
class DGNode {
protected:
Instruction *I;
// TODO: Use a PointerIntPair for SubclassID and I.
/// For isa/dyn_cast etc.
DGNodeID SubclassID;
/// Memory predecessors.
DenseSet<DGNode *> MemPreds;
/// This is true if this may read/write memory, or if it has some ordering
/// constraints, like with stacksave/stackrestore and alloca/inalloca.
bool IsMem;
DenseSet<MemDGNode *> MemPreds;

DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
friend class MemDGNode; // For constructor.

public:
DGNode(Instruction *I) : I(I) {
IsMem = I->isMemDepCandidate() ||
(isa<AllocaInst>(I) && cast<AllocaInst>(I)->isUsedWithInAlloca()) ||
I->isStackSaveOrRestoreIntrinsic();
DGNode(Instruction *I) : I(I), SubclassID(DGNodeID::DGNode) {
assert(!isMemDepCandidate(I) && "Expected Non-Mem instruction, ");
}
DGNode(const DGNode &Other) = delete;
virtual ~DGNode() = default;
/// \Returns true if this is before \p Other in program order.
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
/// \Returns true if \p I is a memory dependency candidate instruction.
static bool isMemDepCandidate(Instruction *I) {
AllocaInst *Alloca;
return I->isMemDepCandidate() ||
((Alloca = dyn_cast<AllocaInst>(I)) &&
Alloca->isUsedWithInAlloca()) ||
I->isStackSaveOrRestoreIntrinsic();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't these checks be moved into isMemDepCandidate for sandboxir::Instruction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean in member function like sandboxir::Instruction::isMemDepCandidate() ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chatted offline and resolving, keeping it here is fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why isMemDepCandidate() and isStackSaveOrRestoreIntrinsic() are in sandboxir::Instruction is that they need to access LLVM IR, but the combined check does not need to. Also I don't think this check will be used in other places.

}

Instruction *getInstruction() const { return I; }
void addMemPred(DGNode *PredN) { MemPreds.insert(PredN); }
void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); }
/// \Returns all memory dependency predecessors.
iterator_range<DenseSet<DGNode *>::const_iterator> memPreds() const {
iterator_range<DenseSet<MemDGNode *>::const_iterator> memPreds() const {
return make_range(MemPreds.begin(), MemPreds.end());
}
/// \Returns true if there is a memory dependency N->this.
bool hasMemPred(DGNode *N) const { return MemPreds.count(N); }
/// \Returns true if this may read/write memory, or if it has some ordering
/// constraints, like with stacksave/stackrestore and alloca/inalloca.
bool isMem() const { return IsMem; }
bool hasMemPred(DGNode *N) const {
if (auto *MN = dyn_cast<MemDGNode>(N))
return MemPreds.count(MN);
return false;
}

#ifndef NDEBUG
void print(raw_ostream &OS, bool PrintDeps = true) const;
virtual void print(raw_ostream &OS, bool PrintDeps = true) const;
friend raw_ostream &operator<<(DGNode &N, raw_ostream &OS) {
N.print(OS);
return OS;
Expand All @@ -66,9 +92,46 @@ class DGNode {
#endif // NDEBUG
};

/// A DependencyGraph Node for instructions that may read/write memory, or have
/// some ordering constraints, like with stacksave/stackrestore and
/// alloca/inalloca.
class MemDGNode final : public DGNode {
MemDGNode *PrevMemN = nullptr;
MemDGNode *NextMemN = nullptr;

void setNextNode(MemDGNode *N) { NextMemN = N; }
void setPrevNode(MemDGNode *N) { PrevMemN = N; }
friend class DependencyGraph; // For setNextNode(), setPrevNode().

public:
MemDGNode(Instruction *I) : DGNode(I, DGNodeID::MemDGNode) {
assert(isMemDepCandidate(I) && "Expected Mem instruction!");
}
static bool classof(const DGNode *Other) {
return Other->SubclassID == DGNodeID::MemDGNode;
}
/// \Returns the previous Mem DGNode in instruction order.
MemDGNode *getPrevNode() const { return PrevMemN; }
/// \Returns the next Mem DGNode in instruction order.
MemDGNode *getNextNode() const { return NextMemN; }
};

/// Convenience builders for a MemDGNode interval.
class MemDGNodeIntervalBuilder {
public:
/// Given \p Instrs it finds their closest mem nodes in the interval and
/// returns the corresponding mem range. Note: BotN (or its neighboring mem
/// node) is included in the range.
static Interval<MemDGNode> make(const Interval<Instruction> &Instrs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate why you need both Interval and a DAG? The DAG has an interval inside. Are they different intervals, what's the usage for each?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interval describes a range of nodes in program order. This particular interval is going to be used as the range of nodes to be scanned for dependencies as we are building the DAG. The DAG's interval is the range of instructions that are represented by the DAG, it's only used to check if some instruction/node is inside or outside the DAG.

DependencyGraph &DAG);
static Interval<MemDGNode> makeEmpty() { return {}; }
};

class DependencyGraph {
private:
DenseMap<Instruction *, std::unique_ptr<DGNode>> InstrToNodeMap;
/// The DAG spans across all instructions in this interval.
Interval<Instruction> DAGInterval;

public:
DependencyGraph() {}
Expand All @@ -77,10 +140,20 @@ class DependencyGraph {
auto It = InstrToNodeMap.find(I);
return It != InstrToNodeMap.end() ? It->second.get() : nullptr;
}
/// Like getNode() but returns nullptr if \p I is nullptr.
DGNode *getNodeOrNull(Instruction *I) const {
if (I == nullptr)
return nullptr;
return getNode(I);
}
DGNode *getOrCreateNode(Instruction *I) {
auto [It, NotInMap] = InstrToNodeMap.try_emplace(I);
if (NotInMap)
It->second = std::make_unique<DGNode>(I);
if (NotInMap) {
if (DGNode::isMemDepCandidate(I))
It->second = std::make_unique<MemDGNode>(I);
else
It->second = std::make_unique<DGNode>(I);
}
return It->second.get();
}
/// Build/extend the dependency graph such that it includes \p Instrs. Returns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ void DGNode::dump() const {
}
#endif // NDEBUG

Interval<MemDGNode>
MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
DependencyGraph &DAG) {
// If top or bottom instructions are not mem-dep candidate nodes we need to
// walk down/up the chain and find the mem-dep ones.
Instruction *MemTopI = Instrs.top();
Instruction *MemBotI = Instrs.bottom();
while (!DGNode::isMemDepCandidate(MemTopI) && MemTopI != MemBotI)
MemTopI = MemTopI->getNextNode();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm missing how getNextNode works for an Instruction here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just skip the non memory-candidate instructions. getNextNode() gives you the next instruction in the instruction list.

while (!DGNode::isMemDepCandidate(MemBotI) && MemBotI != MemTopI)
MemBotI = MemBotI->getPrevNode();
// If we couldn't find a mem node in range TopN - BotN then it's empty.
if (!DGNode::isMemDepCandidate(MemTopI))
return {};
// Now that we have the mem-dep nodes, create and return the range.
return Interval<MemDGNode>(cast<MemDGNode>(DAG.getNode(MemTopI)),
cast<MemDGNode>(DAG.getNode(MemBotI)));
}

Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
if (Instrs.empty())
return {};
Expand All @@ -39,10 +58,18 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
auto *TopI = Interval.top();
auto *BotI = Interval.bottom();
DGNode *LastN = getOrCreateNode(TopI);
MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
for (Instruction *I = TopI->getNextNode(), *E = BotI->getNextNode(); I != E;
I = I->getNextNode()) {
auto *N = getOrCreateNode(I);
N->addMemPred(LastN);
N->addMemPred(LastMemN);
// Build the Mem node chain.
if (auto *MemN = dyn_cast<MemDGNode>(N)) {
MemN->setPrevNode(LastMemN);
if (LastMemN != nullptr)
LastMemN->setNextNode(MemN);
LastMemN = MemN;
}
LastN = N;
}
return Interval;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct DependencyGraphTest : public testing::Test {
}
};

TEST_F(DependencyGraphTest, DGNode_IsMem) {
TEST_F(DependencyGraphTest, MemDGNode) {
parseIR(C, R"IR(
declare void @llvm.sideeffect()
declare void @llvm.pseudoprobe(i64, i64, i32, i64)
Expand Down Expand Up @@ -66,16 +66,16 @@ define void @foo(i8 %v1, ptr %ptr) {

sandboxir::DependencyGraph DAG;
DAG.extend({&*BB->begin(), BB->getTerminator()});
EXPECT_TRUE(DAG.getNode(Store)->isMem());
EXPECT_TRUE(DAG.getNode(Load)->isMem());
EXPECT_FALSE(DAG.getNode(Add)->isMem());
EXPECT_TRUE(DAG.getNode(StackSave)->isMem());
EXPECT_TRUE(DAG.getNode(StackRestore)->isMem());
EXPECT_FALSE(DAG.getNode(SideEffect)->isMem());
EXPECT_FALSE(DAG.getNode(PseudoProbe)->isMem());
EXPECT_TRUE(DAG.getNode(FakeUse)->isMem());
EXPECT_TRUE(DAG.getNode(Call)->isMem());
EXPECT_FALSE(DAG.getNode(Ret)->isMem());
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Store)));
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Load)));
EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Add)));
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(StackSave)));
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(StackRestore)));
EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(SideEffect)));
EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(PseudoProbe)));
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(FakeUse)));
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Call)));
EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Ret)));
}

TEST_F(DependencyGraphTest, Basic) {
Expand Down Expand Up @@ -115,3 +115,100 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1));
}

TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
store i8 %v0, ptr %ptr
add i8 %v0, %v0
store i8 %v1, ptr %ptr
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
[[maybe_unused]] auto *Add = cast<sandboxir::BinaryOperator>(&*It++);
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::DependencyGraph DAG;
DAG.extend({&*BB->begin(), BB->getTerminator()});

auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));

EXPECT_EQ(S0N->getPrevNode(), nullptr);
EXPECT_EQ(S0N->getNextNode(), S1N);

EXPECT_EQ(S1N->getPrevNode(), S0N);
EXPECT_EQ(S1N->getNextNode(), nullptr);
}

TEST_F(DependencyGraphTest, DGNodeRange) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
add i8 %v0, %v0
store i8 %v0, ptr %ptr
add i8 %v0, %v0
store i8 %v1, ptr %ptr
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::DependencyGraph DAG;
DAG.extend({&*BB->begin(), BB->getTerminator()});

auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));

// Check empty range.
EXPECT_THAT(sandboxir::MemDGNodeIntervalBuilder::makeEmpty(),
testing::ElementsAre());

// Returns the pointers in Range.
auto getPtrVec = [](const auto &Range) {
SmallVector<const sandboxir::DGNode *> Vec;
for (const sandboxir::DGNode &N : Range)
Vec.push_back(&N);
return Vec;
};
// Both TopN and BotN are memory.
EXPECT_THAT(
getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({S0, S1}, DAG)),
testing::ElementsAre(S0N, S1N));
// Only TopN is memory.
EXPECT_THAT(
getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({S0, Ret}, DAG)),
testing::ElementsAre(S0N, S1N));
EXPECT_THAT(
getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({S0, Add1}, DAG)),
testing::ElementsAre(S0N));
// Only BotN is memory.
EXPECT_THAT(
getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, S1}, DAG)),
testing::ElementsAre(S0N, S1N));
EXPECT_THAT(
getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, S0}, DAG)),
testing::ElementsAre(S0N));
// Neither TopN or BotN is memory.
EXPECT_THAT(
getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, Ret}, DAG)),
testing::ElementsAre(S0N, S1N));
EXPECT_THAT(
getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, Add0}, DAG)),
testing::ElementsAre());
}
Loading