Skip to content

Commit d6315af

Browse files
authored
[SandboxVec][InstrMaps] EraseInstr callback (#123256)
This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased.
1 parent 04383d6 commit d6315af

File tree

5 files changed

+49
-8
lines changed

5 files changed

+49
-8
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
#include "llvm/ADT/DenseMap.h"
1414
#include "llvm/ADT/SmallSet.h"
1515
#include "llvm/ADT/SmallVector.h"
16+
#include "llvm/SandboxIR/Context.h"
17+
#include "llvm/SandboxIR/Instruction.h"
1618
#include "llvm/SandboxIR/Value.h"
1719
#include "llvm/Support/Casting.h"
1820
#include "llvm/Support/raw_ostream.h"
21+
#include <algorithm>
1922

2023
namespace llvm::sandboxir {
2124

@@ -30,8 +33,37 @@ class InstrMaps {
3033
/// with the same lane, as they may be coming from vectorizing different
3134
/// original values.
3235
DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
36+
Context &Ctx;
37+
std::optional<Context::CallbackID> EraseInstrCB;
38+
39+
private:
40+
void notifyEraseInstr(Value *V) {
41+
// We don't know if V is an original or a vector value.
42+
auto It = OrigToVectorMap.find(V);
43+
if (It != OrigToVectorMap.end()) {
44+
// V is an original value.
45+
// Remove it from VectorToOrigLaneMap.
46+
Value *Vec = It->second;
47+
VectorToOrigLaneMap[Vec].erase(V);
48+
// Now erase V from OrigToVectorMap.
49+
OrigToVectorMap.erase(It);
50+
} else {
51+
// V is a vector value.
52+
// Go over the original values it came from and remove them from
53+
// OrigToVectorMap.
54+
for (auto [Orig, Lane] : VectorToOrigLaneMap[V])
55+
OrigToVectorMap.erase(Orig);
56+
// Now erase V from VectorToOrigLaneMap.
57+
VectorToOrigLaneMap.erase(V);
58+
}
59+
}
3360

3461
public:
62+
InstrMaps(Context &Ctx) : Ctx(Ctx) {
63+
EraseInstrCB = Ctx.registerEraseInstrCallback(
64+
[this](Instruction *I) { notifyEraseInstr(I); });
65+
}
66+
~InstrMaps() { Ctx.unregisterEraseInstrCallback(*EraseInstrCB); }
3567
/// \Returns the vector value that we got from vectorizing \p Orig, or
3668
/// nullptr if not found.
3769
Value *getVectorForOrig(Value *Orig) const {

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class BottomUpVec final : public FunctionPass {
2828
std::unique_ptr<LegalityAnalysis> Legality;
2929
DenseSet<Instruction *> DeadInstrCandidates;
3030
/// Maps scalars to vectors.
31-
InstrMaps IMaps;
31+
std::unique_ptr<InstrMaps> IMaps;
3232

3333
/// Creates and returns a vector instruction that replaces the instructions in
3434
/// \p Bndl. \p Operands are the already vectorized operands.

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
161161
auto *VecI = CreateVectorInstr(Bndl, Operands);
162162
if (VecI != nullptr) {
163163
Change = true;
164-
IMaps.registerVector(Bndl, VecI);
164+
IMaps->registerVector(Bndl, VecI);
165165
}
166166
return VecI;
167167
}
@@ -315,10 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
315315
}
316316

317317
bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
318-
IMaps.clear();
318+
IMaps = std::make_unique<InstrMaps>(F.getContext());
319319
Legality = std::make_unique<LegalityAnalysis>(
320320
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
321-
F.getContext(), IMaps);
321+
F.getContext(), *IMaps);
322322
Change = false;
323323
const auto &DL = F.getParent()->getDataLayout();
324324
unsigned VecRegBits =

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
5353
auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
5454
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
5555

56-
sandboxir::InstrMaps IMaps;
56+
sandboxir::InstrMaps IMaps(Ctx);
5757
// Check with empty IMaps.
5858
EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
5959
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
@@ -75,4 +75,13 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
7575
#ifndef NDEBUG
7676
EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
7777
#endif // NDEBUG
78+
// Check callbacks: erase original instr.
79+
Add0->eraseFromParent();
80+
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0));
81+
EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
82+
EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
83+
// Check callbacks: erase vector instr.
84+
VAdd0->eraseFromParent();
85+
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
86+
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
7887
}

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
111111
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
112112
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
113113

114-
llvm::sandboxir::InstrMaps IMaps;
114+
llvm::sandboxir::InstrMaps IMaps(Ctx);
115115
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
116116
const auto &Result =
117117
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
@@ -230,7 +230,7 @@ define void @foo(ptr %ptr) {
230230
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
231231
auto *St1 = cast<sandboxir::StoreInst>(&*It++);
232232

233-
llvm::sandboxir::InstrMaps IMaps;
233+
llvm::sandboxir::InstrMaps IMaps(Ctx);
234234
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
235235
{
236236
// Can vectorize St0,St1.
@@ -266,7 +266,7 @@ define void @foo() {
266266
};
267267

268268
sandboxir::Context Ctx(C);
269-
llvm::sandboxir::InstrMaps IMaps;
269+
llvm::sandboxir::InstrMaps IMaps(Ctx);
270270
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
271271
EXPECT_TRUE(
272272
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));

0 commit comments

Comments
 (0)