Skip to content

Commit 1c69a3f

Browse files
committed
[SandboxVec][DAG] Refactoring: Outline code that looks for mem nodes
1 parent 102c384 commit 1c69a3f

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,16 @@ class MemDGNode final : public DGNode {
154154
/// Convenience builders for a MemDGNode interval.
155155
class MemDGNodeIntervalBuilder {
156156
public:
157+
/// Scans the instruction chain after \p I until \p BeforeI, looking for
158+
/// a mem dependency candidate and return the corresponding MemDGNode, or
159+
/// nullptr if not found.
160+
static MemDGNode *getMemDGNodeAfter(Instruction *I, Instruction *BeforeI,
161+
const DependencyGraph &DAG);
162+
/// Scans the instruction chain before \p I until \p AfterI, looking for
163+
/// a mem dependency candidate and return the corresponding MemDGNode, or
164+
/// nullptr if not found.
165+
static MemDGNode *getMemDGNodeBefore(Instruction *I, Instruction *AfterI,
166+
const DependencyGraph &DAG);
157167
/// Given \p Instrs it finds their closest mem nodes in the interval and
158168
/// returns the corresponding mem range. Note: BotN (or its neighboring mem
159169
/// node) is included in the range.

llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,40 @@ void DGNode::dump() const {
3232
}
3333
#endif // NDEBUG
3434

35+
MemDGNode *MemDGNodeIntervalBuilder::getMemDGNodeAfter(
36+
Instruction *I, Instruction *BeforeI, const DependencyGraph &DAG) {
37+
assert((I == BeforeI || I->comesBefore(BeforeI)) &&
38+
"Expected I before BeforeI");
39+
// Walk down the chain looking for a mem-dep candidate instruction.
40+
while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
41+
I = I->getNextNode();
42+
if (!DGNode::isMemDepNodeCandidate(I))
43+
return nullptr;
44+
return cast<MemDGNode>(DAG.getNode(I));
45+
}
46+
47+
MemDGNode *MemDGNodeIntervalBuilder::getMemDGNodeBefore(
48+
Instruction *I, Instruction *AfterI, const DependencyGraph &DAG) {
49+
assert((I == AfterI || AfterI->comesBefore(I)) && "Expected AfterI before I");
50+
// Walk up the chain looking for a mem-dep candidate instruction.
51+
while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
52+
I = I->getPrevNode();
53+
if (!DGNode::isMemDepNodeCandidate(I))
54+
return nullptr;
55+
return cast<MemDGNode>(DAG.getNode(I));
56+
}
57+
3558
Interval<MemDGNode>
3659
MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
3760
DependencyGraph &DAG) {
38-
// If top or bottom instructions are not mem-dep candidate nodes we need to
39-
// walk down/up the chain and find the mem-dep ones.
40-
Instruction *MemTopI = Instrs.top();
41-
Instruction *MemBotI = Instrs.bottom();
42-
while (!DGNode::isMemDepNodeCandidate(MemTopI) && MemTopI != MemBotI)
43-
MemTopI = MemTopI->getNextNode();
44-
while (!DGNode::isMemDepNodeCandidate(MemBotI) && MemBotI != MemTopI)
45-
MemBotI = MemBotI->getPrevNode();
61+
auto *TopMemN = getMemDGNodeAfter(Instrs.top(), Instrs.bottom(), DAG);
4662
// If we couldn't find a mem node in range TopN - BotN then it's empty.
47-
if (!DGNode::isMemDepNodeCandidate(MemTopI))
63+
if (TopMemN == nullptr)
4864
return {};
65+
auto *BotMemN = getMemDGNodeBefore(Instrs.bottom(), Instrs.top(), DAG);
66+
assert(BotMemN != nullptr && "TopMemN should be null too!");
4967
// Now that we have the mem-dep nodes, create and return the range.
50-
return Interval<MemDGNode>(cast<MemDGNode>(DAG.getNode(MemTopI)),
51-
cast<MemDGNode>(DAG.getNode(MemBotI)));
68+
return Interval<MemDGNode>(TopMemN, BotMemN);
5269
}
5370

5471
DependencyGraph::DependencyType

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,25 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
305305
auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
306306
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
307307

308+
// Check getMemDGNodeAfter().
309+
using B = sandboxir::MemDGNodeIntervalBuilder;
310+
EXPECT_EQ(B::getMemDGNodeAfter(S0, S0, DAG), S0N);
311+
EXPECT_EQ(B::getMemDGNodeAfter(S0, Ret, DAG), S0N);
312+
#ifndef NDEBUG
313+
EXPECT_DEATH(B::getMemDGNodeAfter(S0, Add0, DAG), ".*before.*");
314+
#endif
315+
EXPECT_EQ(B::getMemDGNodeAfter(Add0, Add1, DAG), S0N);
316+
EXPECT_EQ(B::getMemDGNodeAfter(Add0, Add0, DAG), nullptr);
317+
318+
// Check getMemDGNodeBefore().
319+
EXPECT_EQ(B::getMemDGNodeBefore(S1, S1, DAG), S1N);
320+
EXPECT_EQ(B::getMemDGNodeBefore(S1, Add0, DAG), S1N);
321+
#ifndef NDEBUG
322+
EXPECT_DEATH(B::getMemDGNodeBefore(S1, Ret, DAG), ".*before.*");
323+
#endif
324+
EXPECT_EQ(B::getMemDGNodeBefore(Ret, Add0, DAG), S1N);
325+
EXPECT_EQ(B::getMemDGNodeBefore(Ret, Ret, DAG), nullptr);
326+
308327
// Check empty range.
309328
EXPECT_THAT(sandboxir::MemDGNodeIntervalBuilder::makeEmpty(),
310329
testing::ElementsAre());

0 commit comments

Comments
 (0)