Skip to content

Commit 6e8c970

Browse files
authored
[SandboxIR] Implement CatchSwitchInst (#104652)
This patch implements sandboxir::CatchSwitchInst mirroring llvm::CatchSwitchInst.
1 parent 93e0f31 commit 6e8c970

File tree

7 files changed

+353
-0
lines changed

7 files changed

+353
-0
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class CastInst;
131131
class PtrToIntInst;
132132
class BitCastInst;
133133
class AllocaInst;
134+
class CatchSwitchInst;
134135
class SwitchInst;
135136
class UnaryOperator;
136137
class BinaryOperator;
@@ -254,6 +255,7 @@ class Value {
254255
friend class InvokeInst; // For getting `Val`.
255256
friend class CallBrInst; // For getting `Val`.
256257
friend class GetElementPtrInst; // For getting `Val`.
258+
friend class CatchSwitchInst; // For getting `Val`.
257259
friend class SwitchInst; // For getting `Val`.
258260
friend class UnaryOperator; // For getting `Val`.
259261
friend class BinaryOperator; // For getting `Val`.
@@ -263,6 +265,7 @@ class Value {
263265
friend class CastInst; // For getting `Val`.
264266
friend class PHINode; // For getting `Val`.
265267
friend class UnreachableInst; // For getting `Val`.
268+
friend class CatchSwitchAddHandler; // For `Val`.
266269

267270
/// All values point to the context.
268271
Context &Ctx;
@@ -674,6 +677,7 @@ class Instruction : public sandboxir::User {
674677
friend class InvokeInst; // For getTopmostLLVMInstruction().
675678
friend class CallBrInst; // For getTopmostLLVMInstruction().
676679
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
680+
friend class CatchSwitchInst; // For getTopmostLLVMInstruction().
677681
friend class SwitchInst; // For getTopmostLLVMInstruction().
678682
friend class UnaryOperator; // For getTopmostLLVMInstruction().
679683
friend class BinaryOperator; // For getTopmostLLVMInstruction().
@@ -1480,6 +1484,97 @@ class GetElementPtrInst final
14801484
// TODO: Add missing member functions.
14811485
};
14821486

1487+
class CatchSwitchInst
1488+
: public SingleLLVMInstructionImpl<llvm::CatchSwitchInst> {
1489+
public:
1490+
CatchSwitchInst(llvm::CatchSwitchInst *CSI, Context &Ctx)
1491+
: SingleLLVMInstructionImpl(ClassID::CatchSwitch, Opcode::CatchSwitch,
1492+
CSI, Ctx) {}
1493+
1494+
static CatchSwitchInst *create(Value *ParentPad, BasicBlock *UnwindBB,
1495+
unsigned NumHandlers, BBIterator WhereIt,
1496+
BasicBlock *WhereBB, Context &Ctx,
1497+
const Twine &Name = "");
1498+
1499+
Value *getParentPad() const;
1500+
void setParentPad(Value *ParentPad);
1501+
1502+
bool hasUnwindDest() const {
1503+
return cast<llvm::CatchSwitchInst>(Val)->hasUnwindDest();
1504+
}
1505+
bool unwindsToCaller() const {
1506+
return cast<llvm::CatchSwitchInst>(Val)->unwindsToCaller();
1507+
}
1508+
BasicBlock *getUnwindDest() const;
1509+
void setUnwindDest(BasicBlock *UnwindDest);
1510+
1511+
unsigned getNumHandlers() const {
1512+
return cast<llvm::CatchSwitchInst>(Val)->getNumHandlers();
1513+
}
1514+
1515+
private:
1516+
static BasicBlock *handler_helper(Value *V) { return cast<BasicBlock>(V); }
1517+
static const BasicBlock *handler_helper(const Value *V) {
1518+
return cast<BasicBlock>(V);
1519+
}
1520+
1521+
public:
1522+
using DerefFnTy = BasicBlock *(*)(Value *);
1523+
using handler_iterator = mapped_iterator<op_iterator, DerefFnTy>;
1524+
using handler_range = iterator_range<handler_iterator>;
1525+
using ConstDerefFnTy = const BasicBlock *(*)(const Value *);
1526+
using const_handler_iterator =
1527+
mapped_iterator<const_op_iterator, ConstDerefFnTy>;
1528+
using const_handler_range = iterator_range<const_handler_iterator>;
1529+
1530+
handler_iterator handler_begin() {
1531+
op_iterator It = op_begin() + 1;
1532+
if (hasUnwindDest())
1533+
++It;
1534+
return handler_iterator(It, DerefFnTy(handler_helper));
1535+
}
1536+
const_handler_iterator handler_begin() const {
1537+
const_op_iterator It = op_begin() + 1;
1538+
if (hasUnwindDest())
1539+
++It;
1540+
return const_handler_iterator(It, ConstDerefFnTy(handler_helper));
1541+
}
1542+
handler_iterator handler_end() {
1543+
return handler_iterator(op_end(), DerefFnTy(handler_helper));
1544+
}
1545+
const_handler_iterator handler_end() const {
1546+
return const_handler_iterator(op_end(), ConstDerefFnTy(handler_helper));
1547+
}
1548+
handler_range handlers() {
1549+
return make_range(handler_begin(), handler_end());
1550+
}
1551+
const_handler_range handlers() const {
1552+
return make_range(handler_begin(), handler_end());
1553+
}
1554+
1555+
void addHandler(BasicBlock *Dest);
1556+
1557+
// TODO: removeHandler() cannot be reverted because there is no equivalent
1558+
// addHandler() with a handler_iterator to specify the position. So we can't
1559+
// implement it for now.
1560+
1561+
unsigned getNumSuccessors() const { return getNumOperands() - 1; }
1562+
BasicBlock *getSuccessor(unsigned Idx) const {
1563+
assert(Idx < getNumSuccessors() &&
1564+
"Successor # out of range for catchswitch!");
1565+
return cast<BasicBlock>(getOperand(Idx + 1));
1566+
}
1567+
void setSuccessor(unsigned Idx, BasicBlock *NewSucc) {
1568+
assert(Idx < getNumSuccessors() &&
1569+
"Successor # out of range for catchswitch!");
1570+
setOperand(Idx + 1, NewSucc);
1571+
}
1572+
1573+
static bool classof(const Value *From) {
1574+
return From->getSubclassID() == ClassID::CatchSwitch;
1575+
}
1576+
};
1577+
14831578
class SwitchInst : public SingleLLVMInstructionImpl<llvm::SwitchInst> {
14841579
public:
14851580
SwitchInst(llvm::SwitchInst *SI, Context &Ctx)
@@ -2201,6 +2296,8 @@ class Context {
22012296
friend CallBrInst; // For createCallBrInst()
22022297
GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
22032298
friend GetElementPtrInst; // For createGetElementPtrInst()
2299+
CatchSwitchInst *createCatchSwitchInst(llvm::CatchSwitchInst *I);
2300+
friend CatchSwitchInst; // For createCatchSwitchInst()
22042301
SwitchInst *createSwitchInst(llvm::SwitchInst *I);
22052302
friend SwitchInst; // For createSwitchInst()
22062303
UnaryOperator *createUnaryOperator(llvm::UnaryOperator *I);

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ DEF_INSTR(Call, OP(Call), CallInst)
4646
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
4747
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
4848
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
49+
DEF_INSTR(CatchSwitch, OP(CatchSwitch), CatchSwitchInst)
4950
DEF_INSTR(Switch, OP(Switch), SwitchInst)
5051
DEF_INSTR(UnOp, OPCODES( \
5152
OP(FNeg) \

llvm/include/llvm/SandboxIR/Tracker.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class StoreInst;
5959
class Instruction;
6060
class Tracker;
6161
class AllocaInst;
62+
class CatchSwitchInst;
6263
class SwitchInst;
6364
class ConstantInt;
6465

@@ -263,6 +264,23 @@ class GenericSetterWithIdx final : public IRChangeBase {
263264
#endif
264265
};
265266

267+
class CatchSwitchAddHandler : public IRChangeBase {
268+
CatchSwitchInst *CSI;
269+
unsigned HandlerIdx;
270+
271+
public:
272+
CatchSwitchAddHandler(CatchSwitchInst *CSI);
273+
void revert(Tracker &Tracker) final;
274+
void accept() final {}
275+
#ifndef NDEBUG
276+
void dump(raw_ostream &OS) const final { OS << "CatchSwitchAddHandler"; }
277+
LLVM_DUMP_METHOD void dump() const final {
278+
dump(dbgs());
279+
dbgs() << "\n";
280+
}
281+
#endif // NDEBUG
282+
};
283+
266284
class SwitchAddCase : public IRChangeBase {
267285
SwitchInst *Switch;
268286
ConstantInt *Val;

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,51 @@ static llvm::Instruction::UnaryOps getLLVMUnaryOp(Instruction::Opcode Opc) {
12361236
}
12371237
}
12381238

1239+
CatchSwitchInst *CatchSwitchInst::create(Value *ParentPad, BasicBlock *UnwindBB,
1240+
unsigned NumHandlers,
1241+
BBIterator WhereIt,
1242+
BasicBlock *WhereBB, Context &Ctx,
1243+
const Twine &Name) {
1244+
auto &Builder = Ctx.getLLVMIRBuilder();
1245+
if (WhereIt != WhereBB->end())
1246+
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
1247+
else
1248+
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
1249+
llvm::CatchSwitchInst *LLVMCSI = Builder.CreateCatchSwitch(
1250+
ParentPad->Val, cast<llvm::BasicBlock>(UnwindBB->Val), NumHandlers, Name);
1251+
return Ctx.createCatchSwitchInst(LLVMCSI);
1252+
}
1253+
1254+
Value *CatchSwitchInst::getParentPad() const {
1255+
return Ctx.getValue(cast<llvm::CatchSwitchInst>(Val)->getParentPad());
1256+
}
1257+
1258+
void CatchSwitchInst::setParentPad(Value *ParentPad) {
1259+
Ctx.getTracker()
1260+
.emplaceIfTracking<GenericSetter<&CatchSwitchInst::getParentPad,
1261+
&CatchSwitchInst::setParentPad>>(this);
1262+
cast<llvm::CatchSwitchInst>(Val)->setParentPad(ParentPad->Val);
1263+
}
1264+
1265+
BasicBlock *CatchSwitchInst::getUnwindDest() const {
1266+
return cast_or_null<BasicBlock>(
1267+
Ctx.getValue(cast<llvm::CatchSwitchInst>(Val)->getUnwindDest()));
1268+
}
1269+
1270+
void CatchSwitchInst::setUnwindDest(BasicBlock *UnwindDest) {
1271+
Ctx.getTracker()
1272+
.emplaceIfTracking<GenericSetter<&CatchSwitchInst::getUnwindDest,
1273+
&CatchSwitchInst::setUnwindDest>>(this);
1274+
cast<llvm::CatchSwitchInst>(Val)->setUnwindDest(
1275+
cast<llvm::BasicBlock>(UnwindDest->Val));
1276+
}
1277+
1278+
void CatchSwitchInst::addHandler(BasicBlock *Dest) {
1279+
Ctx.getTracker().emplaceIfTracking<CatchSwitchAddHandler>(this);
1280+
cast<llvm::CatchSwitchInst>(Val)->addHandler(
1281+
cast<llvm::BasicBlock>(Dest->Val));
1282+
}
1283+
12391284
SwitchInst *SwitchInst::create(Value *V, BasicBlock *Dest, unsigned NumCases,
12401285
BasicBlock::iterator WhereIt,
12411286
BasicBlock *WhereBB, Context &Ctx,
@@ -1953,6 +1998,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
19531998
new GetElementPtrInst(LLVMGEP, *this));
19541999
return It->second.get();
19552000
}
2001+
case llvm::Instruction::CatchSwitch: {
2002+
auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(LLVMV);
2003+
It->second = std::unique_ptr<CatchSwitchInst>(
2004+
new CatchSwitchInst(LLVMCatchSwitchInst, *this));
2005+
return It->second.get();
2006+
}
19562007
case llvm::Instruction::Switch: {
19572008
auto *LLVMSwitchInst = cast<llvm::SwitchInst>(LLVMV);
19582009
It->second =
@@ -2117,6 +2168,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
21172168
std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
21182169
return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
21192170
}
2171+
CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) {
2172+
auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this));
2173+
return cast<CatchSwitchInst>(registerValue(std::move(NewPtr)));
2174+
}
21202175
SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
21212176
auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
21222177
return cast<SwitchInst>(registerValue(std::move(NewPtr)));

llvm/lib/SandboxIR/Tracker.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,16 @@ void RemoveFromParent::dump() const {
160160
}
161161
#endif
162162

163+
CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI)
164+
: CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {}
165+
166+
void CatchSwitchAddHandler::revert(Tracker &Tracker) {
167+
// TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler()
168+
// once it gets implemented.
169+
auto *LLVMCSI = cast<llvm::CatchSwitchInst>(CSI->Val);
170+
LLVMCSI->removeHandler(LLVMCSI->handler_begin() + HandlerIdx);
171+
}
172+
163173
void SwitchRemoveCase::revert(Tracker &Tracker) { Switch->addCase(Val, Dest); }
164174

165175
#ifndef NDEBUG

llvm/unittests/SandboxIR/SandboxIRTest.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,6 +1643,110 @@ define void @foo(i32 %arg, float %farg) {
16431643
EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
16441644
}
16451645

1646+
TEST_F(SandboxIRTest, CatchSwitchInst) {
1647+
parseIR(C, R"IR(
1648+
define void @foo(i32 %cond0, i32 %cond1) {
1649+
bb0:
1650+
%cs0 = catchswitch within none [label %handler0, label %handler1] unwind to caller
1651+
bb1:
1652+
%cs1 = catchswitch within %cs0 [label %handler0, label %handler1] unwind label %cleanup
1653+
handler0:
1654+
ret void
1655+
handler1:
1656+
ret void
1657+
cleanup:
1658+
ret void
1659+
}
1660+
)IR");
1661+
Function &LLVMF = *M->getFunction("foo");
1662+
auto *LLVMBB0 = getBasicBlockByName(LLVMF, "bb0");
1663+
auto *LLVMBB1 = getBasicBlockByName(LLVMF, "bb1");
1664+
auto *LLVMHandler0 = getBasicBlockByName(LLVMF, "handler0");
1665+
auto *LLVMHandler1 = getBasicBlockByName(LLVMF, "handler1");
1666+
auto *LLVMCleanup = getBasicBlockByName(LLVMF, "cleanup");
1667+
auto *LLVMCS0 = cast<llvm::CatchSwitchInst>(&*LLVMBB0->begin());
1668+
auto *LLVMCS1 = cast<llvm::CatchSwitchInst>(&*LLVMBB1->begin());
1669+
1670+
sandboxir::Context Ctx(C);
1671+
[[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
1672+
auto *BB0 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB0));
1673+
auto *BB1 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB1));
1674+
auto *Handler0 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMHandler0));
1675+
auto *Handler1 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMHandler1));
1676+
auto *Cleanup = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMCleanup));
1677+
auto *CS0 = cast<sandboxir::CatchSwitchInst>(&*BB0->begin());
1678+
auto *CS1 = cast<sandboxir::CatchSwitchInst>(&*BB1->begin());
1679+
1680+
// Check getParentPad().
1681+
EXPECT_EQ(CS0->getParentPad(), Ctx.getValue(LLVMCS0->getParentPad()));
1682+
EXPECT_EQ(CS1->getParentPad(), Ctx.getValue(LLVMCS1->getParentPad()));
1683+
// Check setParentPad().
1684+
auto *OrigPad = CS0->getParentPad();
1685+
auto *NewPad = CS1;
1686+
EXPECT_NE(NewPad, OrigPad);
1687+
CS0->setParentPad(NewPad);
1688+
EXPECT_EQ(CS0->getParentPad(), NewPad);
1689+
CS0->setParentPad(OrigPad);
1690+
EXPECT_EQ(CS0->getParentPad(), OrigPad);
1691+
// Check hasUnwindDest().
1692+
EXPECT_EQ(CS0->hasUnwindDest(), LLVMCS0->hasUnwindDest());
1693+
EXPECT_EQ(CS1->hasUnwindDest(), LLVMCS1->hasUnwindDest());
1694+
// Check unwindsToCaller().
1695+
EXPECT_EQ(CS0->unwindsToCaller(), LLVMCS0->unwindsToCaller());
1696+
EXPECT_EQ(CS1->unwindsToCaller(), LLVMCS1->unwindsToCaller());
1697+
// Check getUnwindDest().
1698+
EXPECT_EQ(CS0->getUnwindDest(), Ctx.getValue(LLVMCS0->getUnwindDest()));
1699+
EXPECT_EQ(CS1->getUnwindDest(), Ctx.getValue(LLVMCS1->getUnwindDest()));
1700+
// Check setUnwindDest().
1701+
auto *OrigUnwindDest = CS1->getUnwindDest();
1702+
auto *NewUnwindDest = BB0;
1703+
EXPECT_NE(NewUnwindDest, OrigUnwindDest);
1704+
CS1->setUnwindDest(NewUnwindDest);
1705+
EXPECT_EQ(CS1->getUnwindDest(), NewUnwindDest);
1706+
CS1->setUnwindDest(OrigUnwindDest);
1707+
EXPECT_EQ(CS1->getUnwindDest(), OrigUnwindDest);
1708+
// Check getNumHandlers().
1709+
EXPECT_EQ(CS0->getNumHandlers(), LLVMCS0->getNumHandlers());
1710+
EXPECT_EQ(CS1->getNumHandlers(), LLVMCS1->getNumHandlers());
1711+
// Check handler_begin(), handler_end().
1712+
auto It = CS0->handler_begin();
1713+
EXPECT_EQ(*It++, Handler0);
1714+
EXPECT_EQ(*It++, Handler1);
1715+
EXPECT_EQ(It, CS0->handler_end());
1716+
// Check handlers().
1717+
SmallVector<sandboxir::BasicBlock *, 2> Handlers;
1718+
for (sandboxir::BasicBlock *Handler : CS0->handlers())
1719+
Handlers.push_back(Handler);
1720+
EXPECT_EQ(Handlers.size(), 2u);
1721+
EXPECT_EQ(Handlers[0], Handler0);
1722+
EXPECT_EQ(Handlers[1], Handler1);
1723+
// Check addHandler().
1724+
CS0->addHandler(BB0);
1725+
EXPECT_EQ(CS0->getNumHandlers(), 3u);
1726+
EXPECT_EQ(*std::next(CS0->handler_begin(), 2), BB0);
1727+
// Check getNumSuccessors().
1728+
EXPECT_EQ(CS0->getNumSuccessors(), LLVMCS0->getNumSuccessors());
1729+
EXPECT_EQ(CS1->getNumSuccessors(), LLVMCS1->getNumSuccessors());
1730+
// Check getSuccessor().
1731+
for (auto SuccIdx : seq<unsigned>(0, CS0->getNumSuccessors()))
1732+
EXPECT_EQ(CS0->getSuccessor(SuccIdx),
1733+
Ctx.getValue(LLVMCS0->getSuccessor(SuccIdx)));
1734+
// Check setSuccessor().
1735+
auto *OrigSuccessor = CS0->getSuccessor(0);
1736+
auto *NewSuccessor = BB0;
1737+
EXPECT_NE(NewSuccessor, OrigSuccessor);
1738+
CS0->setSuccessor(0, NewSuccessor);
1739+
EXPECT_EQ(CS0->getSuccessor(0), NewSuccessor);
1740+
CS0->setSuccessor(0, OrigSuccessor);
1741+
EXPECT_EQ(CS0->getSuccessor(0), OrigSuccessor);
1742+
// Check create().
1743+
CS1->eraseFromParent();
1744+
auto *NewCSI = sandboxir::CatchSwitchInst::create(
1745+
CS0, Cleanup, 2, BB1->begin(), BB1, Ctx, "NewCSI");
1746+
EXPECT_TRUE(isa<sandboxir::CatchSwitchInst>(NewCSI));
1747+
EXPECT_EQ(NewCSI->getParentPad(), CS0);
1748+
}
1749+
16461750
TEST_F(SandboxIRTest, SwitchInst) {
16471751
parseIR(C, R"IR(
16481752
define void @foo(i32 %cond0, i32 %cond1) {

0 commit comments

Comments
 (0)