Skip to content

[SandboxIR][Tracker] Track eraseFromParent() #99431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ class Instruction : public sandboxir::User {
/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
virtual SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const = 0;
friend class EraseFromParent; // For getLLVMInstrs().

public:
static const char *getOpcodeName(Opcode Opc);
Expand Down Expand Up @@ -658,6 +659,7 @@ class Context {
friend void Instruction::eraseFromParent(); // For detach().
/// Take ownership of VPtr and store it in `LLVMValueToValueMap`.
Value *registerValue(std::unique_ptr<Value> &&VPtr);
friend class EraseFromParent; // For registerValue().
/// This is the actual function that creates sandboxir values for \p V,
/// and among others handles all instruction types.
Value *getOrCreateValueInternal(llvm::Value *V, llvm::User *U = nullptr);
Expand All @@ -682,7 +684,7 @@ class Context {
friend class BasicBlock; // For getOrCreateValue().

public:
Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx) {}
Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx), IRTracker(*this) {}

Tracker &getTracker() { return IRTracker; }
/// Convenience function for `getTracker().save()`
Expand Down
40 changes: 39 additions & 1 deletion llvm/include/llvm/SandboxIR/Tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#ifndef LLVM_SANDBOXIR_TRACKER_H
#define LLVM_SANDBOXIR_TRACKER_H

#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
Expand Down Expand Up @@ -99,6 +100,41 @@ class UseSet : public IRChangeBase {
#endif
};

class EraseFromParent : public IRChangeBase {
/// Contains all the data we need to restore an "erased" (i.e., detached)
/// instruction: the instruction itself and its operands in order.
struct InstrAndOperands {
/// The operands that got dropped.
SmallVector<llvm::Value *> Operands;
/// The instruction that got "erased".
llvm::Instruction *LLVMI;
};
/// The instruction data is in reverse program order, which helps create the
/// original program order during revert().
SmallVector<InstrAndOperands> InstrData;
/// This is either the next Instruction in the stream, or the parent
/// BasicBlock if at the end of the BB.
PointerUnion<llvm::Instruction *, llvm::BasicBlock *> NextLLVMIOrBB;
/// We take ownership of the "erased" instruction.
std::unique_ptr<sandboxir::Value> ErasedIPtr;

public:
EraseFromParent(std::unique_ptr<sandboxir::Value> &&IPtr, Tracker &Tracker);
void revert() final;
void accept() final;
#ifndef NDEBUG
void dump(raw_ostream &OS) const final {
dumpCommon(OS);
OS << "EraseFromParent";
}
LLVM_DUMP_METHOD void dump() const final;
friend raw_ostream &operator<<(raw_ostream &OS, const EraseFromParent &C) {
C.dump(OS);
return OS;
}
#endif
};

/// The tracker collects all the change objects and implements the main API for
/// saving / reverting / accepting.
class Tracker {
Expand All @@ -116,6 +152,7 @@ class Tracker {
#endif
/// The current state of the tracker.
TrackerState State = TrackerState::Disabled;
Context &Ctx;

public:
#ifndef NDEBUG
Expand All @@ -124,8 +161,9 @@ class Tracker {
bool InMiddleOfCreatingChange = false;
#endif // NDEBUG

Tracker() = default;
explicit Tracker(Context &Ctx) : Ctx(Ctx) {}
~Tracker();
Context &getContext() const { return Ctx; }
/// Record \p Change and take ownership. This is the main function used to
/// track Sandbox IR changes.
void track(std::unique_ptr<IRChangeBase> &&Change);
Expand Down
25 changes: 20 additions & 5 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,26 @@ void Instruction::removeFromParent() {
void Instruction::eraseFromParent() {
assert(users().empty() && "Still connected to users, can't erase!");
std::unique_ptr<Value> Detached = Ctx.detach(this);
// We don't have Tracking yet, so just erase the LLVM IR instructions.
// Erase in reverse to avoid erasing nstructions with attached uses.
auto Instrs = getLLVMInstrs();
for (llvm::Instruction *I : reverse(Instrs))
I->eraseFromParent();
auto LLVMInstrs = getLLVMInstrs();

auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking()) {
Tracker.track(
std::make_unique<EraseFromParent>(std::move(Detached), Tracker));
// We don't actually delete the IR instruction, because then it would be
// impossible to bring it back from the dead at the same memory location.
// Instead we remove it from its BB and track its current location.
for (llvm::Instruction *I : LLVMInstrs)
I->removeFromParent();
// TODO: Multi-instructions need special treatment because some of the
// references are internal to the instruction.
for (llvm::Instruction *I : LLVMInstrs)
I->dropAllReferences();
} else {
// Erase in reverse to avoid erasing nstructions with attached uses.
for (llvm::Instruction *I : reverse(LLVMInstrs))
I->eraseFromParent();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo putting an else here rather than the return is a bit clearer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) {
Expand Down
59 changes: 59 additions & 0 deletions llvm/lib/SandboxIR/Tracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,65 @@ Tracker::~Tracker() {
assert(Changes.empty() && "You must accept or revert changes!");
}

EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr,
Tracker &Tracker)
: IRChangeBase(Tracker), ErasedIPtr(std::move(ErasedIPtr)) {
auto *I = cast<Instruction>(this->ErasedIPtr.get());
auto LLVMInstrs = I->getLLVMInstrs();
// Iterate in reverse program order.
for (auto *LLVMI : reverse(LLVMInstrs)) {
SmallVector<llvm::Value *> Operands;
Operands.reserve(LLVMI->getNumOperands());
for (auto [OpNum, Use] : enumerate(LLVMI->operands()))
Operands.push_back(Use.get());
InstrData.push_back({Operands, LLVMI});
}
assert(is_sorted(InstrData,
[](const auto &D0, const auto &D1) {
return D0.LLVMI->comesBefore(D1.LLVMI);
}) &&
"Expected reverse program order!");
auto *BotLLVMI = cast<llvm::Instruction>(I->Val);
if (BotLLVMI->getNextNode() != nullptr)
NextLLVMIOrBB = BotLLVMI->getNextNode();
else
NextLLVMIOrBB = BotLLVMI->getParent();
}

void EraseFromParent::accept() {
for (const auto &IData : InstrData)
IData.LLVMI->deleteValue();
}

void EraseFromParent::revert() {
// Place the bottom-most instruction first.
auto [Operands, BotLLVMI] = InstrData[0];
if (auto *NextLLVMI = NextLLVMIOrBB.dyn_cast<llvm::Instruction *>()) {
BotLLVMI->insertBefore(NextLLVMI);
} else {
auto *LLVMBB = NextLLVMIOrBB.get<llvm::BasicBlock *>();
BotLLVMI->insertInto(LLVMBB, LLVMBB->end());
}
for (auto [OpNum, Op] : enumerate(Operands))
BotLLVMI->setOperand(OpNum, Op);

// Go over the rest of the instructions and stack them on top.
for (auto [Operands, LLVMI] : drop_begin(InstrData)) {
LLVMI->insertBefore(BotLLVMI);
for (auto [OpNum, Op] : enumerate(Operands))
LLVMI->setOperand(OpNum, Op);
BotLLVMI = LLVMI;
}
Parent.getContext().registerValue(std::move(ErasedIPtr));
}

#ifndef NDEBUG
void EraseFromParent::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif

void Tracker::track(std::unique_ptr<IRChangeBase> &&Change) {
assert(State == TrackerState::Record && "The tracker should be tracking!");
Changes.push_back(std::move(Change));
Expand Down
52 changes: 52 additions & 0 deletions llvm/unittests/SandboxIR/TrackerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,55 @@ define void @foo(ptr %ptr) {
Ctx.accept();
EXPECT_EQ(St0->getOperand(0), Ld1);
}

// TODO: Test multi-instruction patterns.
TEST_F(TrackerTest, EraseFromParent) {
parseIR(C, R"IR(
define void @foo(i32 %v1) {
%add0 = add i32 %v1, %v1
%add1 = add i32 %add0, %v1
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto *F = Ctx.createFunction(&LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
sandboxir::Instruction *Add0 = &*It++;
sandboxir::Instruction *Add1 = &*It++;
sandboxir::Instruction *Ret = &*It++;

Ctx.save();
// Check erase.
Add1->eraseFromParent();
It = BB->begin();
EXPECT_EQ(&*It++, Add0);
EXPECT_EQ(&*It++, Ret);
EXPECT_EQ(It, BB->end());
EXPECT_EQ(Add0->getNumUses(), 0u);

// Check revert().
Ctx.revert();
It = BB->begin();
EXPECT_EQ(&*It++, Add0);
EXPECT_EQ(&*It++, Add1);
EXPECT_EQ(&*It++, Ret);
EXPECT_EQ(It, BB->end());
EXPECT_EQ(Add1->getOperand(0), Add0);

// Same for the last instruction in the block.
Ctx.save();
Ret->eraseFromParent();
It = BB->begin();
EXPECT_EQ(&*It++, Add0);
EXPECT_EQ(&*It++, Add1);
EXPECT_EQ(It, BB->end());
Ctx.revert();
It = BB->begin();
EXPECT_EQ(&*It++, Add0);
EXPECT_EQ(&*It++, Add1);
EXPECT_EQ(&*It++, Ret);
EXPECT_EQ(It, BB->end());
}
Loading