Skip to content

[SandboxIR] Add InsertValueInst #106273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
// |
// +- ShuffleVectorInst
// |
// +- InsertValueInst
// |
// +- StoreInst
// |
// +- UnaryInstruction -+- LoadInst
Expand Down Expand Up @@ -117,6 +119,7 @@ class SelectInst;
class ExtractElementInst;
class InsertElementInst;
class ShuffleVectorInst;
class InsertValueInst;
class BranchInst;
class UnaryInstruction;
class LoadInst;
Expand Down Expand Up @@ -260,6 +263,7 @@ class Value {
friend class ExtractElementInst; // For getting `Val`.
friend class InsertElementInst; // For getting `Val`.
friend class ShuffleVectorInst; // For getting `Val`.
friend class InsertValueInst; // For getting `Val`.
friend class BranchInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
Expand Down Expand Up @@ -692,6 +696,7 @@ class Instruction : public sandboxir::User {
friend class ExtractElementInst; // For getTopmostLLVMInstruction().
friend class InsertElementInst; // For getTopmostLLVMInstruction().
friend class ShuffleVectorInst; // For getTopmostLLVMInstruction().
friend class InsertValueInst; // For getTopmostLLVMInstruction().
friend class BranchInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
Expand Down Expand Up @@ -1451,6 +1456,67 @@ class ShuffleVectorInst final
}
};

class InsertValueInst
: public SingleLLVMInstructionImpl<llvm::InsertValueInst> {
/// Use Context::createInsertValueInst(). Don't call the constructor directly.
InsertValueInst(llvm::InsertValueInst *IVI, Context &Ctx)
: SingleLLVMInstructionImpl(ClassID::InsertValue, Opcode::InsertValue,
IVI, Ctx) {}
friend Context; // for InsertValueInst()

public:
static Value *create(Value *Agg, Value *Val, ArrayRef<unsigned> Idxs,
BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
const Twine &Name = "");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing getInedexedType(). If it is missing for a reason then we should add a TODO comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? I only see getIndexedType methods in GetElementPtrInst and ExtractValueInst, not in InsertValueInst.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops I was looking at the ExtractValueInst :)

static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::InsertValue;
}

using idx_iterator = llvm::InsertValueInst::idx_iterator;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

static bool classof(const Value *From) {
    return From->getSubclassID() == ClassID::InsertValueInst;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm @slackito does the test pass without the classof? Shouldn't it crash here?

  auto *Ins0 = cast<sandboxir::InsertValueInst>(&*It++);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tschuett good catch, thanks!

@vporpo It does pass, and I just double-checked I'm running them with asserts enabled (debug build).

I think it's falling back to Instruction::classof and accidentally passing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a good way to check that classof has been implemented. Perhaps a test that uses the .def file and does something with <subclass>::classof() so that we get a compile-time error if we forget to implement it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote a small test that checks EXPECT_NE(&<subclass>::classof, &Instruction::classof); and one of the insturctions is missing a classof!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Can you commit that test (and the missing classof) separately from this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will upload a PR in a few minutes.

inline idx_iterator idx_begin() const {
return cast<llvm::InsertValueInst>(Val)->idx_begin();
}
inline idx_iterator idx_end() const {
return cast<llvm::InsertValueInst>(Val)->idx_end();
}
inline iterator_range<idx_iterator> indices() const {
return cast<llvm::InsertValueInst>(Val)->indices();
}

Value *getAggregateOperand() {
return getOperand(getAggregateOperandIndex());
}
const Value *getAggregateOperand() const {
return getOperand(getAggregateOperandIndex());
}
static unsigned getAggregateOperandIndex() {
return llvm::InsertValueInst::getAggregateOperandIndex();
}

Value *getInsertedValueOperand() {
return getOperand(getInsertedValueOperandIndex());
}
const Value *getInsertedValueOperand() const {
return getOperand(getInsertedValueOperandIndex());
}
static unsigned getInsertedValueOperandIndex() {
return llvm::InsertValueInst::getInsertedValueOperandIndex();
}

ArrayRef<unsigned> getIndices() const {
return cast<llvm::InsertValueInst>(Val)->getIndices();
}

unsigned getNumIndices() const {
return cast<llvm::InsertValueInst>(Val)->getNumIndices();
}

unsigned hasIndices() const {
return cast<llvm::InsertValueInst>(Val)->hasIndices();
}
};

class BranchInst : public SingleLLVMInstructionImpl<llvm::BranchInst> {
/// Use Context::createBranchInst(). Don't call the constructor directly.
BranchInst(llvm::BranchInst *BI, Context &Ctx)
Expand Down Expand Up @@ -2999,6 +3065,8 @@ class Context {
friend ExtractElementInst; // For createExtractElementInst()
ShuffleVectorInst *createShuffleVectorInst(llvm::ShuffleVectorInst *SVI);
friend ShuffleVectorInst; // For createShuffleVectorInst()
InsertValueInst *createInsertValueInst(llvm::InsertValueInst *IVI);
friend InsertValueInst; // For createInsertValueInst()
BranchInst *createBranchInst(llvm::BranchInst *I);
friend BranchInst; // For createBranchInst()
LoadInst *createLoadInst(llvm::LoadInst *LI);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
DEF_INSTR(Freeze, OP(Freeze), FreezeInst)
DEF_INSTR(Fence, OP(Fence), FenceInst)
DEF_INSTR(ShuffleVector, OP(ShuffleVector), ShuffleVectorInst)
DEF_INSTR(InsertValue, OP(InsertValue), InsertValueInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,21 @@ Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(
llvm::ShuffleVectorInst::convertShuffleMaskForBitcode(Mask, ResultTy));
}

Value *InsertValueInst::create(Value *Agg, Value *Val, ArrayRef<unsigned> Idxs,
BBIterator WhereIt, BasicBlock *WhereBB,
Context &Ctx, const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
if (WhereIt != WhereBB->end())
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
else
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
llvm::Value *NewV = Builder.CreateInsertValue(Agg->Val, Val->Val, Idxs, Name);
if (auto *NewInsertValueInst = dyn_cast<llvm::InsertValueInst>(NewV))
return Ctx.createInsertValueInst(NewInsertValueInst);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

#ifndef NDEBUG
void Constant::dumpOS(raw_ostream &OS) const {
dumpCommonPrefix(OS);
Expand Down Expand Up @@ -2269,6 +2284,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
new ShuffleVectorInst(LLVMIns, *this));
return It->second.get();
}
case llvm::Instruction::InsertValue: {
auto *LLVMIns = cast<llvm::InsertValueInst>(LLVMV);
It->second =
std::unique_ptr<InsertValueInst>(new InsertValueInst(LLVMIns, *this));
return It->second.get();
}
case llvm::Instruction::Br: {
auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
Expand Down Expand Up @@ -2480,6 +2501,12 @@ Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
return cast<ShuffleVectorInst>(registerValue(std::move(NewPtr)));
}

InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) {
auto NewPtr =
std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this));
return cast<InsertValueInst>(registerValue(std::move(NewPtr)));
}

BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
return cast<BranchInst>(registerValue(std::move(NewPtr)));
Expand Down
104 changes: 104 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,110 @@ define void @foo(<2 x i8> %v1, <2 x i8> %v2) {
}
}

TEST_F(SandboxIRTest, InsertValueInst) {
parseIR(C, R"IR(
define void @foo({i32, float} %agg, i32 %i) {
%ins_simple = insertvalue {i32, float} %agg, i32 %i, 0
%ins_nested = insertvalue {float, {i32}} undef, i32 %i, 1, 0
%const1 = insertvalue {i32, float} {i32 99, float 99.0}, i32 %i, 0
%const2 = insertvalue {i32, float} {i32 0, float 99.0}, i32 %i, 0
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto *ArgAgg = F.getArg(0);
auto *ArgInt = F.getArg(1);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *InsSimple = cast<sandboxir::InsertValueInst>(&*It++);
auto *InsNested = cast<sandboxir::InsertValueInst>(&*It++);
// These "const" instructions are helpers to create constant struct operands.
// TODO: Remove them once sandboxir::ConstantStruct gets added.
auto *Const1 = cast<sandboxir::InsertValueInst>(&*It++);
auto *Const2 = cast<sandboxir::InsertValueInst>(&*It++);
auto *Ret = &*It++;

EXPECT_EQ(InsSimple->getOperand(0), ArgAgg);
EXPECT_EQ(InsSimple->getOperand(1), ArgInt);

// create before instruction
auto *NewInsBeforeRet =
cast<sandboxir::InsertValueInst>(sandboxir::InsertValueInst::create(
ArgAgg, ArgInt, ArrayRef<unsigned>({0}), Ret->getIterator(),
Ret->getParent(), Ctx, "NewInsBeforeRet"));
EXPECT_EQ(NewInsBeforeRet->getNextNode(), Ret);
#ifndef NDEBUG
EXPECT_EQ(NewInsBeforeRet->getName(), "NewInsBeforeRet");
#endif // NDEBUG

// create at end of BB
auto *NewInsAtEnd =
cast<sandboxir::InsertValueInst>(sandboxir::InsertValueInst::create(
ArgAgg, ArgInt, ArrayRef<unsigned>({0}), BB->end(), BB, Ctx,
"NewInsAtEnd"));
EXPECT_EQ(NewInsAtEnd->getPrevNode(), Ret);
#ifndef NDEBUG
EXPECT_EQ(NewInsAtEnd->getName(), "NewInsAtEnd");
#endif // NDEBUG

// Test the path that creates a folded constant.
auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx);
auto *ShouldBeConstant = sandboxir::InsertValueInst::create(
Const1->getOperand(0), Zero, ArrayRef<unsigned>({0}), BB->end(), BB, Ctx);
auto *ExpectedConstant = Const2->getOperand(0);
EXPECT_TRUE(isa<sandboxir::Constant>(ShouldBeConstant));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Constants are unique, so we could also do something EXPECT_EQ(ShouldBeConstant, ExpectedConstant)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

EXPECT_EQ(ShouldBeConstant, ExpectedConstant);

// idx_begin / idx_end
{
SmallVector<int, 2> IndicesSimple(InsSimple->idx_begin(),
InsSimple->idx_end());
EXPECT_THAT(IndicesSimple, testing::ElementsAre(0u));

SmallVector<int, 2> IndicesNested(InsNested->idx_begin(),
InsNested->idx_end());
EXPECT_THAT(IndicesNested, testing::ElementsAre(1u, 0u));
}

// indices
{
SmallVector<int, 2> IndicesSimple(InsSimple->indices());
EXPECT_THAT(IndicesSimple, testing::ElementsAre(0u));

SmallVector<int, 2> IndicesNested(InsNested->indices());
EXPECT_THAT(IndicesNested, testing::ElementsAre(1u, 0u));
}

// getAggregateOperand
EXPECT_EQ(InsSimple->getAggregateOperand(), ArgAgg);
const auto *ConstInsSimple = InsSimple;
EXPECT_EQ(ConstInsSimple->getAggregateOperand(), ArgAgg);

// getAggregateOperandIndex
EXPECT_EQ(sandboxir::InsertValueInst::getAggregateOperandIndex(),
llvm::InsertValueInst::getAggregateOperandIndex());

// getInsertedValueOperand
EXPECT_EQ(InsSimple->getInsertedValueOperand(), ArgInt);
EXPECT_EQ(ConstInsSimple->getInsertedValueOperand(), ArgInt);

// getInsertedValueOperandIndex
EXPECT_EQ(sandboxir::InsertValueInst::getInsertedValueOperandIndex(),
llvm::InsertValueInst::getInsertedValueOperandIndex());

// getIndices
EXPECT_EQ(InsSimple->getIndices().size(), 1u);
EXPECT_EQ(InsSimple->getIndices()[0], 0u);

// getNumIndices
EXPECT_EQ(InsSimple->getNumIndices(), 1u);

// hasIndices
EXPECT_EQ(InsSimple->hasIndices(), true);
}

TEST_F(SandboxIRTest, BranchInst) {
parseIR(C, R"IR(
define void @foo(i1 %cond0, i1 %cond2) {
Expand Down
Loading