Skip to content

[SandboxVec][DAG] Refactoring: Outline code that looks for mem nodes #111750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2024

Conversation

vporpo
Copy link
Contributor

@vporpo vporpo commented Oct 9, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2024

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/111750.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+10)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+28-11)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+19)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 134adc4b21ab12..050e119040c281 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -154,6 +154,16 @@ class MemDGNode final : public DGNode {
 /// Convenience builders for a MemDGNode interval.
 class MemDGNodeIntervalBuilder {
 public:
+  /// Scans the instruction chain after \p I until \p BeforeI, looking for
+  /// a mem dependency candidate and return the corresponding MemDGNode, or
+  /// nullptr if not found.
+  static MemDGNode *getMemDGNodeAfter(Instruction *I, Instruction *BeforeI,
+                                      const DependencyGraph &DAG);
+  /// Scans the instruction chain before \p I until \p AfterI, looking for
+  /// a mem dependency candidate and return the corresponding MemDGNode, or
+  /// nullptr if not found.
+  static MemDGNode *getMemDGNodeBefore(Instruction *I, Instruction *AfterI,
+                                       const DependencyGraph &DAG);
   /// Given \p Instrs it finds their closest mem nodes in the interval and
   /// returns the corresponding mem range. Note: BotN (or its neighboring mem
   /// node) is included in the range.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 82f253d4c63231..6266b4155dc253 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -32,23 +32,40 @@ void DGNode::dump() const {
 }
 #endif // NDEBUG
 
+MemDGNode *MemDGNodeIntervalBuilder::getMemDGNodeAfter(
+    Instruction *I, Instruction *BeforeI, const DependencyGraph &DAG) {
+  assert((I == BeforeI || I->comesBefore(BeforeI)) &&
+         "Expected I before BeforeI");
+  // Walk down the chain looking for a mem-dep candidate instruction.
+  while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
+    I = I->getNextNode();
+  if (!DGNode::isMemDepNodeCandidate(I))
+    return nullptr;
+  return cast<MemDGNode>(DAG.getNode(I));
+}
+
+MemDGNode *MemDGNodeIntervalBuilder::getMemDGNodeBefore(
+    Instruction *I, Instruction *AfterI, const DependencyGraph &DAG) {
+  assert((I == AfterI || AfterI->comesBefore(I)) && "Expected AfterI before I");
+  // Walk up the chain looking for a mem-dep candidate instruction.
+  while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
+    I = I->getPrevNode();
+  if (!DGNode::isMemDepNodeCandidate(I))
+    return nullptr;
+  return cast<MemDGNode>(DAG.getNode(I));
+}
+
 Interval<MemDGNode>
 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
                                DependencyGraph &DAG) {
-  // If top or bottom instructions are not mem-dep candidate nodes we need to
-  // walk down/up the chain and find the mem-dep ones.
-  Instruction *MemTopI = Instrs.top();
-  Instruction *MemBotI = Instrs.bottom();
-  while (!DGNode::isMemDepNodeCandidate(MemTopI) && MemTopI != MemBotI)
-    MemTopI = MemTopI->getNextNode();
-  while (!DGNode::isMemDepNodeCandidate(MemBotI) && MemBotI != MemTopI)
-    MemBotI = MemBotI->getPrevNode();
+  auto *TopMemN = getMemDGNodeAfter(Instrs.top(), Instrs.bottom(), DAG);
   // If we couldn't find a mem node in range TopN - BotN then it's empty.
-  if (!DGNode::isMemDepNodeCandidate(MemTopI))
+  if (TopMemN == nullptr)
     return {};
+  auto *BotMemN = getMemDGNodeBefore(Instrs.bottom(), Instrs.top(), DAG);
+  assert(BotMemN != nullptr && "TopMemN should be null too!");
   // Now that we have the mem-dep nodes, create and return the range.
-  return Interval<MemDGNode>(cast<MemDGNode>(DAG.getNode(MemTopI)),
-                             cast<MemDGNode>(DAG.getNode(MemBotI)));
+  return Interval<MemDGNode>(TopMemN, BotMemN);
 }
 
 DependencyGraph::DependencyType
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index e2f16919a5cddd..3d14da7b9358ec 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -305,6 +305,25 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
   auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
 
+  // Check getMemDGNodeAfter().
+  using B = sandboxir::MemDGNodeIntervalBuilder;
+  EXPECT_EQ(B::getMemDGNodeAfter(S0, S0, DAG), S0N);
+  EXPECT_EQ(B::getMemDGNodeAfter(S0, Ret, DAG), S0N);
+#ifndef NDEBUG
+  EXPECT_DEATH(B::getMemDGNodeAfter(S0, Add0, DAG), ".*before.*");
+#endif
+  EXPECT_EQ(B::getMemDGNodeAfter(Add0, Add1, DAG), S0N);
+  EXPECT_EQ(B::getMemDGNodeAfter(Add0, Add0, DAG), nullptr);
+
+  // Check getMemDGNodeBefore().
+  EXPECT_EQ(B::getMemDGNodeBefore(S1, S1, DAG), S1N);
+  EXPECT_EQ(B::getMemDGNodeBefore(S1, Add0, DAG), S1N);
+#ifndef NDEBUG
+  EXPECT_DEATH(B::getMemDGNodeBefore(S1, Ret, DAG), ".*before.*");
+#endif
+  EXPECT_EQ(B::getMemDGNodeBefore(Ret, Add0, DAG), S1N);
+  EXPECT_EQ(B::getMemDGNodeBefore(Ret, Ret, DAG), nullptr);
+
   // Check empty range.
   EXPECT_THAT(sandboxir::MemDGNodeIntervalBuilder::makeEmpty(),
               testing::ElementsAre());

@vporpo
Copy link
Contributor Author

vporpo commented Oct 9, 2024

Changed the function arguments from individual instructions to an instruction interval.

@vporpo vporpo merged commit 69c0067 into llvm:main Oct 10, 2024
8 checks passed
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants