Skip to content

[SandboxVec][DAG] Update DAG when a new instruction is created #126124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ class MemDGNode final : public DGNode {
void addMemPred(MemDGNode *PredN) {
[[maybe_unused]] auto Inserted = MemPreds.insert(PredN).second;
assert(Inserted && "PredN already exists!");
assert(PredN != this && "Trying to add a dependency to self!");
if (!Scheduled) {
++PredN->UnscheduledSuccs;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ template <typename T> class Interval {
return (Top == I || Top->comesBefore(I)) &&
(I == Bottom || I->comesBefore(Bottom));
}
/// \Returns true if \p Elm is right before the top or right after the bottom.
bool touches(T *Elm) const {
return Top == Elm->getNextNode() || Bottom == Elm->getPrevNode();
}
T *top() const { return Top; }
T *bottom() const { return Bottom; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,13 @@ MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
}

void DependencyGraph::notifyCreateInstr(Instruction *I) {
auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
// TODO: Update the dependencies for the new node.
// Nothing to do if the node is not in the focus range of the DAG.
if (!(DAGInterval.contains(I) || DAGInterval.touches(I)))
return;
// Include `I` into the interval.
DAGInterval = DAGInterval.getUnionInterval({I, I});
auto *N = getOrCreateNode(I);
auto *MemN = dyn_cast<MemDGNode>(N);

// Update the MemDGNode chain if this is a memory node.
if (MemN != nullptr) {
Expand All @@ -381,6 +386,21 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) {
NextMemN->PrevMemN = MemN;
MemN->NextMemN = NextMemN;
}

// Add Mem dependencies.
// 1. Scan for deps above `I` for deps to `I`: AboveN->MemN.
if (DAGInterval.top()->comesBefore(I)) {
Interval<Instruction> AboveIntvl(DAGInterval.top(), I->getPrevNode());
auto SrcInterval = MemDGNodeIntervalBuilder::make(AboveIntvl, *this);
scanAndAddDeps(*MemN, SrcInterval);
}
// 2. Scan for deps below `I` for deps from `I`: MemN->BelowN.
if (I->comesBefore(DAGInterval.bottom())) {
Interval<Instruction> BelowIntvl(I->getNextNode(), DAGInterval.bottom());
for (MemDGNode &BelowN :
MemDGNodeIntervalBuilder::make(BelowIntvl, *this))
scanAndAddDeps(BelowN, Interval<MemDGNode>(MemN, MemN));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -832,9 +832,10 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
}
}

// Check that the DAG gets updated when we create a new instruction.
TEST_F(DependencyGraphTest, CreateInstrCallback) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %new1, i8 %new2) {
store i8 %v1, ptr %ptr
store i8 %v2, ptr %ptr
store i8 %v3, ptr %ptr
Expand All @@ -851,42 +852,52 @@ define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

// Check new instruction callback.
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({S1, Ret});
auto *Arg = F->getArg(3);
// Create a DAG spanning S1 to S3.
DAG.extend({S1, S3});
auto *ArgNew1 = F->getArg(4);
auto *ArgNew2 = F->getArg(5);
auto *Ptr = S1->getPointerOperand();

auto *S1MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
sandboxir::MemDGNode *New1MemN = nullptr;
sandboxir::MemDGNode *New2MemN = nullptr;
{
// Create a new store before S3 (within the span of the DAG).
sandboxir::StoreInst *NewS =
sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
sandboxir::StoreInst::create(ArgNew1, Ptr, Align(8), S3->getIterator(),
/*IsVolatile=*/true, Ctx);
auto *NewSN = DAG.getNode(NewS);
EXPECT_TRUE(NewSN != nullptr);

// Check the MemDGNode chain.
auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
auto *NewMemSN = cast<sandboxir::MemDGNode>(NewSN);
auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
EXPECT_EQ(S2MemN->getNextNode(), NewMemSN);
EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN);
EXPECT_EQ(NewMemSN->getNextNode(), S3MemN);
EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN);
New1MemN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
EXPECT_EQ(S2MemN->getNextNode(), New1MemN);
EXPECT_EQ(New1MemN->getPrevNode(), S2MemN);
EXPECT_EQ(New1MemN->getNextNode(), S3MemN);
EXPECT_EQ(S3MemN->getPrevNode(), New1MemN);

// Check dependencies.
EXPECT_TRUE(memDependency(S1MemN, New1MemN));
EXPECT_TRUE(memDependency(S2MemN, New1MemN));
EXPECT_TRUE(memDependency(New1MemN, S3MemN));
}

{
// Also check if new node is at the end of the BB, after Ret.
// Create a new store before Ret (outside the current DAG).
sandboxir::StoreInst *NewS =
sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(),
sandboxir::StoreInst::create(ArgNew2, Ptr, Align(8), Ret->getIterator(),
/*IsVolatile=*/true, Ctx);
// Check the MemDGNode chain.
auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
auto *NewMemSN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
EXPECT_EQ(S3MemN->getNextNode(), NewMemSN);
EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN);
EXPECT_EQ(NewMemSN->getNextNode(), nullptr);
New2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
EXPECT_EQ(S3MemN->getNextNode(), New2MemN);
EXPECT_EQ(New2MemN->getPrevNode(), S3MemN);
EXPECT_EQ(New2MemN->getNextNode(), nullptr);

// Check dependencies.
EXPECT_TRUE(memDependency(S1MemN, New2MemN));
EXPECT_TRUE(memDependency(S2MemN, New2MemN));
EXPECT_TRUE(memDependency(New1MemN, New2MemN));
EXPECT_TRUE(memDependency(S3MemN, New2MemN));
}

// TODO: Check the dependencies to/from NewSN after they land.
}

TEST_F(DependencyGraphTest, EraseInstrCallback) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ define void @foo(i8 %v0) {
EXPECT_FALSE(One.contains(I1));
EXPECT_FALSE(One.contains(I2));
EXPECT_FALSE(One.contains(Ret));
// Check touches().
{
sandboxir::Interval<sandboxir::Instruction> Intvl(I2, I2);
EXPECT_TRUE(Intvl.touches(I1));
EXPECT_TRUE(Intvl.contains(I2));
EXPECT_FALSE(Intvl.touches(I2));
EXPECT_TRUE(Intvl.touches(Ret));
EXPECT_FALSE(Intvl.touches(I0));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a EXPECT_TRUE(Intvl.contains(I2)); here for clarity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// Check iterator.
auto BBIt = BB->begin();
for (auto &I : Intvl)
Expand Down