Skip to content

[SandboxVec][DAG] Register callback for erase instr #116742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 20, 2024
Merged

Conversation

vporpo
Copy link
Contributor

@vporpo vporpo commented Nov 19, 2024

This patch adds the callback registration logic in the DAG's constructor and the corresponding deregistration logic in the destructor. It also implements the code that makes sure that SchedBundle and DGNodes can be safely destroyed in any order.

This patch adds the callback registration logic in the DAG's constructor
and the corresponding deregistration logic in the destructor.
It also implements the code that makes sure that SchedBundle and DGNodes
can be safely destroyed in any order.
@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

This patch adds the callback registration logic in the DAG's constructor and the corresponding deregistration logic in the destructor. It also implements the code that makes sure that SchedBundle and DGNodes can be safely destroyed in any order.


Full diff: https://github.com/llvm/llvm-project/pull/116742.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+12-1)
  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h (+4)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+7)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+28)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 765b65c4971bed..68a2daca1403df 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -117,7 +117,7 @@ class DGNode {
     assert(!isMemDepNodeCandidate(I) && "Expected Non-Mem instruction, ");
   }
   DGNode(const DGNode &Other) = delete;
-  virtual ~DGNode() = default;
+  virtual ~DGNode();
   /// \Returns the number of unscheduled successors.
   unsigned getNumUnscheduledSuccs() const { return UnscheduledSuccs; }
   void decrUnscheduledSuccs() {
@@ -292,6 +292,7 @@ class DependencyGraph {
 
   Context *Ctx = nullptr;
   std::optional<Context::CallbackID> CreateInstrCB;
+  std::optional<Context::CallbackID> EraseInstrCB;
 
   std::unique_ptr<BatchAAResults> BatchAA;
 
@@ -334,6 +335,12 @@ class DependencyGraph {
     // TODO: Update the dependencies for the new node.
     // TODO: Update the MemDGNode chain to include the new node if needed.
   }
+  /// Called by the callbacks when instruction \p I is about to get deleted.
+  void notifyEraseInstr(Instruction *I) {
+    InstrToNodeMap.erase(I);
+    // TODO: Update the dependencies.
+    // TODO: Update the MemDGNode chain to remove the node if needed.
+  }
 
 public:
   /// This constructor also registers callbacks.
@@ -341,10 +348,14 @@ class DependencyGraph {
       : Ctx(&Ctx), BatchAA(std::make_unique<BatchAAResults>(AA)) {
     CreateInstrCB = Ctx.registerCreateInstrCallback(
         [this](Instruction *I) { notifyCreateInstr(I); });
+    EraseInstrCB = Ctx.registerEraseInstrCallback(
+        [this](Instruction *I) { notifyEraseInstr(I); });
   }
   ~DependencyGraph() {
     if (CreateInstrCB)
       Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
+    if (EraseInstrCB)
+      Ctx->unregisterEraseInstrCallback(*EraseInstrCB);
   }
 
   DGNode *getNode(Instruction *I) const {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
index 022fd71df67dc6..3959f84c601e04 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -69,6 +69,10 @@ class SchedBundle {
 private:
   ContainerTy Nodes;
 
+  /// Called by the DGNode destructor to avoid accessing freed memory.
+  void eraseFromBundle(DGNode *N) { Nodes.erase(find(Nodes, N)); }
+  friend DGNode::~DGNode(); // For eraseFromBundle().
+
 public:
   SchedBundle() = default;
   SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 6217c9fecf45dd..4b0e12c28f07b7 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -10,6 +10,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/SandboxIR/Instruction.h"
 #include "llvm/SandboxIR/Utils.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
 
 namespace llvm::sandboxir {
 
@@ -58,6 +59,12 @@ bool PredIterator::operator==(const PredIterator &Other) const {
   return OpIt == Other.OpIt && MemIt == Other.MemIt;
 }
 
+DGNode::~DGNode() {
+  if (SB == nullptr)
+    return;
+  SB->eraseFromBundle(this);
+}
+
 #ifndef NDEBUG
 void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
   OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n";
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 206f6c5b4c1359..e6bb4b4684d262 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -830,3 +830,31 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
   // TODO: Check the dependencies to/from NewSN after they land.
   // TODO: Check the MemDGNode chain.
 }
+
+TEST_F(DependencyGraphTest, EraseInstrCallback) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+  // Check erase instruction callback.
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+  DAG.extend({S1, S3});
+  S2->eraseFromParent();
+  auto *DeletedN = DAG.getNodeOrNull(S2);
+  EXPECT_TRUE(DeletedN == nullptr);
+  // TODO: Check the dependencies to/from NewSN after they land.
+  // TODO: Check the MemDGNode chain.
+}

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-vectorizers

Author: vporpo (vporpo)

Changes

This patch adds the callback registration logic in the DAG's constructor and the corresponding deregistration logic in the destructor. It also implements the code that makes sure that SchedBundle and DGNodes can be safely destroyed in any order.


Full diff: https://github.com/llvm/llvm-project/pull/116742.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+12-1)
  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h (+4)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+7)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+28)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 765b65c4971bed..68a2daca1403df 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -117,7 +117,7 @@ class DGNode {
     assert(!isMemDepNodeCandidate(I) && "Expected Non-Mem instruction, ");
   }
   DGNode(const DGNode &Other) = delete;
-  virtual ~DGNode() = default;
+  virtual ~DGNode();
   /// \Returns the number of unscheduled successors.
   unsigned getNumUnscheduledSuccs() const { return UnscheduledSuccs; }
   void decrUnscheduledSuccs() {
@@ -292,6 +292,7 @@ class DependencyGraph {
 
   Context *Ctx = nullptr;
   std::optional<Context::CallbackID> CreateInstrCB;
+  std::optional<Context::CallbackID> EraseInstrCB;
 
   std::unique_ptr<BatchAAResults> BatchAA;
 
@@ -334,6 +335,12 @@ class DependencyGraph {
     // TODO: Update the dependencies for the new node.
     // TODO: Update the MemDGNode chain to include the new node if needed.
   }
+  /// Called by the callbacks when instruction \p I is about to get deleted.
+  void notifyEraseInstr(Instruction *I) {
+    InstrToNodeMap.erase(I);
+    // TODO: Update the dependencies.
+    // TODO: Update the MemDGNode chain to remove the node if needed.
+  }
 
 public:
   /// This constructor also registers callbacks.
@@ -341,10 +348,14 @@ class DependencyGraph {
       : Ctx(&Ctx), BatchAA(std::make_unique<BatchAAResults>(AA)) {
     CreateInstrCB = Ctx.registerCreateInstrCallback(
         [this](Instruction *I) { notifyCreateInstr(I); });
+    EraseInstrCB = Ctx.registerEraseInstrCallback(
+        [this](Instruction *I) { notifyEraseInstr(I); });
   }
   ~DependencyGraph() {
     if (CreateInstrCB)
       Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
+    if (EraseInstrCB)
+      Ctx->unregisterEraseInstrCallback(*EraseInstrCB);
   }
 
   DGNode *getNode(Instruction *I) const {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
index 022fd71df67dc6..3959f84c601e04 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -69,6 +69,10 @@ class SchedBundle {
 private:
   ContainerTy Nodes;
 
+  /// Called by the DGNode destructor to avoid accessing freed memory.
+  void eraseFromBundle(DGNode *N) { Nodes.erase(find(Nodes, N)); }
+  friend DGNode::~DGNode(); // For eraseFromBundle().
+
 public:
   SchedBundle() = default;
   SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 6217c9fecf45dd..4b0e12c28f07b7 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -10,6 +10,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/SandboxIR/Instruction.h"
 #include "llvm/SandboxIR/Utils.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
 
 namespace llvm::sandboxir {
 
@@ -58,6 +59,12 @@ bool PredIterator::operator==(const PredIterator &Other) const {
   return OpIt == Other.OpIt && MemIt == Other.MemIt;
 }
 
+DGNode::~DGNode() {
+  if (SB == nullptr)
+    return;
+  SB->eraseFromBundle(this);
+}
+
 #ifndef NDEBUG
 void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
   OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n";
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 206f6c5b4c1359..e6bb4b4684d262 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -830,3 +830,31 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
   // TODO: Check the dependencies to/from NewSN after they land.
   // TODO: Check the MemDGNode chain.
 }
+
+TEST_F(DependencyGraphTest, EraseInstrCallback) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+  // Check erase instruction callback.
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+  DAG.extend({S1, S3});
+  S2->eraseFromParent();
+  auto *DeletedN = DAG.getNodeOrNull(S2);
+  EXPECT_TRUE(DeletedN == nullptr);
+  // TODO: Check the dependencies to/from NewSN after they land.
+  // TODO: Check the MemDGNode chain.
+}

Copy link
Collaborator

@slackito slackito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@vporpo vporpo merged commit 6e48214 into llvm:main Nov 20, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants