Skip to content

[SandboxIR] Implement SelectInst #99996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class BasicBlock;
class Context;
class Function;
class Instruction;
class SelectInst;
class LoadInst;
class ReturnInst;
class StoreInst;
Expand Down Expand Up @@ -177,6 +178,7 @@ class Value {
friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class SelectInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.
Expand Down Expand Up @@ -411,6 +413,8 @@ class Constant : public sandboxir::User {
}

public:
static Constant *createInt(Type *Ty, uint64_t V, Context &Ctx,
bool IsSigned = false);
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From) {
return From->getSubclassID() == ClassID::Constant ||
Expand Down Expand Up @@ -499,6 +503,7 @@ class Instruction : public sandboxir::User {
/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
friend class SelectInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
Expand Down Expand Up @@ -566,6 +571,52 @@ class Instruction : public sandboxir::User {
#endif
};

class SelectInst : public Instruction {
/// Use Context::createSelectInst(). Don't call the
/// constructor directly.
SelectInst(llvm::SelectInst *CI, Context &Ctx)
: Instruction(ClassID::Select, Opcode::Select, CI, Ctx) {}
friend Context; // for SelectInst()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}
static Value *createCommon(Value *Cond, Value *True, Value *False,
const Twine &Name, IRBuilder<> &Builder,
Context &Ctx);

public:
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
static Value *create(Value *Cond, Value *True, Value *False,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name = "");
static Value *create(Value *Cond, Value *True, Value *False,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name = "");
Value *getCondition() { return getOperand(0); }
Value *getTrueValue() { return getOperand(1); }
Value *getFalseValue() { return getOperand(2); }

void setCondition(Value *New) { setOperand(0, New); }
void setTrueValue(Value *New) { setOperand(1, New); }
void setFalseValue(Value *New) { setOperand(2, New); }
void swapValues() { cast<llvm::SelectInst>(Val)->swapValues(); }
/// For isa/dyn_cast.
static bool classof(const Value *From);
#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::SelectInst>(Val) && "Expected SelectInst!");
}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

class LoadInst final : public Instruction {
/// Use LoadInst::create() instead of calling the constructor.
LoadInst(llvm::LoadInst *LI, Context &Ctx)
Expand Down Expand Up @@ -803,6 +854,11 @@ class Context {
Value *getOrCreateValue(llvm::Value *LLVMV) {
return getOrCreateValueInternal(LLVMV, 0);
}
/// Get or create a sandboxir::Constant from an existing LLVM IR \p LLVMC.
Constant *getOrCreateConstant(llvm::Constant *LLVMC) {
return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
}
friend class Constant; // For getOrCreateConstant().
/// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will
/// also create all contents of the block.
BasicBlock *createBasicBlock(llvm::BasicBlock *BB);
Expand All @@ -812,6 +868,8 @@ class Context {
IRBuilder<ConstantFolder> LLVMIRBuilder;
auto &getLLVMIRBuilder() { return LLVMIRBuilder; }

SelectInst *createSelectInst(llvm::SelectInst *SI);
friend SelectInst; // For createSelectInst()
LoadInst *createLoadInst(llvm::LoadInst *LI);
friend LoadInst; // For createLoadInst()
StoreInst *createStoreInst(llvm::StoreInst *SI);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ DEF_USER(Constant, Constant)
#endif
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)
Expand Down
63 changes: 63 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,51 @@ void Instruction::dump() const {
}
#endif // NDEBUG

Value *SelectInst::createCommon(Value *Cond, Value *True, Value *False,
const Twine &Name, IRBuilder<> &Builder,
Context &Ctx) {
llvm::Value *NewV =
Builder.CreateSelect(Cond->Val, True->Val, False->Val, Name);
if (auto *NewSI = dyn_cast<llvm::SelectInst>(NewV))
return Ctx.createSelectInst(NewSI);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a createCommon here too like ReturnInst to capture the duplication?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes.

}

Value *SelectInst::create(Value *Cond, Value *True, Value *False,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(BeforeIR);
return createCommon(Cond, True, False, Name, Builder, Ctx);
}

Value *SelectInst::create(Value *Cond, Value *True, Value *False,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name) {
auto *IRInsertAtEnd = cast<llvm::BasicBlock>(InsertAtEnd->Val);
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(IRInsertAtEnd);
return createCommon(Cond, True, False, Name, Builder, Ctx);
}

bool SelectInst::classof(const Value *From) {
return From->getSubclassID() == ClassID::Select;
}

#ifndef NDEBUG
void SelectInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}

void SelectInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
Expand Down Expand Up @@ -592,7 +637,15 @@ void OpaqueInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx,
bool IsSigned) {
llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
return Ctx.getOrCreateConstant(LLVMC);
}

#ifndef NDEBUG
void Constant::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
Expand Down Expand Up @@ -700,6 +753,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction");

switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) {
case llvm::Instruction::Select: {
auto *LLVMSel = cast<llvm::SelectInst>(LLVMV);
It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
return It->second.get();
}
case llvm::Instruction::Load: {
auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
Expand Down Expand Up @@ -733,6 +791,11 @@ BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
return BB;
}

SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
return cast<SelectInst>(registerValue(std::move(NewPtr)));
}

LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
return cast<LoadInst>(registerValue(std::move(NewPtr)));
Expand Down
68 changes: 68 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,74 @@ define void @foo(i8 %v1) {
EXPECT_EQ(I0->getNextNode(), Ret);
}

TEST_F(SandboxIRTest, SelectInst) {
parseIR(C, R"IR(
define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
%sel = select i1 %c0, i8 %v0, i8 %v1
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
sandboxir::Function *F = Ctx.createFunction(LLVMF);
auto *Cond0 = F->getArg(0);
auto *V0 = F->getArg(1);
auto *V1 = F->getArg(2);
auto *Cond1 = F->getArg(3);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *Select = cast<sandboxir::SelectInst>(&*It++);
auto *Ret = &*It++;

// Check getCondition().
EXPECT_EQ(Select->getCondition(), Cond0);
// Check getTrueValue().
EXPECT_EQ(Select->getTrueValue(), V0);
// Check getFalseValue().
EXPECT_EQ(Select->getFalseValue(), V1);
// Check setCondition().
Select->setCondition(Cond1);
EXPECT_EQ(Select->getCondition(), Cond1);
// Check setTrueValue().
Select->setTrueValue(V1);
EXPECT_EQ(Select->getTrueValue(), V1);
// Check setFalseValue().
Select->setFalseValue(V0);
EXPECT_EQ(Select->getFalseValue(), V0);

{
// Check SelectInst::create() InsertBefore.
auto *NewSel = cast<sandboxir::SelectInst>(sandboxir::SelectInst::create(
Cond0, V0, V1, /*InsertBefore=*/Ret, Ctx));
EXPECT_EQ(NewSel->getCondition(), Cond0);
EXPECT_EQ(NewSel->getTrueValue(), V0);
EXPECT_EQ(NewSel->getFalseValue(), V1);
EXPECT_EQ(NewSel->getNextNode(), Ret);
}
{
// Check SelectInst::create() InsertAtEnd.
auto *NewSel = cast<sandboxir::SelectInst>(
sandboxir::SelectInst::create(Cond0, V0, V1, /*InsertAtEnd=*/BB, Ctx));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth checking *(BB->end())->getOpcodeName() == Select here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this is automatically handled by the .def files, so it shouldn't normally be wrong. But yeah, at some point we should check this too.
It will be checked once we add tests for the dump() functions for the whole class hierarchy in a future patch.

EXPECT_EQ(NewSel->getCondition(), Cond0);
EXPECT_EQ(NewSel->getTrueValue(), V0);
EXPECT_EQ(NewSel->getFalseValue(), V1);
EXPECT_EQ(NewSel->getPrevNode(), Ret);
}
{
// Check SelectInst::create() Folded.
auto *False =
sandboxir::Constant::createInt(llvm::Type::getInt1Ty(C), 0, Ctx,
/*IsSigned=*/false);
auto *FortyTwo =
sandboxir::Constant::createInt(llvm::Type::getInt1Ty(C), 42, Ctx,
/*IsSigned=*/false);
auto *NewSel =
sandboxir::SelectInst::create(False, FortyTwo, FortyTwo, Ret, Ctx);
EXPECT_TRUE(isa<sandboxir::Constant>(NewSel));
EXPECT_EQ(NewSel, FortyTwo);
}
}

TEST_F(SandboxIRTest, LoadInst) {
parseIR(C, R"IR(
define void @foo(ptr %arg0, ptr %arg1) {
Expand Down
Loading