Skip to content

[SandboxVec][DAG] Refactoring: Outline code that looks for mem nodes #111750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ class MemDGNode final : public DGNode {
/// Convenience builders for a MemDGNode interval.
class MemDGNodeIntervalBuilder {
public:
/// Scans the instruction chain in \p Intvl top-down, returning the top-most
/// MemDGNode, or nullptr.
static MemDGNode *getTopMemDGNode(const Interval<Instruction> &Intvl,
const DependencyGraph &DAG);
/// Scans the instruction chain in \p Intvl bottom-up, returning the
/// bottom-most MemDGNode, or nullptr.
static MemDGNode *getBotMemDGNode(const Interval<Instruction> &Intvl,
const DependencyGraph &DAG);
/// Given \p Instrs it finds their closest mem nodes in the interval and
/// returns the corresponding mem range. Note: BotN (or its neighboring mem
/// node) is included in the range.
Expand Down
42 changes: 31 additions & 11 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,43 @@ void DGNode::dump() const {
}
#endif // NDEBUG

MemDGNode *
MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl,
const DependencyGraph &DAG) {
Instruction *I = Intvl.top();
Instruction *BeforeI = Intvl.bottom();
// Walk down the chain looking for a mem-dep candidate instruction.
while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
I = I->getNextNode();
if (!DGNode::isMemDepNodeCandidate(I))
return nullptr;
return cast<MemDGNode>(DAG.getNode(I));
}

MemDGNode *
MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
const DependencyGraph &DAG) {
Instruction *I = Intvl.bottom();
Instruction *AfterI = Intvl.top();
// Walk up the chain looking for a mem-dep candidate instruction.
while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
I = I->getPrevNode();
if (!DGNode::isMemDepNodeCandidate(I))
return nullptr;
return cast<MemDGNode>(DAG.getNode(I));
}

Interval<MemDGNode>
MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
DependencyGraph &DAG) {
// If top or bottom instructions are not mem-dep candidate nodes we need to
// walk down/up the chain and find the mem-dep ones.
Instruction *MemTopI = Instrs.top();
Instruction *MemBotI = Instrs.bottom();
while (!DGNode::isMemDepNodeCandidate(MemTopI) && MemTopI != MemBotI)
MemTopI = MemTopI->getNextNode();
while (!DGNode::isMemDepNodeCandidate(MemBotI) && MemBotI != MemTopI)
MemBotI = MemBotI->getPrevNode();
auto *TopMemN = getTopMemDGNode(Instrs, DAG);
// If we couldn't find a mem node in range TopN - BotN then it's empty.
if (!DGNode::isMemDepNodeCandidate(MemTopI))
if (TopMemN == nullptr)
return {};
auto *BotMemN = getBotMemDGNode(Instrs, DAG);
assert(BotMemN != nullptr && "TopMemN should be null too!");
// Now that we have the mem-dep nodes, create and return the range.
return Interval<MemDGNode>(cast<MemDGNode>(DAG.getNode(MemTopI)),
cast<MemDGNode>(DAG.getNode(MemBotI)));
return Interval<MemDGNode>(TopMemN, BotMemN);
}

DependencyGraph::DependencyType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,20 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));

// Check getTopMemDGNode().
using B = sandboxir::MemDGNodeIntervalBuilder;
using InstrInterval = sandboxir::Interval<sandboxir::Instruction>;
EXPECT_EQ(B::getTopMemDGNode(InstrInterval(S0, S0), DAG), S0N);
EXPECT_EQ(B::getTopMemDGNode(InstrInterval(S0, Ret), DAG), S0N);
EXPECT_EQ(B::getTopMemDGNode(InstrInterval(Add0, Add1), DAG), S0N);
EXPECT_EQ(B::getTopMemDGNode(InstrInterval(Add0, Add0), DAG), nullptr);

// Check getBotMemDGNode().
EXPECT_EQ(B::getBotMemDGNode(InstrInterval(S1, S1), DAG), S1N);
EXPECT_EQ(B::getBotMemDGNode(InstrInterval(Add0, S1), DAG), S1N);
EXPECT_EQ(B::getBotMemDGNode(InstrInterval(Add0, Ret), DAG), S1N);
EXPECT_EQ(B::getBotMemDGNode(InstrInterval(Ret, Ret), DAG), nullptr);

// Check empty range.
EXPECT_THAT(sandboxir::MemDGNodeIntervalBuilder::makeEmpty(),
testing::ElementsAre());
Expand Down
Loading