Skip to content

Commit 4df71ab

Browse files
authored
[SandboxIR] Add callbacks for instruction insert/remove/move ops (#112965)
1 parent a9c417c commit 4df71ab

File tree

4 files changed

+233
-7
lines changed

4 files changed

+233
-7
lines changed

llvm/include/llvm/SandboxIR/Context.h

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,39 @@
99
#ifndef LLVM_SANDBOXIR_CONTEXT_H
1010
#define LLVM_SANDBOXIR_CONTEXT_H
1111

12+
#include "llvm/ADT/DenseMap.h"
13+
#include "llvm/ADT/MapVector.h"
14+
#include "llvm/ADT/SmallVector.h"
1215
#include "llvm/IR/LLVMContext.h"
1316
#include "llvm/SandboxIR/Tracker.h"
1417
#include "llvm/SandboxIR/Type.h"
1518

19+
#include <cstdint>
20+
1621
namespace llvm::sandboxir {
1722

18-
class Module;
19-
class Value;
2023
class Argument;
24+
class BBIterator;
2125
class Constant;
26+
class Module;
27+
class Value;
2228

2329
class Context {
30+
public:
31+
// A EraseInstrCallback receives the instruction about to be erased.
32+
using EraseInstrCallback = std::function<void(Instruction *)>;
33+
// A CreateInstrCallback receives the instruction about to be created.
34+
using CreateInstrCallback = std::function<void(Instruction *)>;
35+
// A MoveInstrCallback receives the instruction about to be moved, the
36+
// destination BB and an iterator pointing to the insertion position.
37+
using MoveInstrCallback =
38+
std::function<void(Instruction *, const BBIterator &)>;
39+
40+
/// An ID for a registered callback. Used for deregistration. Using a 64-bit
41+
/// integer so we don't have to worry about the unlikely case of overflowing
42+
/// a 32-bit counter.
43+
using CallbackID = uint64_t;
44+
2445
protected:
2546
LLVMContext &LLVMCtx;
2647
friend class Type; // For LLVMCtx.
@@ -48,6 +69,21 @@ class Context {
4869
/// Type objects.
4970
DenseMap<llvm::Type *, std::unique_ptr<Type, TypeDeleter>> LLVMTypeToTypeMap;
5071

72+
/// Callbacks called when an IR instruction is about to get erased. Keys are
73+
/// used as IDs for deregistration.
74+
MapVector<CallbackID, EraseInstrCallback> EraseInstrCallbacks;
75+
/// Callbacks called when an IR instruction is about to get created. Keys are
76+
/// used as IDs for deregistration.
77+
MapVector<CallbackID, CreateInstrCallback> CreateInstrCallbacks;
78+
/// Callbacks called when an IR instruction is about to get moved. Keys are
79+
/// used as IDs for deregistration.
80+
MapVector<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
81+
82+
/// A counter used for assigning callback IDs during registration. The same
83+
/// counter is used for all kinds of callbacks so we can detect mismatched
84+
/// registration/deregistration.
85+
CallbackID NextCallbackID = 0;
86+
5187
/// Remove \p V from the maps and returns the unique_ptr.
5288
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
5389
/// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively
@@ -70,6 +106,10 @@ class Context {
70106
Constant *getOrCreateConstant(llvm::Constant *LLVMC);
71107
friend class Utils; // For getMemoryBase
72108

109+
void runEraseInstrCallbacks(Instruction *I);
110+
void runCreateInstrCallbacks(Instruction *I);
111+
void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);
112+
73113
// Friends for getOrCreateConstant().
74114
#define DEF_CONST(ID, CLASS) friend class CLASS;
75115
#include "llvm/SandboxIR/Values.def"
@@ -198,6 +238,28 @@ class Context {
198238

199239
/// \Returns the number of values registered with Context.
200240
size_t getNumValues() const { return LLVMValueToValueMap.size(); }
241+
242+
/// Register a callback that gets called when a SandboxIR instruction is about
243+
/// to be removed from its parent. Note that this will also be called when
244+
/// reverting the creation of an instruction.
245+
/// \Returns a callback ID for later deregistration.
246+
CallbackID registerEraseInstrCallback(EraseInstrCallback CB);
247+
void unregisterEraseInstrCallback(CallbackID ID);
248+
249+
/// Register a callback that gets called right after a SandboxIR instruction
250+
/// is created. Note that this will also be called when reverting the removal
251+
/// of an instruction.
252+
/// \Returns a callback ID for later deregistration.
253+
CallbackID registerCreateInstrCallback(CreateInstrCallback CB);
254+
void unregisterCreateInstrCallback(CallbackID ID);
255+
256+
/// Register a callback that gets called when a SandboxIR instruction is about
257+
/// to be moved. Note that this will also be called when reverting a move.
258+
/// \Returns a callback ID for later deregistration.
259+
CallbackID registerMoveInstrCallback(MoveInstrCallback CB);
260+
void unregisterMoveInstrCallback(CallbackID ID);
261+
262+
// TODO: Add callbacks for instructions inserted/removed if needed.
201263
};
202264

203265
} // namespace llvm::sandboxir

llvm/lib/SandboxIR/Context.cpp

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,20 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
3535
assert(VPtr->getSubclassID() != Value::ClassID::User &&
3636
"Can't register a user!");
3737

38+
Value *V = VPtr.get();
39+
[[maybe_unused]] auto Pair =
40+
LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
41+
assert(Pair.second && "Already exists!");
42+
3843
// Track creation of instructions.
3944
// Please note that we don't allow the creation of detached instructions,
4045
// meaning that the instructions need to be inserted into a block upon
4146
// creation. This is why the tracker class combines creation and insertion.
42-
if (auto *I = dyn_cast<Instruction>(VPtr.get()))
47+
if (auto *I = dyn_cast<Instruction>(V)) {
4348
getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
49+
runCreateInstrCallbacks(I);
50+
}
4451

45-
Value *V = VPtr.get();
46-
[[maybe_unused]] auto Pair =
47-
LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
48-
assert(Pair.second && "Already exists!");
4952
return V;
5053
}
5154

@@ -660,4 +663,64 @@ Module *Context::createModule(llvm::Module *LLVMM) {
660663
return M;
661664
}
662665

666+
void Context::runEraseInstrCallbacks(Instruction *I) {
667+
for (const auto &CBEntry : EraseInstrCallbacks)
668+
CBEntry.second(I);
669+
}
670+
671+
void Context::runCreateInstrCallbacks(Instruction *I) {
672+
for (auto &CBEntry : CreateInstrCallbacks)
673+
CBEntry.second(I);
674+
}
675+
676+
void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
677+
for (auto &CBEntry : MoveInstrCallbacks)
678+
CBEntry.second(I, WhereIt);
679+
}
680+
681+
// An arbitrary limit, to check for accidental misuse. We expect a small number
682+
// of callbacks to be registered at a time, but we can increase this number if
683+
// we discover we needed more.
684+
static constexpr int MaxRegisteredCallbacks = 16;
685+
686+
Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
687+
assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
688+
"EraseInstrCallbacks size limit exceeded");
689+
CallbackID ID = NextCallbackID++;
690+
EraseInstrCallbacks[ID] = CB;
691+
return ID;
692+
}
693+
void Context::unregisterEraseInstrCallback(CallbackID ID) {
694+
[[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(ID);
695+
assert(Erased &&
696+
"Callback ID not found in EraseInstrCallbacks during deregistration");
697+
}
698+
699+
Context::CallbackID
700+
Context::registerCreateInstrCallback(CreateInstrCallback CB) {
701+
assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
702+
"CreateInstrCallbacks size limit exceeded");
703+
CallbackID ID = NextCallbackID++;
704+
CreateInstrCallbacks[ID] = CB;
705+
return ID;
706+
}
707+
void Context::unregisterCreateInstrCallback(CallbackID ID) {
708+
[[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(ID);
709+
assert(Erased &&
710+
"Callback ID not found in CreateInstrCallbacks during deregistration");
711+
}
712+
713+
Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
714+
assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
715+
"MoveInstrCallbacks size limit exceeded");
716+
CallbackID ID = NextCallbackID++;
717+
MoveInstrCallbacks[ID] = CB;
718+
return ID;
719+
}
720+
void Context::unregisterMoveInstrCallback(CallbackID ID) {
721+
[[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID);
722+
assert(Erased &&
723+
"Callback ID not found in MoveInstrCallbacks during deregistration");
724+
}
725+
663726
} // namespace llvm::sandboxir

llvm/lib/SandboxIR/Instruction.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ void Instruction::removeFromParent() {
7373

7474
void Instruction::eraseFromParent() {
7575
assert(users().empty() && "Still connected to users, can't erase!");
76+
77+
Ctx.runEraseInstrCallbacks(this);
7678
std::unique_ptr<Value> Detached = Ctx.detach(this);
7779
auto LLVMInstrs = getLLVMInstrs();
7880

@@ -100,6 +102,7 @@ void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) {
100102
// Destination is same as origin, nothing to do.
101103
return;
102104

105+
Ctx.runMoveInstrCallbacks(this, WhereIt);
103106
Ctx.getTracker().emplaceIfTracking<MoveInstr>(this);
104107

105108
auto *LLVMBB = cast<llvm::BasicBlock>(BB.Val);

llvm/unittests/SandboxIR/SandboxIRTest.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/SandboxIR/Value.h"
2323
#include "llvm/Support/SourceMgr.h"
2424
#include "gmock/gmock-matchers.h"
25+
#include "gmock/gmock-more-matchers.h"
2526
#include "gtest/gtest.h"
2627

2728
using namespace llvm;
@@ -5962,3 +5963,100 @@ TEST_F(SandboxIRTest, CheckClassof) {
59625963
EXPECT_NE(&sandboxir::CLASS::classof, &sandboxir::Instruction::classof);
59635964
#include "llvm/SandboxIR/Values.def"
59645965
}
5966+
5967+
TEST_F(SandboxIRTest, InstructionCallbacks) {
5968+
parseIR(C, R"IR(
5969+
define void @foo(ptr %ptr, i8 %val) {
5970+
ret void
5971+
}
5972+
)IR");
5973+
Function &LLVMF = *M->getFunction("foo");
5974+
sandboxir::Context Ctx(C);
5975+
5976+
auto &F = *Ctx.createFunction(&LLVMF);
5977+
auto &BB = *F.begin();
5978+
sandboxir::Argument *Ptr = F.getArg(0);
5979+
sandboxir::Argument *Val = F.getArg(1);
5980+
sandboxir::Instruction *Ret = &BB.front();
5981+
5982+
SmallVector<sandboxir::Instruction *> Inserted;
5983+
auto InsertCbId = Ctx.registerCreateInstrCallback(
5984+
[&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); });
5985+
5986+
SmallVector<sandboxir::Instruction *> Removed;
5987+
auto RemoveCbId = Ctx.registerEraseInstrCallback(
5988+
[&Removed](sandboxir::Instruction *I) { Removed.push_back(I); });
5989+
5990+
// Keep the moved instruction and the instruction pointed by the Where
5991+
// iterator so we can check both callback arguments work as expected.
5992+
SmallVector<std::pair<sandboxir::Instruction *, sandboxir::Instruction *>>
5993+
Moved;
5994+
auto MoveCbId = Ctx.registerMoveInstrCallback(
5995+
[&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) {
5996+
// Use a nullptr to signal "move to end" to keep it single. We only
5997+
// have a basic block in this test case anyway.
5998+
if (Where == Where.getNodeParent()->end())
5999+
Moved.push_back(std::make_pair(I, nullptr));
6000+
else
6001+
Moved.push_back(std::make_pair(I, &*Where));
6002+
});
6003+
6004+
// Two more insertion callbacks, to check that they're called in registration
6005+
// order.
6006+
SmallVector<int> Order;
6007+
auto CheckOrderInsertCbId1 = Ctx.registerCreateInstrCallback(
6008+
[&Order](sandboxir::Instruction *I) { Order.push_back(1); });
6009+
6010+
auto CheckOrderInsertCbId2 = Ctx.registerCreateInstrCallback(
6011+
[&Order](sandboxir::Instruction *I) { Order.push_back(2); });
6012+
6013+
Ctx.save();
6014+
auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
6015+
Ret->getIterator(), Ctx);
6016+
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
6017+
EXPECT_THAT(Removed, testing::IsEmpty());
6018+
EXPECT_THAT(Moved, testing::IsEmpty());
6019+
EXPECT_THAT(Order, testing::ElementsAre(1, 2));
6020+
6021+
Ret->moveBefore(NewI);
6022+
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
6023+
EXPECT_THAT(Removed, testing::IsEmpty());
6024+
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
6025+
6026+
Ret->eraseFromParent();
6027+
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
6028+
EXPECT_THAT(Removed, testing::ElementsAre(Ret));
6029+
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
6030+
6031+
NewI->eraseFromParent();
6032+
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
6033+
EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI));
6034+
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
6035+
6036+
// Check that after revert the callbacks have been called for the inverse
6037+
// operations of the changes made so far.
6038+
Ctx.revert();
6039+
EXPECT_THAT(Inserted, testing::ElementsAre(NewI, NewI, Ret));
6040+
EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI));
6041+
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI),
6042+
std::make_pair(Ret, nullptr)));
6043+
EXPECT_THAT(Order, testing::ElementsAre(1, 2, 1, 2, 1, 2));
6044+
6045+
// Check that deregistration works. Do an operation of each type after
6046+
// deregistering callbacks and check.
6047+
Inserted.clear();
6048+
Removed.clear();
6049+
Moved.clear();
6050+
Ctx.unregisterCreateInstrCallback(InsertCbId);
6051+
Ctx.unregisterEraseInstrCallback(RemoveCbId);
6052+
Ctx.unregisterMoveInstrCallback(MoveCbId);
6053+
Ctx.unregisterCreateInstrCallback(CheckOrderInsertCbId1);
6054+
Ctx.unregisterCreateInstrCallback(CheckOrderInsertCbId2);
6055+
auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
6056+
Ret->getIterator(), Ctx);
6057+
Ret->moveBefore(NewI2);
6058+
Ret->eraseFromParent();
6059+
EXPECT_THAT(Inserted, testing::IsEmpty());
6060+
EXPECT_THAT(Removed, testing::IsEmpty());
6061+
EXPECT_THAT(Moved, testing::IsEmpty());
6062+
}

0 commit comments

Comments
 (0)