Skip to content

[SandboxIR] SetUse callback #126985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion llvm/include/llvm/SandboxIR/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BBIterator;
class Constant;
class Module;
class Value;
class Use;

class Context {
public:
Expand All @@ -37,6 +38,8 @@ class Context {
// destination BB and an iterator pointing to the insertion position.
using MoveInstrCallback =
std::function<void(Instruction *, const BBIterator &)>;
// A SetUseCallback receives the Use that is about to get its source set.
using SetUseCallback = std::function<void(const Use &, Value *)>;

/// An ID for a registered callback. Used for deregistration. A dedicated type
/// is employed so as to keep IDs opaque to the end user; only Context should
Expand Down Expand Up @@ -98,6 +101,9 @@ class Context {
/// Callbacks called when an IR instruction is about to get moved. Keys are
/// used as IDs for deregistration.
MapVector<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
/// Callbacks called when a Use gets its source set. Keys are used as IDs for
/// deregistration.
MapVector<CallbackID, SetUseCallback> SetUseCallbacks;

/// A counter used for assigning callback IDs during registration. The same
/// counter is used for all kinds of callbacks so we can detect mismatched
Expand Down Expand Up @@ -129,6 +135,10 @@ class Context {
void runEraseInstrCallbacks(Instruction *I);
void runCreateInstrCallbacks(Instruction *I);
void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);
void runSetUseCallbacks(const Use &U, Value *NewSrc);

friend class User; // For runSetUseCallbacks().
friend class Value; // For runSetUseCallbacks().

// Friends for getOrCreateConstant().
#define DEF_CONST(ID, CLASS) friend class CLASS;
Expand Down Expand Up @@ -281,7 +291,10 @@ class Context {
CallbackID registerMoveInstrCallback(MoveInstrCallback CB);
void unregisterMoveInstrCallback(CallbackID ID);

// TODO: Add callbacks for instructions inserted/removed if needed.
/// Register a callback that gets called when a Use gets set.
/// \Returns a callback ID for later deregistration.
CallbackID registerSetUseCallback(SetUseCallback CB);
void unregisterSetUseCallback(CallbackID ID);
};

} // namespace sandboxir
Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/SandboxIR/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,11 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
CBEntry.second(I, WhereIt);
}

void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) {
for (auto &CBEntry : SetUseCallbacks)
CBEntry.second(U, NewSrc);
}

// An arbitrary limit, to check for accidental misuse. We expect a small number
// of callbacks to be registered at a time, but we can increase this number if
// we discover we needed more.
Expand Down Expand Up @@ -732,4 +737,17 @@ void Context::unregisterMoveInstrCallback(CallbackID ID) {
"Callback ID not found in MoveInstrCallbacks during deregistration");
}

Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) {
assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks &&
"SetUseCallbacks size limit exceeded");
CallbackID ID{NextCallbackID++};
SetUseCallbacks[ID] = CB;
return ID;
}
void Context::unregisterSetUseCallback(CallbackID ID) {
[[maybe_unused]] bool Erased = SetUseCallbacks.erase(ID);
assert(Erased &&
"Callback ID not found in SetUseCallbacks during deregistration");
}

} // namespace llvm::sandboxir
13 changes: 8 additions & 5 deletions llvm/lib/SandboxIR/User.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,20 @@ bool User::classof(const Value *From) {

void User::setOperand(unsigned OperandIdx, Value *Operand) {
assert(isa<llvm::User>(Val) && "No operands!");
Ctx.getTracker().emplaceIfTracking<UseSet>(getOperandUse(OperandIdx));
const auto &U = getOperandUse(OperandIdx);
Ctx.getTracker().emplaceIfTracking<UseSet>(U);
Ctx.runSetUseCallbacks(U, Operand);
// We are delegating to llvm::User::setOperand().
cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
}

bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking()) {
for (auto OpIdx : seq<unsigned>(0, getNumOperands())) {
auto Use = getOperandUse(OpIdx);
if (Use.get() == FromV)
for (auto OpIdx : seq<unsigned>(0, getNumOperands())) {
auto Use = getOperandUse(OpIdx);
if (Use.get() == FromV) {
Ctx.runSetUseCallbacks(Use, ToV);
if (Tracker.isTracking())
Tracker.emplaceIfTracking<UseSet>(Use);
}
}
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/SandboxIR/Value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@ void Value::replaceUsesWithIf(
llvm::Value *OtherVal = OtherV->Val;
// We are delegating RUWIf to LLVM IR's RUWIf.
Val->replaceUsesWithIf(
OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
OtherVal, [&ShouldReplace, this, OtherV](llvm::Use &LLVMUse) -> bool {
User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
if (DstU == nullptr)
return false;
Use UseToReplace(&LLVMUse, DstU, Ctx);
if (!ShouldReplace(UseToReplace))
return false;
Ctx.getTracker().emplaceIfTracking<UseSet>(UseToReplace);
Ctx.runSetUseCallbacks(UseToReplace, OtherV);
return true;
});
}
Expand All @@ -67,8 +68,9 @@ void Value::replaceAllUsesWith(Value *Other) {
assert(getType() == Other->getType() &&
"Replacing with Value of different type!");
auto &Tracker = Ctx.getTracker();
if (Tracker.isTracking()) {
for (auto Use : uses())
for (auto Use : uses()) {
Ctx.runSetUseCallbacks(Use, Other);
if (Tracker.isTracking())
Tracker.track(std::make_unique<UseSet>(Use));
}
// We are delegating RAUW to LLVM IR's RAUW.
Expand Down
66 changes: 66 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6081,6 +6081,72 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
EXPECT_THAT(Moved, testing::IsEmpty());
}

// Check callbacks when we set a Use.
TEST_F(SandboxIRTest, SetUseCallbacks) {
parseIR(C, R"IR(
define void @foo(i8 %v0, i8 %v1) {
%add0 = add i8 %v0, %v1
%add1 = add i8 %add0, %v1
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *Arg0 = F->getArg(0);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);

SmallVector<std::pair<sandboxir::Use, sandboxir::Value *>> UsesSet;
auto Id = Ctx.registerSetUseCallback(
[&UsesSet](sandboxir::Use U, sandboxir::Value *NewSrc) {
UsesSet.push_back({U, NewSrc});
});

// Now change %add1 operand to not use %add0.
Add1->setOperand(0, Arg0);
EXPECT_EQ(UsesSet.size(), 1u);
EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
EXPECT_EQ(UsesSet[0].second, Arg0);
// Restore to previous state.
Add1->setOperand(0, Add0);
UsesSet.clear();

// RAUW
Add0->replaceAllUsesWith(Arg0);
EXPECT_EQ(UsesSet.size(), 1u);
EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
EXPECT_EQ(UsesSet[0].second, Arg0);
// Restore to previous state.
Add1->setOperand(0, Add0);
UsesSet.clear();

// RUWIf
Add0->replaceUsesWithIf(Arg0, [](const auto &U) { return true; });
EXPECT_EQ(UsesSet.size(), 1u);
EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
EXPECT_EQ(UsesSet[0].second, Arg0);
// Restore to previous state.
Add1->setOperand(0, Add0);
UsesSet.clear();

// RUOW
Add1->replaceUsesOfWith(Add0, Arg0);
EXPECT_EQ(UsesSet.size(), 1u);
EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
EXPECT_EQ(UsesSet[0].second, Arg0);
// Restore to previous state.
Add1->setOperand(0, Add0);
UsesSet.clear();

// Check unregister.
Ctx.unregisterSetUseCallback(Id);
Add0->replaceAllUsesWith(Arg0);
EXPECT_TRUE(UsesSet.empty());
}

TEST_F(SandboxIRTest, FunctionObjectAlreadyExists) {
parseIR(C, R"IR(
define void @foo() {
Expand Down