-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SandboxVec][InstrMaps] EraseInstr callback #123256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased.
@llvm/pr-subscribers-llvm-transforms Author: vporpo (vporpo) ChangesThis patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased. Full diff: https://github.com/llvm/llvm-project/pull/123256.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
index 2c4ba30f6fd052..999fbb0aad9405 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -13,9 +13,12 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Instruction.h"
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
namespace llvm::sandboxir {
@@ -30,8 +33,37 @@ class InstrMaps {
/// with the same lane, as they may be coming from vectorizing different
/// original values.
DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
+ Context &Ctx;
+ std::optional<Context::CallbackID> EraseInstrCB;
+
+private:
+ void notifyEraseInstr(Value *V) {
+ // We don't know if V is an original or a vector value.
+ auto It = OrigToVectorMap.find(V);
+ if (It != OrigToVectorMap.end()) {
+ // V is an original value.
+ // Remove it from VectorToOrigLaneMap.
+ Value *Vec = It->second;
+ VectorToOrigLaneMap[Vec].erase(V);
+ // Now erase V from OrigToVectorMap.
+ OrigToVectorMap.erase(It);
+ } else {
+ // V is a vector value.
+ // Go over the original values it came from and remove them from
+ // OrigToVectorMap.
+ for (auto [Orig, Lane] : VectorToOrigLaneMap[V])
+ OrigToVectorMap.erase(Orig);
+ // Now erase V from VectorToOrigLaneMap.
+ VectorToOrigLaneMap.erase(V);
+ }
+ }
public:
+ InstrMaps(Context &Ctx) : Ctx(Ctx) {
+ EraseInstrCB = Ctx.registerEraseInstrCallback(
+ [this](Instruction *I) { notifyEraseInstr(I); });
+ }
+ ~InstrMaps() { Ctx.unregisterEraseInstrCallback(*EraseInstrCB); }
/// \Returns the vector value that we got from vectorizing \p Orig, or
/// nullptr if not found.
Value *getVectorForOrig(Value *Orig) const {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 69cea3c4c7b53b..dd3012f7c9b556 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -28,7 +28,7 @@ class BottomUpVec final : public FunctionPass {
std::unique_ptr<LegalityAnalysis> Legality;
DenseSet<Instruction *> DeadInstrCandidates;
/// Maps scalars to vectors.
- InstrMaps IMaps;
+ std::unique_ptr<InstrMaps> IMaps;
/// Creates and returns a vector instruction that replaces the instructions in
/// \p Bndl. \p Operands are the already vectorized operands.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 6b2032be535603..b8e2697839a3c2 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -161,7 +161,7 @@ Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
auto *VecI = CreateVectorInstr(Bndl, Operands);
if (VecI != nullptr) {
Change = true;
- IMaps.registerVector(Bndl, VecI);
+ IMaps->registerVector(Bndl, VecI);
}
return VecI;
}
@@ -315,10 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
}
bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
- IMaps.clear();
+ IMaps = std::make_unique<InstrMaps>(F.getContext());
Legality = std::make_unique<LegalityAnalysis>(
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
- F.getContext(), IMaps);
+ F.getContext(), *IMaps);
Change = false;
const auto &DL = F.getParent()->getDataLayout();
unsigned VecRegBits =
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
index bcfb8db7f86741..11831b881ca7a8 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -53,7 +53,7 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
- sandboxir::InstrMaps IMaps;
+ sandboxir::InstrMaps IMaps(Ctx);
// Check with empty IMaps.
EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
@@ -75,4 +75,13 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
#ifndef NDEBUG
EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
#endif // NDEBUG
+ // Check callbacks: erase original instr.
+ Add0->eraseFromParent();
+ EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0));
+ EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
+ EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
+ // Check callbacks: erase vector instr.
+ VAdd0->eraseFromParent();
+ EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
+ EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 2e90462a633c17..069bfdba0a7cdb 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -111,7 +111,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
- llvm::sandboxir::InstrMaps IMaps;
+ llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
const auto &Result =
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
@@ -230,7 +230,7 @@ define void @foo(ptr %ptr) {
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);
- llvm::sandboxir::InstrMaps IMaps;
+ llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
{
// Can vectorize St0,St1.
@@ -266,7 +266,7 @@ define void @foo() {
};
sandboxir::Context Ctx(C);
- llvm::sandboxir::InstrMaps IMaps;
+ llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
|
@llvm/pr-subscribers-vectorizers Author: vporpo (vporpo) ChangesThis patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased. Full diff: https://github.com/llvm/llvm-project/pull/123256.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
index 2c4ba30f6fd052..999fbb0aad9405 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -13,9 +13,12 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Instruction.h"
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
namespace llvm::sandboxir {
@@ -30,8 +33,37 @@ class InstrMaps {
/// with the same lane, as they may be coming from vectorizing different
/// original values.
DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
+ Context &Ctx;
+ std::optional<Context::CallbackID> EraseInstrCB;
+
+private:
+ void notifyEraseInstr(Value *V) {
+ // We don't know if V is an original or a vector value.
+ auto It = OrigToVectorMap.find(V);
+ if (It != OrigToVectorMap.end()) {
+ // V is an original value.
+ // Remove it from VectorToOrigLaneMap.
+ Value *Vec = It->second;
+ VectorToOrigLaneMap[Vec].erase(V);
+ // Now erase V from OrigToVectorMap.
+ OrigToVectorMap.erase(It);
+ } else {
+ // V is a vector value.
+ // Go over the original values it came from and remove them from
+ // OrigToVectorMap.
+ for (auto [Orig, Lane] : VectorToOrigLaneMap[V])
+ OrigToVectorMap.erase(Orig);
+ // Now erase V from VectorToOrigLaneMap.
+ VectorToOrigLaneMap.erase(V);
+ }
+ }
public:
+ InstrMaps(Context &Ctx) : Ctx(Ctx) {
+ EraseInstrCB = Ctx.registerEraseInstrCallback(
+ [this](Instruction *I) { notifyEraseInstr(I); });
+ }
+ ~InstrMaps() { Ctx.unregisterEraseInstrCallback(*EraseInstrCB); }
/// \Returns the vector value that we got from vectorizing \p Orig, or
/// nullptr if not found.
Value *getVectorForOrig(Value *Orig) const {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 69cea3c4c7b53b..dd3012f7c9b556 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -28,7 +28,7 @@ class BottomUpVec final : public FunctionPass {
std::unique_ptr<LegalityAnalysis> Legality;
DenseSet<Instruction *> DeadInstrCandidates;
/// Maps scalars to vectors.
- InstrMaps IMaps;
+ std::unique_ptr<InstrMaps> IMaps;
/// Creates and returns a vector instruction that replaces the instructions in
/// \p Bndl. \p Operands are the already vectorized operands.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 6b2032be535603..b8e2697839a3c2 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -161,7 +161,7 @@ Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
auto *VecI = CreateVectorInstr(Bndl, Operands);
if (VecI != nullptr) {
Change = true;
- IMaps.registerVector(Bndl, VecI);
+ IMaps->registerVector(Bndl, VecI);
}
return VecI;
}
@@ -315,10 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
}
bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
- IMaps.clear();
+ IMaps = std::make_unique<InstrMaps>(F.getContext());
Legality = std::make_unique<LegalityAnalysis>(
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
- F.getContext(), IMaps);
+ F.getContext(), *IMaps);
Change = false;
const auto &DL = F.getParent()->getDataLayout();
unsigned VecRegBits =
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
index bcfb8db7f86741..11831b881ca7a8 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -53,7 +53,7 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
- sandboxir::InstrMaps IMaps;
+ sandboxir::InstrMaps IMaps(Ctx);
// Check with empty IMaps.
EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
@@ -75,4 +75,13 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
#ifndef NDEBUG
EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
#endif // NDEBUG
+ // Check callbacks: erase original instr.
+ Add0->eraseFromParent();
+ EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0));
+ EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
+ EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
+ // Check callbacks: erase vector instr.
+ VAdd0->eraseFromParent();
+ EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
+ EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 2e90462a633c17..069bfdba0a7cdb 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -111,7 +111,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
- llvm::sandboxir::InstrMaps IMaps;
+ llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
const auto &Result =
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
@@ -230,7 +230,7 @@ define void @foo(ptr %ptr) {
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);
- llvm::sandboxir::InstrMaps IMaps;
+ llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
{
// Can vectorize St0,St1.
@@ -266,7 +266,7 @@ define void @foo() {
};
sandboxir::Context Ctx(C);
- llvm::sandboxir::InstrMaps IMaps;
+ llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
|
This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased.