Skip to content

Commit 0931a2a

Browse files
authored
[SandboxIR] SetUse callback (#126985)
This patch implements a callback mechanism similar to the existing ones, but for getting notified whenever a Use edge gets updated. This is going to be used in a follow up patch by the Dependency Graph.
1 parent db1e15a commit 0931a2a

File tree

5 files changed

+111
-9
lines changed

5 files changed

+111
-9
lines changed

llvm/include/llvm/SandboxIR/Context.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class BBIterator;
2626
class Constant;
2727
class Module;
2828
class Value;
29+
class Use;
2930

3031
class Context {
3132
public:
@@ -37,6 +38,8 @@ class Context {
3738
// destination BB and an iterator pointing to the insertion position.
3839
using MoveInstrCallback =
3940
std::function<void(Instruction *, const BBIterator &)>;
41+
// A SetUseCallback receives the Use that is about to get its source set.
42+
using SetUseCallback = std::function<void(const Use &, Value *)>;
4043

4144
/// An ID for a registered callback. Used for deregistration. A dedicated type
4245
/// is employed so as to keep IDs opaque to the end user; only Context should
@@ -98,6 +101,9 @@ class Context {
98101
/// Callbacks called when an IR instruction is about to get moved. Keys are
99102
/// used as IDs for deregistration.
100103
MapVector<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
104+
/// Callbacks called when a Use gets its source set. Keys are used as IDs for
105+
/// deregistration.
106+
MapVector<CallbackID, SetUseCallback> SetUseCallbacks;
101107

102108
/// A counter used for assigning callback IDs during registration. The same
103109
/// counter is used for all kinds of callbacks so we can detect mismatched
@@ -129,6 +135,10 @@ class Context {
129135
void runEraseInstrCallbacks(Instruction *I);
130136
void runCreateInstrCallbacks(Instruction *I);
131137
void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);
138+
void runSetUseCallbacks(const Use &U, Value *NewSrc);
139+
140+
friend class User; // For runSetUseCallbacks().
141+
friend class Value; // For runSetUseCallbacks().
132142

133143
// Friends for getOrCreateConstant().
134144
#define DEF_CONST(ID, CLASS) friend class CLASS;
@@ -281,7 +291,10 @@ class Context {
281291
CallbackID registerMoveInstrCallback(MoveInstrCallback CB);
282292
void unregisterMoveInstrCallback(CallbackID ID);
283293

284-
// TODO: Add callbacks for instructions inserted/removed if needed.
294+
/// Register a callback that gets called when a Use gets set.
295+
/// \Returns a callback ID for later deregistration.
296+
CallbackID registerSetUseCallback(SetUseCallback CB);
297+
void unregisterSetUseCallback(CallbackID ID);
285298
};
286299

287300
} // namespace sandboxir

llvm/lib/SandboxIR/Context.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,11 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
687687
CBEntry.second(I, WhereIt);
688688
}
689689

690+
void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) {
691+
for (auto &CBEntry : SetUseCallbacks)
692+
CBEntry.second(U, NewSrc);
693+
}
694+
690695
// An arbitrary limit, to check for accidental misuse. We expect a small number
691696
// of callbacks to be registered at a time, but we can increase this number if
692697
// we discover we needed more.
@@ -732,4 +737,17 @@ void Context::unregisterMoveInstrCallback(CallbackID ID) {
732737
"Callback ID not found in MoveInstrCallbacks during deregistration");
733738
}
734739

740+
Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) {
741+
assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks &&
742+
"SetUseCallbacks size limit exceeded");
743+
CallbackID ID{NextCallbackID++};
744+
SetUseCallbacks[ID] = CB;
745+
return ID;
746+
}
747+
void Context::unregisterSetUseCallback(CallbackID ID) {
748+
[[maybe_unused]] bool Erased = SetUseCallbacks.erase(ID);
749+
assert(Erased &&
750+
"Callback ID not found in SetUseCallbacks during deregistration");
751+
}
752+
735753
} // namespace llvm::sandboxir

llvm/lib/SandboxIR/User.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,20 @@ bool User::classof(const Value *From) {
9090

9191
void User::setOperand(unsigned OperandIdx, Value *Operand) {
9292
assert(isa<llvm::User>(Val) && "No operands!");
93-
Ctx.getTracker().emplaceIfTracking<UseSet>(getOperandUse(OperandIdx));
93+
const auto &U = getOperandUse(OperandIdx);
94+
Ctx.getTracker().emplaceIfTracking<UseSet>(U);
95+
Ctx.runSetUseCallbacks(U, Operand);
9496
// We are delegating to llvm::User::setOperand().
9597
cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
9698
}
9799

98100
bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
99101
auto &Tracker = Ctx.getTracker();
100-
if (Tracker.isTracking()) {
101-
for (auto OpIdx : seq<unsigned>(0, getNumOperands())) {
102-
auto Use = getOperandUse(OpIdx);
103-
if (Use.get() == FromV)
102+
for (auto OpIdx : seq<unsigned>(0, getNumOperands())) {
103+
auto Use = getOperandUse(OpIdx);
104+
if (Use.get() == FromV) {
105+
Ctx.runSetUseCallbacks(Use, ToV);
106+
if (Tracker.isTracking())
104107
Tracker.emplaceIfTracking<UseSet>(Use);
105108
}
106109
}

llvm/lib/SandboxIR/Value.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,15 @@ void Value::replaceUsesWithIf(
5151
llvm::Value *OtherVal = OtherV->Val;
5252
// We are delegating RUWIf to LLVM IR's RUWIf.
5353
Val->replaceUsesWithIf(
54-
OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
54+
OtherVal, [&ShouldReplace, this, OtherV](llvm::Use &LLVMUse) -> bool {
5555
User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
5656
if (DstU == nullptr)
5757
return false;
5858
Use UseToReplace(&LLVMUse, DstU, Ctx);
5959
if (!ShouldReplace(UseToReplace))
6060
return false;
6161
Ctx.getTracker().emplaceIfTracking<UseSet>(UseToReplace);
62+
Ctx.runSetUseCallbacks(UseToReplace, OtherV);
6263
return true;
6364
});
6465
}
@@ -67,8 +68,9 @@ void Value::replaceAllUsesWith(Value *Other) {
6768
assert(getType() == Other->getType() &&
6869
"Replacing with Value of different type!");
6970
auto &Tracker = Ctx.getTracker();
70-
if (Tracker.isTracking()) {
71-
for (auto Use : uses())
71+
for (auto Use : uses()) {
72+
Ctx.runSetUseCallbacks(Use, Other);
73+
if (Tracker.isTracking())
7274
Tracker.track(std::make_unique<UseSet>(Use));
7375
}
7476
// We are delegating RAUW to LLVM IR's RAUW.

llvm/unittests/SandboxIR/SandboxIRTest.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6081,6 +6081,72 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
60816081
EXPECT_THAT(Moved, testing::IsEmpty());
60826082
}
60836083

6084+
// Check callbacks when we set a Use.
6085+
TEST_F(SandboxIRTest, SetUseCallbacks) {
6086+
parseIR(C, R"IR(
6087+
define void @foo(i8 %v0, i8 %v1) {
6088+
%add0 = add i8 %v0, %v1
6089+
%add1 = add i8 %add0, %v1
6090+
ret void
6091+
}
6092+
)IR");
6093+
llvm::Function *LLVMF = &*M->getFunction("foo");
6094+
sandboxir::Context Ctx(C);
6095+
auto *F = Ctx.createFunction(LLVMF);
6096+
auto *Arg0 = F->getArg(0);
6097+
auto *BB = &*F->begin();
6098+
auto It = BB->begin();
6099+
auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
6100+
auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
6101+
6102+
SmallVector<std::pair<sandboxir::Use, sandboxir::Value *>> UsesSet;
6103+
auto Id = Ctx.registerSetUseCallback(
6104+
[&UsesSet](sandboxir::Use U, sandboxir::Value *NewSrc) {
6105+
UsesSet.push_back({U, NewSrc});
6106+
});
6107+
6108+
// Now change %add1 operand to not use %add0.
6109+
Add1->setOperand(0, Arg0);
6110+
EXPECT_EQ(UsesSet.size(), 1u);
6111+
EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
6112+
EXPECT_EQ(UsesSet[0].second, Arg0);
6113+
// Restore to previous state.
6114+
Add1->setOperand(0, Add0);
6115+
UsesSet.clear();
6116+
6117+
// RAUW
6118+
Add0->replaceAllUsesWith(Arg0);
6119+
EXPECT_EQ(UsesSet.size(), 1u);
6120+
EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
6121+
EXPECT_EQ(UsesSet[0].second, Arg0);
6122+
// Restore to previous state.
6123+
Add1->setOperand(0, Add0);
6124+
UsesSet.clear();
6125+
6126+
// RUWIf
6127+
Add0->replaceUsesWithIf(Arg0, [](const auto &U) { return true; });
6128+
EXPECT_EQ(UsesSet.size(), 1u);
6129+
EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
6130+
EXPECT_EQ(UsesSet[0].second, Arg0);
6131+
// Restore to previous state.
6132+
Add1->setOperand(0, Add0);
6133+
UsesSet.clear();
6134+
6135+
// RUOW
6136+
Add1->replaceUsesOfWith(Add0, Arg0);
6137+
EXPECT_EQ(UsesSet.size(), 1u);
6138+
EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
6139+
EXPECT_EQ(UsesSet[0].second, Arg0);
6140+
// Restore to previous state.
6141+
Add1->setOperand(0, Add0);
6142+
UsesSet.clear();
6143+
6144+
// Check unregister.
6145+
Ctx.unregisterSetUseCallback(Id);
6146+
Add0->replaceAllUsesWith(Arg0);
6147+
EXPECT_TRUE(UsesSet.empty());
6148+
}
6149+
60846150
TEST_F(SandboxIRTest, FunctionObjectAlreadyExists) {
60856151
parseIR(C, R"IR(
60866152
define void @foo() {

0 commit comments

Comments
 (0)