Skip to content

[SandboxIR] Implement InvokeInst #100796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 74 additions & 4 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
// |
// +- CastInst
// |
// +- CallBase ----- CallInst
// |
// +- CmpInst
// +- CallBase ------+- CallInst
// | |
// +- CmpInst +- InvokeInst
// |
// +- ExtractElementInst
// |
Expand Down Expand Up @@ -90,6 +90,7 @@ class User;
class Value;
class CallBase;
class CallInst;
class InvokeInst;

/// Iterator for the `Use` edges of a User's operands.
/// \Returns the operand `Use` when dereferenced.
Expand Down Expand Up @@ -203,6 +204,7 @@ class Value {
friend class ReturnInst; // For getting `Val`.
friend class CallBase; // For getting `Val`.
friend class CallInst; // For getting `Val`.
friend class InvokeInst; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -541,6 +543,7 @@ class Instruction : public sandboxir::User {
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
friend class CallInst; // For getTopmostLLVMInstruction().
friend class InvokeInst; // For getTopmostLLVMInstruction().

/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
Expand Down Expand Up @@ -861,7 +864,8 @@ class ReturnInst final : public Instruction {
class CallBase : public Instruction {
CallBase(ClassID ID, Opcode Opc, llvm::Instruction *I, Context &Ctx)
: Instruction(ID, Opc, I, Ctx) {}
friend class CallInst; // For constructor.
friend class CallInst; // For constructor.
friend class InvokeInst; // For constructor.

public:
static bool classof(const Value *From) {
Expand Down Expand Up @@ -1029,6 +1033,70 @@ class CallInst final : public CallBase {
#endif
};

class InvokeInst final : public CallBase {
/// Use Context::createInvokeInst(). Don't call the
/// constructor directly.
InvokeInst(llvm::Instruction *I, Context &Ctx)
: CallBase(ClassID::Invoke, Opcode::Invoke, I, Ctx) {}
friend class Context; // For accessing the constructor in
// create*()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}

public:
static InvokeInst *create(FunctionType *FTy, Value *Func,
BasicBlock *IfNormal, BasicBlock *IfException,
ArrayRef<Value *> Args, BBIterator WhereIt,
BasicBlock *WhereBB, Context &Ctx,
const Twine &NameStr = "");
static InvokeInst *create(FunctionType *FTy, Value *Func,
BasicBlock *IfNormal, BasicBlock *IfException,
ArrayRef<Value *> Args, Instruction *InsertBefore,
Context &Ctx, const Twine &NameStr = "");
static InvokeInst *create(FunctionType *FTy, Value *Func,
BasicBlock *IfNormal, BasicBlock *IfException,
ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
Context &Ctx, const Twine &NameStr = "");

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is InvokeInstr missing a constructor with "InsertAtEnd" parameter? CallInstr has it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the issue is that in other instructions we had just two create() functions, one with InsertBefore and one with InsertAtEnd. But both of these can be replaced by one create() function that takes as arguments a BasicBlock::Iterator WhereIt and a BasicBlock *WhereBB. Anyway, it's probably OK to either keep all three variants, or keep only one. I would need to revisit them with a refactoring patch. Any preference?

But anyway, I will add the missing create() function with the InsertAtEnd parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a preference but the LLVM IR has InsertPosition. Should we just create one for SandboxIR too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should introduce a similar class in SandboxIR.

static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::Invoke;
}
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
BasicBlock *getNormalDest() const;
BasicBlock *getUnwindDest() const;
void setNormalDest(BasicBlock *BB);
void setUnwindDest(BasicBlock *BB);
// TODO: Return a `LandingPadInst` once implemented.
Instruction *getLandingPadInst() const;
BasicBlock *getSuccessor(unsigned SuccIdx) const;
void setSuccessor(unsigned SuccIdx, BasicBlock *NewSucc) {
assert(SuccIdx < 2 && "Successor # out of range for invoke!");
if (SuccIdx == 0)
setNormalDest(NewSucc);
else
setUnwindDest(NewSucc);
}
unsigned getNumSuccessors() const {
return cast<llvm::InvokeInst>(Val)->getNumSuccessors();
}
#ifndef NDEBUG
void verify() const final {}
friend raw_ostream &operator<<(raw_ostream &OS, const InvokeInst &I) {
I.dump(OS);
return OS;
}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
/// an OpaqueInstr.
class OpaqueInst : public sandboxir::Instruction {
Expand Down Expand Up @@ -1179,6 +1247,8 @@ class Context {
friend ReturnInst; // For createReturnInst()
CallInst *createCallInst(llvm::CallInst *I);
friend CallInst; // For createCallInst()
InvokeInst *createInvokeInst(llvm::InvokeInst *I);
friend InvokeInst; // For createInvokeInst()

public:
Context(LLVMContext &LLVMCtx)
Expand Down
85 changes: 85 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,81 @@ void CallInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

InvokeInst *InvokeInst::create(FunctionType *FTy, Value *Func,
BasicBlock *IfNormal, BasicBlock *IfException,
ArrayRef<Value *> Args, BBIterator WhereIt,
BasicBlock *WhereBB, Context &Ctx,
const Twine &NameStr) {
auto &Builder = Ctx.getLLVMIRBuilder();
if (WhereIt != WhereBB->end())
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
else
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
SmallVector<llvm::Value *> LLVMArgs;
LLVMArgs.reserve(Args.size());
for (Value *Arg : Args)
LLVMArgs.push_back(Arg->Val);
llvm::InvokeInst *Invoke = Builder.CreateInvoke(
FTy, Func->Val, cast<llvm::BasicBlock>(IfNormal->Val),
cast<llvm::BasicBlock>(IfException->Val), LLVMArgs, NameStr);
return Ctx.createInvokeInst(Invoke);
}

InvokeInst *InvokeInst::create(FunctionType *FTy, Value *Func,
BasicBlock *IfNormal, BasicBlock *IfException,
ArrayRef<Value *> Args,
Instruction *InsertBefore, Context &Ctx,
const Twine &NameStr) {
return create(FTy, Func, IfNormal, IfException, Args,
InsertBefore->getIterator(), InsertBefore->getParent(), Ctx,
NameStr);
}

InvokeInst *InvokeInst::create(FunctionType *FTy, Value *Func,
BasicBlock *IfNormal, BasicBlock *IfException,
ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
Context &Ctx, const Twine &NameStr) {
return create(FTy, Func, IfNormal, IfException, Args, InsertAtEnd->end(),
InsertAtEnd, Ctx, NameStr);
}

BasicBlock *InvokeInst::getNormalDest() const {
return cast<BasicBlock>(
Ctx.getValue(cast<llvm::InvokeInst>(Val)->getNormalDest()));
}
BasicBlock *InvokeInst::getUnwindDest() const {
return cast<BasicBlock>(
Ctx.getValue(cast<llvm::InvokeInst>(Val)->getUnwindDest()));
}
void InvokeInst::setNormalDest(BasicBlock *BB) {
setOperand(1, BB);
assert(getNormalDest() == BB && "LLVM IR uses a different operan index!");
}
void InvokeInst::setUnwindDest(BasicBlock *BB) {
setOperand(2, BB);
assert(getUnwindDest() == BB && "LLVM IR uses a different operan index!");
}
Instruction *InvokeInst::getLandingPadInst() const {
return cast<Instruction>(
Ctx.getValue(cast<llvm::InvokeInst>(Val)->getLandingPadInst()));
;
}
BasicBlock *InvokeInst::getSuccessor(unsigned SuccIdx) const {
return cast<BasicBlock>(
Ctx.getValue(cast<llvm::InvokeInst>(Val)->getSuccessor(SuccIdx)));
}

#ifndef NDEBUG
void InvokeInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}
void InvokeInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}

void OpaqueInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
Expand Down Expand Up @@ -968,6 +1043,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
return It->second.get();
}
case llvm::Instruction::Invoke: {
auto *LLVMInvoke = cast<llvm::InvokeInst>(LLVMV);
It->second = std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
return It->second.get();
}
default:
break;
}
Expand Down Expand Up @@ -1016,6 +1096,11 @@ CallInst *Context::createCallInst(llvm::CallInst *I) {
return cast<CallInst>(registerValue(std::move(NewPtr)));
}

InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this));
return cast<InvokeInst>(registerValue(std::move(NewPtr)));
}

Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
if (It != LLVMValueToValueMap.end())
Expand Down
91 changes: 91 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,3 +1036,94 @@ define i8 @foo(i8 %arg) {
EXPECT_EQ(Call->getArgOperand(0), Arg0);
}
}

TEST_F(SandboxIRTest, InvokeInst) {
parseIR(C, R"IR(
define void @foo(i8 %arg) {
bb0:
invoke i8 @foo(i8 %arg) to label %normal_bb
unwind label %exception_bb
normal_bb:
ret void
exception_bb:
%lpad = landingpad { ptr, i32}
cleanup
ret void
other_bb:
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto *Arg = F.getArg(0);
auto *BB0 = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
auto *NormalBB = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "normal_bb")));
auto *ExceptionBB = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "exception_bb")));
auto *LandingPad = &*ExceptionBB->begin();
auto *OtherBB = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "other_bb")));
auto It = BB0->begin();
// Check classof(Instruction *).
auto *Invoke = cast<sandboxir::InvokeInst>(&*It++);

// Check getNormalDest().
EXPECT_EQ(Invoke->getNormalDest(), NormalBB);
// Check getUnwindDest().
EXPECT_EQ(Invoke->getUnwindDest(), ExceptionBB);
// Check getSuccessor().
EXPECT_EQ(Invoke->getSuccessor(0), NormalBB);
EXPECT_EQ(Invoke->getSuccessor(1), ExceptionBB);
// Check setNormalDest().
Invoke->setNormalDest(OtherBB);
EXPECT_EQ(Invoke->getNormalDest(), OtherBB);
EXPECT_EQ(Invoke->getUnwindDest(), ExceptionBB);
// Check setUnwindDest().
Invoke->setUnwindDest(OtherBB);
EXPECT_EQ(Invoke->getNormalDest(), OtherBB);
EXPECT_EQ(Invoke->getUnwindDest(), OtherBB);
// Check setSuccessor().
Invoke->setSuccessor(0, NormalBB);
EXPECT_EQ(Invoke->getNormalDest(), NormalBB);
Invoke->setSuccessor(1, ExceptionBB);
EXPECT_EQ(Invoke->getUnwindDest(), ExceptionBB);
// Check getLandingPadInst().
EXPECT_EQ(Invoke->getLandingPadInst(), LandingPad);

{
// Check create() WhereIt, WhereBB.
SmallVector<sandboxir::Value *> Args({Arg});
auto *InsertBefore = &*BB0->begin();
auto *NewInvoke = cast<sandboxir::InvokeInst>(sandboxir::InvokeInst::create(
F.getFunctionType(), &F, NormalBB, ExceptionBB, Args,
/*WhereIt=*/InsertBefore->getIterator(), /*WhereBB=*/BB0, Ctx));
EXPECT_EQ(NewInvoke->getNormalDest(), NormalBB);
EXPECT_EQ(NewInvoke->getUnwindDest(), ExceptionBB);
EXPECT_EQ(NewInvoke->getNextNode(), InsertBefore);
}
{
// Check create() InsertBefore.
SmallVector<sandboxir::Value *> Args({Arg});
auto *InsertBefore = &*BB0->begin();
auto *NewInvoke = cast<sandboxir::InvokeInst>(
sandboxir::InvokeInst::create(F.getFunctionType(), &F, NormalBB,
ExceptionBB, Args, InsertBefore, Ctx));
EXPECT_EQ(NewInvoke->getNormalDest(), NormalBB);
EXPECT_EQ(NewInvoke->getUnwindDest(), ExceptionBB);
EXPECT_EQ(NewInvoke->getNextNode(), InsertBefore);
}
{
// Check create() InsertAtEnd.
SmallVector<sandboxir::Value *> Args({Arg});
auto *NewInvoke = cast<sandboxir::InvokeInst>(sandboxir::InvokeInst::create(
F.getFunctionType(), &F, NormalBB, ExceptionBB, Args,
/*InsertAtEnd=*/BB0, Ctx));
EXPECT_EQ(NewInvoke->getNormalDest(), NormalBB);
EXPECT_EQ(NewInvoke->getUnwindDest(), ExceptionBB);
EXPECT_EQ(NewInvoke->getParent(), BB0);
EXPECT_EQ(NewInvoke->getNextNode(), nullptr);
}
}
56 changes: 56 additions & 0 deletions llvm/unittests/SandboxIR/TrackerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,59 @@ define void @foo(i8 %arg0, i8 %arg1) {
Ctx.revert();
EXPECT_EQ(Call->getCalledFunction(), Bar1F);
}

TEST_F(TrackerTest, InvokeSetters) {
parseIR(C, R"IR(
define void @foo(i8 %arg) {
bb0:
invoke i8 @foo(i8 %arg) to label %normal_bb
unwind label %exception_bb
normal_bb:
ret void
exception_bb:
ret void
other_bb:
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
[[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
auto *BB0 = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
auto *NormalBB = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "normal_bb")));
auto *ExceptionBB = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "exception_bb")));
auto *OtherBB = cast<sandboxir::BasicBlock>(
Ctx.getValue(getBasicBlockByName(LLVMF, "other_bb")));
auto It = BB0->begin();
auto *Invoke = cast<sandboxir::InvokeInst>(&*It++);

// Check setNormalDest().
Ctx.save();
Invoke->setNormalDest(OtherBB);
EXPECT_EQ(Invoke->getNormalDest(), OtherBB);
Ctx.revert();
EXPECT_EQ(Invoke->getNormalDest(), NormalBB);

// Check setUnwindDest().
Ctx.save();
Invoke->setUnwindDest(OtherBB);
EXPECT_EQ(Invoke->getUnwindDest(), OtherBB);
Ctx.revert();
EXPECT_EQ(Invoke->getUnwindDest(), ExceptionBB);

// Check setSuccessor().
Ctx.save();
Invoke->setSuccessor(0, OtherBB);
EXPECT_EQ(Invoke->getSuccessor(0), OtherBB);
Ctx.revert();
EXPECT_EQ(Invoke->getSuccessor(0), NormalBB);

Ctx.save();
Invoke->setSuccessor(1, OtherBB);
EXPECT_EQ(Invoke->getSuccessor(1), OtherBB);
Ctx.revert();
EXPECT_EQ(Invoke->getSuccessor(1), ExceptionBB);
}
Loading