Skip to content

[SandboxIR] Implement ReturnInst #99784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
// |
// +- PHINode
// |
// +- RetInst
// +- ReturnInst
// |
// +- SelectInst
// |
Expand Down Expand Up @@ -76,6 +76,7 @@ class Context;
class Function;
class Instruction;
class LoadInst;
class ReturnInst;
class StoreInst;
class User;
class Value;
Expand Down Expand Up @@ -173,11 +174,12 @@ class Value {
/// order.
llvm::Value *Val = nullptr;

friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -497,8 +499,9 @@ class Instruction : public sandboxir::User {
/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().

/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
Expand Down Expand Up @@ -639,6 +642,43 @@ class StoreInst final : public Instruction {
#endif
};

class ReturnInst final : public Instruction {
/// Use ReturnInst::create() instead of calling the constructor.
ReturnInst(llvm::Instruction *I, Context &Ctx)
: Instruction(ClassID::Ret, Opcode::Ret, I, Ctx) {}
ReturnInst(ClassID SubclassID, llvm::Instruction *I, Context &Ctx)
: Instruction(SubclassID, Opcode::Ret, I, Ctx) {}
friend class Context; // For accessing the constructor in create*()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}
static ReturnInst *createCommon(Value *RetVal, IRBuilder<> &Builder,
Context &Ctx);

public:
static ReturnInst *create(Value *RetVal, Instruction *InsertBefore,
Context &Ctx);
static ReturnInst *create(Value *RetVal, BasicBlock *InsertAtEnd,
Context &Ctx);
static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::Ret;
}
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
/// \Returns null if there is no return value.
Value *getReturnValue() const;
#ifndef NDEBUG
void verify() const final {}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
/// an OpaqueInstr.
class OpaqueInst : public sandboxir::Instruction {
Expand Down Expand Up @@ -776,6 +816,8 @@ class Context {
friend LoadInst; // For createLoadInst()
StoreInst *createStoreInst(llvm::StoreInst *SI);
friend StoreInst; // For createStoreInst()
ReturnInst *createReturnInst(llvm::ReturnInst *I);
friend ReturnInst; // For createReturnInst()

public:
Context(LLVMContext &LLVMCtx)
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ DEF_USER(Constant, Constant)
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)

#ifdef DEF_VALUE
#undef DEF_VALUE
Expand Down
54 changes: 53 additions & 1 deletion llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,48 @@ void StoreInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

ReturnInst *ReturnInst::createCommon(Value *RetVal, IRBuilder<> &Builder,
Context &Ctx) {
llvm::ReturnInst *NewRI;
if (RetVal != nullptr)
NewRI = Builder.CreateRet(RetVal->Val);
else
NewRI = Builder.CreateRetVoid();
return Ctx.createReturnInst(NewRI);
}

ReturnInst *ReturnInst::create(Value *RetVal, Instruction *InsertBefore,
Context &Ctx) {
llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(BeforeIR);
return createCommon(RetVal, Builder, Ctx);
}

ReturnInst *ReturnInst::create(Value *RetVal, BasicBlock *InsertAtEnd,
Context &Ctx) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
return createCommon(RetVal, Builder, Ctx);
}

Value *ReturnInst::getReturnValue() const {
auto *LLVMRetVal = cast<llvm::ReturnInst>(Val)->getReturnValue();
return LLVMRetVal != nullptr ? Ctx.getValue(LLVMRetVal) : nullptr;
}

#ifndef NDEBUG
void ReturnInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}

void ReturnInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}

void OpaqueInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
Expand Down Expand Up @@ -626,7 +668,7 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
"Can't register a user!");
Value *V = VPtr.get();
[[maybe_unused]] auto Pair =
LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
assert(Pair.second && "Already exists!");
return V;
}
Expand Down Expand Up @@ -668,6 +710,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
return It->second.get();
}
case llvm::Instruction::Ret: {
auto *LLVMRet = cast<llvm::ReturnInst>(LLVMV);
It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
return It->second.get();
}
default:
break;
}
Expand Down Expand Up @@ -696,6 +743,11 @@ StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
return cast<StoreInst>(registerValue(std::move(NewPtr)));
}

ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
return cast<ReturnInst>(registerValue(std::move(NewPtr)));
}

Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
if (It != LLVMValueToValueMap.end())
Expand Down
53 changes: 46 additions & 7 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ define i32 @foo(i32 %v0, i32 %v1) {
auto *Arg1 = F.getArg(1);
auto It = BB.begin();
auto *I0 = &*It++;
auto *Ret = &*It++;
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

SmallVector<sandboxir::Argument *> Args{Arg0, Arg1};
unsigned OpIdx = 0;
Expand Down Expand Up @@ -245,7 +245,7 @@ define i32 @foo(i32 %arg0, i32 %arg1) {
auto *I0 = &*It++;
auto *I1 = &*It++;
auto *I2 = &*It++;
auto *Ret = &*It++;
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

bool Replaced;
// Try to replace an operand that doesn't match.
Expand Down Expand Up @@ -401,7 +401,7 @@ void @foo(i32 %arg0, i32 %arg1) {
br label %bb1 ; SB3. (Opaque)

bb1:
ret void ; SB5. (Opaque)
ret void ; SB5. (Ret)
}
)IR");
}
Expand Down Expand Up @@ -488,7 +488,7 @@ define void @foo(i8 %v1) {
auto It = BB->begin();
auto *I0 = &*It++;
auto *I1 = &*It++;
auto *Ret = &*It++;
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

// Check getPrevNode().
EXPECT_EQ(Ret->getPrevNode(), I1);
Expand All @@ -508,7 +508,7 @@ define void @foo(i8 %v1) {
// Check getOpcode().
EXPECT_EQ(I0->getOpcode(), sandboxir::Instruction::Opcode::Opaque);
EXPECT_EQ(I1->getOpcode(), sandboxir::Instruction::Opcode::Opaque);
EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Opaque);
EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Ret);

// Check moveBefore(I).
I1->moveBefore(I0);
Expand Down Expand Up @@ -576,7 +576,7 @@ define void @foo(ptr %arg0, ptr %arg1) {
auto *BB = &*F->begin();
auto It = BB->begin();
auto *Ld = cast<sandboxir::LoadInst>(&*It++);
auto *Ret = &*It++;
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

// Check getPointerOperand()
EXPECT_EQ(Ld->getPointerOperand(), Arg0);
Expand Down Expand Up @@ -607,7 +607,7 @@ define void @foo(i8 %val, ptr %ptr) {
auto *BB = &*F->begin();
auto It = BB->begin();
auto *St = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = &*It++;
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

// Check that the StoreInst has been created correctly.
// Check getPointerOperand()
Expand All @@ -624,3 +624,42 @@ define void @foo(i8 %val, ptr %ptr) {
EXPECT_EQ(NewSt->getPointerOperand(), Ptr);
EXPECT_EQ(NewSt->getAlign(), 8);
}

TEST_F(SandboxIRTest, ReturnInst) {
parseIR(C, R"IR(
define i8 @foo(i8 %val) {
%add = add i8 %val, 42
ret i8 %val
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
sandboxir::Function *F = Ctx.createFunction(LLVMF);
auto *Val = F->getArg(0);
auto *BB = &*F->begin();
auto It = BB->begin();
It++;
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

// Check that the ReturnInst has been created correctly.
// Check getReturnValue().
EXPECT_EQ(Ret->getReturnValue(), Val);

// Check create(InsertBefore) a void ReturnInst.
auto *NewRet1 = cast<sandboxir::ReturnInst>(
sandboxir::ReturnInst::create(nullptr, /*InsertBefore=*/Ret, Ctx));
EXPECT_EQ(NewRet1->getReturnValue(), nullptr);
// Check create(InsertBefore) a non-void ReturnInst.
auto *NewRet2 = cast<sandboxir::ReturnInst>(
sandboxir::ReturnInst::create(Val, /*InsertBefore=*/Ret, Ctx));
EXPECT_EQ(NewRet2->getReturnValue(), Val);

// Check create(InsertAtEnd) a void ReturnInst.
auto *NewRet3 = cast<sandboxir::ReturnInst>(
sandboxir::ReturnInst::create(nullptr, /*InsertAtEnd=*/BB, Ctx));
EXPECT_EQ(NewRet3->getReturnValue(), nullptr);
// Check create(InsertAtEnd) a non-void ReturnInst.
auto *NewRet4 = cast<sandboxir::ReturnInst>(
sandboxir::ReturnInst::create(Val, /*InsertAtEnd=*/BB, Ctx));
EXPECT_EQ(NewRet4->getReturnValue(), Val);
}
Loading