Skip to content

[SandboxIR] Preserve the order of switch cases after revert. #115577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions llvm/include/llvm/SandboxIR/Tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,15 @@ class SwitchAddCase : public IRChangeBase {

class SwitchRemoveCase : public IRChangeBase {
SwitchInst *Switch;
ConstantInt *Val;
BasicBlock *Dest;
struct Case {
ConstantInt *Val;
BasicBlock *Dest;
};
SmallVector<Case> Cases;

public:
SwitchRemoveCase(SwitchInst *Switch, ConstantInt *Val, BasicBlock *Dest)
: Switch(Switch), Val(Val), Dest(Dest) {}
SwitchRemoveCase(SwitchInst *Switch);

void revert(Tracker &Tracker) final;
void accept() final {}
#ifndef NDEBUG
Expand Down
4 changes: 1 addition & 3 deletions llvm/lib/SandboxIR/Instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1131,9 +1131,7 @@ void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest) {
}

SwitchInst::CaseIt SwitchInst::removeCase(CaseIt It) {
auto &Case = *It;
Ctx.getTracker().emplaceIfTracking<SwitchRemoveCase>(
this, Case.getCaseValue(), Case.getCaseSuccessor());
Ctx.getTracker().emplaceIfTracking<SwitchRemoveCase>(this);

auto *LLVMSwitch = cast<llvm::SwitchInst>(Val);
unsigned CaseNum = It - case_begin();
Expand Down
19 changes: 18 additions & 1 deletion llvm/lib/SandboxIR/Tracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,24 @@ void CatchSwitchAddHandler::revert(Tracker &Tracker) {
LLVMCSI->removeHandler(LLVMCSI->handler_begin() + HandlerIdx);
}

void SwitchRemoveCase::revert(Tracker &Tracker) { Switch->addCase(Val, Dest); }
SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) {
for (const auto &C : Switch->cases())
Cases.push_back({C.getCaseValue(), C.getCaseSuccessor()});
}

void SwitchRemoveCase::revert(Tracker &Tracker) {
// SwitchInst::removeCase doesn't provide any guarantees about the order of
// cases after removal. In order to preserve the original ordering, we save
// all of them and, when reverting, clear them all then insert them in the
// desired order. This still relies on the fact that `addCase` will insert
// them at the end, but it is documented to invalidate `case_end()` so it's
// probably okay.
unsigned NumCases = Switch->getNumCases();
for (unsigned I = 0; I < NumCases; ++I)
Switch->removeCase(Switch->case_begin());
for (auto &Case : Cases)
Switch->addCase(Case.Val, Case.Dest);
}

#ifndef NDEBUG
void SwitchRemoveCase::dump() const {
Expand Down
82 changes: 82 additions & 0 deletions llvm/unittests/SandboxIR/TrackerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,88 @@ define void @foo(i32 %cond0, i32 %cond1) {
EXPECT_EQ(Switch->findCaseDest(BB1), One);
}

TEST_F(TrackerTest, SwitchInstPreservesSuccesorOrder) {
parseIR(C, R"IR(
define void @foo(i32 %cond0) {
entry:
switch i32 %cond0, label %default [ i32 0, label %bb0
i32 1, label %bb1
i32 2, label %bb2 ]
bb0:
ret void
bb1:
ret void
bb2:
ret void
default:
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
auto *LLVMEntry = getBasicBlockByName(LLVMF, "entry");

sandboxir::Context Ctx(C);
[[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
auto *Entry = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMEntry));
auto *BB0 = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
auto *BB1 = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "bb1")));
auto *BB2 = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "bb2")));
auto *Switch = cast<sandboxir::SwitchInst>(&*Entry->begin());

auto *DefaultDest = Switch->getDefaultDest();
auto *Zero = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0);
auto *One = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 1);
auto *Two = sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 2);

// Check that we can properly revert a removeCase multiple positions apart
// from the end of the operand list.
Ctx.save();
Switch->removeCase(Switch->findCaseValue(Zero));
EXPECT_EQ(Switch->getNumCases(), 2u);
Ctx.revert();
EXPECT_EQ(Switch->getNumCases(), 3u);
EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
EXPECT_EQ(Switch->findCaseDest(BB1), One);
EXPECT_EQ(Switch->findCaseDest(BB2), Two);
EXPECT_EQ(Switch->getSuccessor(0), DefaultDest);
EXPECT_EQ(Switch->getSuccessor(1), BB0);
EXPECT_EQ(Switch->getSuccessor(2), BB1);
EXPECT_EQ(Switch->getSuccessor(3), BB2);

// Check that we can properly revert a removeCase of the last case.
Ctx.save();
Switch->removeCase(Switch->findCaseValue(Two));
EXPECT_EQ(Switch->getNumCases(), 2u);
Ctx.revert();
EXPECT_EQ(Switch->getNumCases(), 3u);
EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
EXPECT_EQ(Switch->findCaseDest(BB1), One);
EXPECT_EQ(Switch->findCaseDest(BB2), Two);
EXPECT_EQ(Switch->getSuccessor(0), DefaultDest);
EXPECT_EQ(Switch->getSuccessor(1), BB0);
EXPECT_EQ(Switch->getSuccessor(2), BB1);
EXPECT_EQ(Switch->getSuccessor(3), BB2);

// Check order is preserved after reverting multiple removeCase invocations.
Ctx.save();
Switch->removeCase(Switch->findCaseValue(One));
Switch->removeCase(Switch->findCaseValue(Zero));
Switch->removeCase(Switch->findCaseValue(Two));
EXPECT_EQ(Switch->getNumCases(), 0u);
Ctx.revert();
EXPECT_EQ(Switch->getNumCases(), 3u);
EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this would be a good place to add a comment that if these checks fail, then it might be that the implementation in LLVM changed and that we should update the way revert() works.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the implementation so it doesn't rely on the specific ordering changes introduced by removeCase()

EXPECT_EQ(Switch->findCaseDest(BB1), One);
EXPECT_EQ(Switch->findCaseDest(BB2), Two);
EXPECT_EQ(Switch->getSuccessor(0), DefaultDest);
EXPECT_EQ(Switch->getSuccessor(1), BB0);
EXPECT_EQ(Switch->getSuccessor(2), BB1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test where we remove case 0 of a 3-case switch, to check that we revert correctly in that case too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

EXPECT_EQ(Switch->getSuccessor(3), BB2);
}

TEST_F(TrackerTest, SelectInst) {
parseIR(C, R"IR(
define void @foo(i1 %c0, i8 %v0, i8 %v1) {
Expand Down
Loading