Skip to content

[SandboxVec][InstrMaps] EraseInstr callback #123256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2025
Merged

[SandboxVec][InstrMaps] EraseInstr callback #123256

merged 1 commit into from
Jan 17, 2025

Conversation

vporpo
Copy link
Contributor

@vporpo vporpo commented Jan 16, 2025

This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased.

This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets
updated when instructions get erased.
@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2025

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased.


Full diff: https://github.com/llvm/llvm-project/pull/123256.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h (+32)
  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h (+1-1)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (+3-3)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp (+10-1)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp (+3-3)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
index 2c4ba30f6fd052..999fbb0aad9405 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -13,9 +13,12 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Instruction.h"
 #include "llvm/SandboxIR/Value.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
+#include <algorithm>
 
 namespace llvm::sandboxir {
 
@@ -30,8 +33,37 @@ class InstrMaps {
   /// with the same lane, as they may be coming from vectorizing different
   /// original values.
   DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
+  Context &Ctx;
+  std::optional<Context::CallbackID> EraseInstrCB;
+
+private:
+  void notifyEraseInstr(Value *V) {
+    // We don't know if V is an original or a vector value.
+    auto It = OrigToVectorMap.find(V);
+    if (It != OrigToVectorMap.end()) {
+      // V is an original value.
+      // Remove it from VectorToOrigLaneMap.
+      Value *Vec = It->second;
+      VectorToOrigLaneMap[Vec].erase(V);
+      // Now erase V from OrigToVectorMap.
+      OrigToVectorMap.erase(It);
+    } else {
+      // V is a vector value.
+      // Go over the original values it came from and remove them from
+      // OrigToVectorMap.
+      for (auto [Orig, Lane] : VectorToOrigLaneMap[V])
+        OrigToVectorMap.erase(Orig);
+      // Now erase V from VectorToOrigLaneMap.
+      VectorToOrigLaneMap.erase(V);
+    }
+  }
 
 public:
+  InstrMaps(Context &Ctx) : Ctx(Ctx) {
+    EraseInstrCB = Ctx.registerEraseInstrCallback(
+        [this](Instruction *I) { notifyEraseInstr(I); });
+  }
+  ~InstrMaps() { Ctx.unregisterEraseInstrCallback(*EraseInstrCB); }
   /// \Returns the vector value that we got from vectorizing \p Orig, or
   /// nullptr if not found.
   Value *getVectorForOrig(Value *Orig) const {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 69cea3c4c7b53b..dd3012f7c9b556 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -28,7 +28,7 @@ class BottomUpVec final : public FunctionPass {
   std::unique_ptr<LegalityAnalysis> Legality;
   DenseSet<Instruction *> DeadInstrCandidates;
   /// Maps scalars to vectors.
-  InstrMaps IMaps;
+  std::unique_ptr<InstrMaps> IMaps;
 
   /// Creates and returns a vector instruction that replaces the instructions in
   /// \p Bndl. \p Operands are the already vectorized operands.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 6b2032be535603..b8e2697839a3c2 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -161,7 +161,7 @@ Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
   auto *VecI = CreateVectorInstr(Bndl, Operands);
   if (VecI != nullptr) {
     Change = true;
-    IMaps.registerVector(Bndl, VecI);
+    IMaps->registerVector(Bndl, VecI);
   }
   return VecI;
 }
@@ -315,10 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
 }
 
 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
-  IMaps.clear();
+  IMaps = std::make_unique<InstrMaps>(F.getContext());
   Legality = std::make_unique<LegalityAnalysis>(
       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
-      F.getContext(), IMaps);
+      F.getContext(), *IMaps);
   Change = false;
   const auto &DL = F.getParent()->getDataLayout();
   unsigned VecRegBits =
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
index bcfb8db7f86741..11831b881ca7a8 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -53,7 +53,7 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
   auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
   [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
-  sandboxir::InstrMaps IMaps;
+  sandboxir::InstrMaps IMaps(Ctx);
   // Check with empty IMaps.
   EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
   EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
@@ -75,4 +75,13 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
 #ifndef NDEBUG
   EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
 #endif // NDEBUG
+  // Check callbacks: erase original instr.
+  Add0->eraseFromParent();
+  EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0));
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
+  EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
+  // Check callbacks: erase vector instr.
+  VAdd0->eraseFromParent();
+  EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
+  EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
 }
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 2e90462a633c17..069bfdba0a7cdb 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -111,7 +111,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
   auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
 
-  llvm::sandboxir::InstrMaps IMaps;
+  llvm::sandboxir::InstrMaps IMaps(Ctx);
   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   const auto &Result =
       Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
@@ -230,7 +230,7 @@ define void @foo(ptr %ptr) {
   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
 
-  llvm::sandboxir::InstrMaps IMaps;
+  llvm::sandboxir::InstrMaps IMaps(Ctx);
   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   {
     // Can vectorize St0,St1.
@@ -266,7 +266,7 @@ define void @foo() {
   };
 
   sandboxir::Context Ctx(C);
-  llvm::sandboxir::InstrMaps IMaps;
+  llvm::sandboxir::InstrMaps IMaps(Ctx);
   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   EXPECT_TRUE(
       Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));

@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2025

@llvm/pr-subscribers-vectorizers

Author: vporpo (vporpo)

Changes

This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased.


Full diff: https://github.com/llvm/llvm-project/pull/123256.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h (+32)
  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h (+1-1)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (+3-3)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp (+10-1)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp (+3-3)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
index 2c4ba30f6fd052..999fbb0aad9405 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -13,9 +13,12 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Instruction.h"
 #include "llvm/SandboxIR/Value.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
+#include <algorithm>
 
 namespace llvm::sandboxir {
 
@@ -30,8 +33,37 @@ class InstrMaps {
   /// with the same lane, as they may be coming from vectorizing different
   /// original values.
   DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
+  Context &Ctx;
+  std::optional<Context::CallbackID> EraseInstrCB;
+
+private:
+  void notifyEraseInstr(Value *V) {
+    // We don't know if V is an original or a vector value.
+    auto It = OrigToVectorMap.find(V);
+    if (It != OrigToVectorMap.end()) {
+      // V is an original value.
+      // Remove it from VectorToOrigLaneMap.
+      Value *Vec = It->second;
+      VectorToOrigLaneMap[Vec].erase(V);
+      // Now erase V from OrigToVectorMap.
+      OrigToVectorMap.erase(It);
+    } else {
+      // V is a vector value.
+      // Go over the original values it came from and remove them from
+      // OrigToVectorMap.
+      for (auto [Orig, Lane] : VectorToOrigLaneMap[V])
+        OrigToVectorMap.erase(Orig);
+      // Now erase V from VectorToOrigLaneMap.
+      VectorToOrigLaneMap.erase(V);
+    }
+  }
 
 public:
+  InstrMaps(Context &Ctx) : Ctx(Ctx) {
+    EraseInstrCB = Ctx.registerEraseInstrCallback(
+        [this](Instruction *I) { notifyEraseInstr(I); });
+  }
+  ~InstrMaps() { Ctx.unregisterEraseInstrCallback(*EraseInstrCB); }
   /// \Returns the vector value that we got from vectorizing \p Orig, or
   /// nullptr if not found.
   Value *getVectorForOrig(Value *Orig) const {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 69cea3c4c7b53b..dd3012f7c9b556 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -28,7 +28,7 @@ class BottomUpVec final : public FunctionPass {
   std::unique_ptr<LegalityAnalysis> Legality;
   DenseSet<Instruction *> DeadInstrCandidates;
   /// Maps scalars to vectors.
-  InstrMaps IMaps;
+  std::unique_ptr<InstrMaps> IMaps;
 
   /// Creates and returns a vector instruction that replaces the instructions in
   /// \p Bndl. \p Operands are the already vectorized operands.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 6b2032be535603..b8e2697839a3c2 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -161,7 +161,7 @@ Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
   auto *VecI = CreateVectorInstr(Bndl, Operands);
   if (VecI != nullptr) {
     Change = true;
-    IMaps.registerVector(Bndl, VecI);
+    IMaps->registerVector(Bndl, VecI);
   }
   return VecI;
 }
@@ -315,10 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
 }
 
 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
-  IMaps.clear();
+  IMaps = std::make_unique<InstrMaps>(F.getContext());
   Legality = std::make_unique<LegalityAnalysis>(
       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
-      F.getContext(), IMaps);
+      F.getContext(), *IMaps);
   Change = false;
   const auto &DL = F.getParent()->getDataLayout();
   unsigned VecRegBits =
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
index bcfb8db7f86741..11831b881ca7a8 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -53,7 +53,7 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
   auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
   [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
-  sandboxir::InstrMaps IMaps;
+  sandboxir::InstrMaps IMaps(Ctx);
   // Check with empty IMaps.
   EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
   EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
@@ -75,4 +75,13 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
 #ifndef NDEBUG
   EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
 #endif // NDEBUG
+  // Check callbacks: erase original instr.
+  Add0->eraseFromParent();
+  EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0));
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
+  EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
+  // Check callbacks: erase vector instr.
+  VAdd0->eraseFromParent();
+  EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
+  EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
 }
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 2e90462a633c17..069bfdba0a7cdb 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -111,7 +111,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
   auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
 
-  llvm::sandboxir::InstrMaps IMaps;
+  llvm::sandboxir::InstrMaps IMaps(Ctx);
   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   const auto &Result =
       Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
@@ -230,7 +230,7 @@ define void @foo(ptr %ptr) {
   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
 
-  llvm::sandboxir::InstrMaps IMaps;
+  llvm::sandboxir::InstrMaps IMaps(Ctx);
   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   {
     // Can vectorize St0,St1.
@@ -266,7 +266,7 @@ define void @foo() {
   };
 
   sandboxir::Context Ctx(C);
-  llvm::sandboxir::InstrMaps IMaps;
+  llvm::sandboxir::InstrMaps IMaps(Ctx);
   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   EXPECT_TRUE(
       Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));

@vporpo vporpo merged commit d6315af into llvm:main Jan 17, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants