Skip to content

[SandboxVec][InstrMaps] EraseInstr callback #123256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/SandboxIR/Context.h"
#include "llvm/SandboxIR/Instruction.h"
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>

namespace llvm::sandboxir {

Expand All @@ -30,8 +33,37 @@ class InstrMaps {
/// with the same lane, as they may be coming from vectorizing different
/// original values.
DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
Context &Ctx;
std::optional<Context::CallbackID> EraseInstrCB;

private:
void notifyEraseInstr(Value *V) {
// We don't know if V is an original or a vector value.
auto It = OrigToVectorMap.find(V);
if (It != OrigToVectorMap.end()) {
// V is an original value.
// Remove it from VectorToOrigLaneMap.
Value *Vec = It->second;
VectorToOrigLaneMap[Vec].erase(V);
// Now erase V from OrigToVectorMap.
OrigToVectorMap.erase(It);
} else {
// V is a vector value.
// Go over the original values it came from and remove them from
// OrigToVectorMap.
for (auto [Orig, Lane] : VectorToOrigLaneMap[V])
OrigToVectorMap.erase(Orig);
// Now erase V from VectorToOrigLaneMap.
VectorToOrigLaneMap.erase(V);
}
}

public:
InstrMaps(Context &Ctx) : Ctx(Ctx) {
EraseInstrCB = Ctx.registerEraseInstrCallback(
[this](Instruction *I) { notifyEraseInstr(I); });
}
~InstrMaps() { Ctx.unregisterEraseInstrCallback(*EraseInstrCB); }
/// \Returns the vector value that we got from vectorizing \p Orig, or
/// nullptr if not found.
Value *getVectorForOrig(Value *Orig) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BottomUpVec final : public FunctionPass {
std::unique_ptr<LegalityAnalysis> Legality;
DenseSet<Instruction *> DeadInstrCandidates;
/// Maps scalars to vectors.
InstrMaps IMaps;
std::unique_ptr<InstrMaps> IMaps;

/// Creates and returns a vector instruction that replaces the instructions in
/// \p Bndl. \p Operands are the already vectorized operands.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
auto *VecI = CreateVectorInstr(Bndl, Operands);
if (VecI != nullptr) {
Change = true;
IMaps.registerVector(Bndl, VecI);
IMaps->registerVector(Bndl, VecI);
}
return VecI;
}
Expand Down Expand Up @@ -315,10 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
}

bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
IMaps.clear();
IMaps = std::make_unique<InstrMaps>(F.getContext());
Legality = std::make_unique<LegalityAnalysis>(
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
F.getContext(), IMaps);
F.getContext(), *IMaps);
Change = false;
const auto &DL = F.getParent()->getDataLayout();
unsigned VecRegBits =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::InstrMaps IMaps;
sandboxir::InstrMaps IMaps(Ctx);
// Check with empty IMaps.
EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
Expand All @@ -75,4 +75,13 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
#ifndef NDEBUG
EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
#endif // NDEBUG
// Check callbacks: erase original instr.
Add0->eraseFromParent();
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0));
EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
// Check callbacks: erase vector instr.
VAdd0->eraseFromParent();
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);

llvm::sandboxir::InstrMaps IMaps;
llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
const auto &Result =
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
Expand Down Expand Up @@ -230,7 +230,7 @@ define void @foo(ptr %ptr) {
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);

llvm::sandboxir::InstrMaps IMaps;
llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
{
// Can vectorize St0,St1.
Expand Down Expand Up @@ -266,7 +266,7 @@ define void @foo() {
};

sandboxir::Context Ctx(C);
llvm::sandboxir::InstrMaps IMaps;
llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
Expand Down
Loading