Skip to content

[SandboxIR] Add ExtractValueInst. #106613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
// |
// +- ShuffleVectorInst
// |
// +- ExtractValueInst
// |
// +- InsertValueInst
// |
// +- StoreInst
Expand Down Expand Up @@ -120,6 +122,7 @@ class SelectInst;
class ExtractElementInst;
class InsertElementInst;
class ShuffleVectorInst;
class ExtractValueInst;
class InsertValueInst;
class BranchInst;
class UnaryInstruction;
Expand Down Expand Up @@ -270,6 +273,7 @@ class Value {
friend class ExtractElementInst; // For getting `Val`.
friend class InsertElementInst; // For getting `Val`.
friend class ShuffleVectorInst; // For getting `Val`.
friend class ExtractValueInst; // For getting `Val`.
friend class InsertValueInst; // For getting `Val`.
friend class BranchInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
Expand Down Expand Up @@ -710,6 +714,7 @@ class Instruction : public sandboxir::User {
friend class ExtractElementInst; // For getTopmostLLVMInstruction().
friend class InsertElementInst; // For getTopmostLLVMInstruction().
friend class ShuffleVectorInst; // For getTopmostLLVMInstruction().
friend class ExtractValueInst; // For getTopmostLLVMInstruction().
friend class InsertValueInst; // For getTopmostLLVMInstruction().
friend class BranchInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
Expand Down Expand Up @@ -1621,6 +1626,65 @@ class UnaryInstruction
}
};

class ExtractValueInst : public UnaryInstruction {
/// Use Context::createExtractValueInst() instead.
ExtractValueInst(llvm::ExtractValueInst *EVI, Context &Ctx)
: UnaryInstruction(ClassID::ExtractValue, Opcode::ExtractValue, EVI,
Ctx) {}
friend Context; // for ExtractValueInst()

public:
static Value *create(Value *Agg, ArrayRef<unsigned> Idxs, BBIterator WhereIt,
BasicBlock *WhereBB, Context &Ctx,
const Twine &Name = "");

static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::ExtractValue;
}

/// Returns the type of the element that would be extracted
/// with an extractvalue instruction with the specified parameters.
///
/// Null is returned if the indices are invalid for the specified type.
static Type *getIndexedType(Type *Agg, ArrayRef<unsigned> Idxs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It is not straight forward what this function does, so I would copy LLVM IR's comments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return llvm::ExtractValueInst::getIndexedType(Agg, Idxs);
}

using idx_iterator = llvm::ExtractValueInst::idx_iterator;

inline idx_iterator idx_begin() const {
return cast<llvm::ExtractValueInst>(Val)->idx_begin();
}
inline idx_iterator idx_end() const {
return cast<llvm::ExtractValueInst>(Val)->idx_end();
}
inline iterator_range<idx_iterator> indices() const {
return cast<llvm::ExtractValueInst>(Val)->indices();
}

Value *getAggregateOperand() {
return getOperand(getAggregateOperandIndex());
}
const Value *getAggregateOperand() const {
return getOperand(getAggregateOperandIndex());
}
static unsigned getAggregateOperandIndex() {
return llvm::ExtractValueInst::getAggregateOperandIndex();
}

ArrayRef<unsigned> getIndices() const {
return cast<llvm::ExtractValueInst>(Val)->getIndices();
}

unsigned getNumIndices() const {
return cast<llvm::ExtractValueInst>(Val)->getNumIndices();
}

unsigned hasIndices() const {
return cast<llvm::ExtractValueInst>(Val)->hasIndices();
}
};

class VAArgInst : public UnaryInstruction {
VAArgInst(llvm::VAArgInst *FI, Context &Ctx)
: UnaryInstruction(ClassID::VAArg, Opcode::VAArg, FI, Ctx) {}
Expand Down Expand Up @@ -3123,6 +3187,8 @@ class Context {
friend ExtractElementInst; // For createExtractElementInst()
ShuffleVectorInst *createShuffleVectorInst(llvm::ShuffleVectorInst *SVI);
friend ShuffleVectorInst; // For createShuffleVectorInst()
ExtractValueInst *createExtractValueInst(llvm::ExtractValueInst *IVI);
friend ExtractValueInst; // For createExtractValueInst()
InsertValueInst *createInsertValueInst(llvm::InsertValueInst *IVI);
friend InsertValueInst; // For createInsertValueInst()
BranchInst *createBranchInst(llvm::BranchInst *I);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ DEF_INSTR(VAArg, OP(VAArg), VAArgInst)
DEF_INSTR(Freeze, OP(Freeze), FreezeInst)
DEF_INSTR(Fence, OP(Fence), FenceInst)
DEF_INSTR(ShuffleVector, OP(ShuffleVector), ShuffleVectorInst)
DEF_INSTR(ExtractValue, OP(ExtractValue), ExtractValueInst)
DEF_INSTR(InsertValue, OP(InsertValue), InsertValueInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Br, OP(Br), BranchInst)
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,21 @@ Constant *ShuffleVectorInst::convertShuffleMaskForBitcode(
llvm::ShuffleVectorInst::convertShuffleMaskForBitcode(Mask, ResultTy));
}

Value *ExtractValueInst::create(Value *Agg, ArrayRef<unsigned> Idxs,
BBIterator WhereIt, BasicBlock *WhereBB,
Context &Ctx, const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
if (WhereIt != WhereBB->end())
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
else
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
llvm::Value *NewV = Builder.CreateExtractValue(Agg->Val, Idxs, Name);
if (auto *NewExtractValueInst = dyn_cast<llvm::ExtractValueInst>(NewV))
return Ctx.createExtractValueInst(NewExtractValueInst);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

Value *InsertValueInst::create(Value *Agg, Value *Val, ArrayRef<unsigned> Idxs,
BBIterator WhereIt, BasicBlock *WhereBB,
Context &Ctx, const Twine &Name) {
Expand Down Expand Up @@ -2320,6 +2335,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
new ShuffleVectorInst(LLVMIns, *this));
return It->second.get();
}
case llvm::Instruction::ExtractValue: {
auto *LLVMIns = cast<llvm::ExtractValueInst>(LLVMV);
It->second =
std::unique_ptr<ExtractValueInst>(new ExtractValueInst(LLVMIns, *this));
return It->second.get();
}
case llvm::Instruction::InsertValue: {
auto *LLVMIns = cast<llvm::InsertValueInst>(LLVMV);
It->second =
Expand Down Expand Up @@ -2548,6 +2569,12 @@ Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
return cast<ShuffleVectorInst>(registerValue(std::move(NewPtr)));
}

ExtractValueInst *Context::createExtractValueInst(llvm::ExtractValueInst *EVI) {
auto NewPtr =
std::unique_ptr<ExtractValueInst>(new ExtractValueInst(EVI, *this));
return cast<ExtractValueInst>(registerValue(std::move(NewPtr)));
}

InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) {
auto NewPtr =
std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this));
Expand Down
100 changes: 100 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,106 @@ define void @foo(<2 x i8> %v1, <2 x i8> %v2) {
}
}

TEST_F(SandboxIRTest, ExtractValueInst) {
parseIR(C, R"IR(
define void @foo({i32, float} %agg) {
%ext_simple = extractvalue {i32, float} %agg, 0
%ext_nested = extractvalue {float, {i32}} undef, 1, 0
%const1 = extractvalue {i32, float} {i32 0, float 99.0}, 0
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto *ArgAgg = F.getArg(0);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *ExtSimple = cast<sandboxir::ExtractValueInst>(&*It++);
auto *ExtNested = cast<sandboxir::ExtractValueInst>(&*It++);
auto *Const1 = cast<sandboxir::ExtractValueInst>(&*It++);
auto *Ret = &*It++;

EXPECT_EQ(ExtSimple->getOperand(0), ArgAgg);

// create before instruction
auto *NewExtBeforeRet =
cast<sandboxir::ExtractValueInst>(sandboxir::ExtractValueInst::create(
ArgAgg, ArrayRef<unsigned>({0}), Ret->getIterator(), Ret->getParent(),
Ctx, "NewExtBeforeRet"));
EXPECT_EQ(NewExtBeforeRet->getNextNode(), Ret);
#ifndef NDEBUG
EXPECT_EQ(NewExtBeforeRet->getName(), "NewExtBeforeRet");
#endif // NDEBUG

// create at end of BB
auto *NewExtAtEnd =
cast<sandboxir::ExtractValueInst>(sandboxir::ExtractValueInst::create(
ArgAgg, ArrayRef<unsigned>({0}), BB->end(), BB, Ctx, "NewExtAtEnd"));
EXPECT_EQ(NewExtAtEnd->getPrevNode(), Ret);
#ifndef NDEBUG
EXPECT_EQ(NewExtAtEnd->getName(), "NewExtAtEnd");
#endif // NDEBUG

// Test the path that creates a folded constant.
auto *ShouldBeConstant = sandboxir::ExtractValueInst::create(
Const1->getOperand(0), ArrayRef<unsigned>({0}), BB->end(), BB, Ctx);
EXPECT_TRUE(isa<sandboxir::Constant>(ShouldBeConstant));

auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx);
EXPECT_EQ(ShouldBeConstant, Zero);

// getIndexedType
Type *AggType = ExtNested->getAggregateOperand()->getType();
EXPECT_EQ(sandboxir::ExtractValueInst::getIndexedType(
AggType, ArrayRef<unsigned>({1, 0})),
llvm::ExtractValueInst::getIndexedType(AggType,
ArrayRef<unsigned>({1, 0})));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you have two checks for getIndexedType(), one that returns non-null and one that returns null. The reason is that once we migrate to sandboxir::Type the null case would need special treatment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a simple test for the nullptr case. Please take another look.


EXPECT_EQ(sandboxir::ExtractValueInst::getIndexedType(
AggType, ArrayRef<unsigned>({2})),
nullptr);

// idx_begin / idx_end
{
SmallVector<int, 2> IndicesSimple(ExtSimple->idx_begin(),
ExtSimple->idx_end());
EXPECT_THAT(IndicesSimple, testing::ElementsAre(0u));

SmallVector<int, 2> IndicesNested(ExtNested->idx_begin(),
ExtNested->idx_end());
EXPECT_THAT(IndicesNested, testing::ElementsAre(1u, 0u));
}

// indices
{
SmallVector<int, 2> IndicesSimple(ExtSimple->indices());
EXPECT_THAT(IndicesSimple, testing::ElementsAre(0u));

SmallVector<int, 2> IndicesNested(ExtNested->indices());
EXPECT_THAT(IndicesNested, testing::ElementsAre(1u, 0u));
}

// getAggregateOperand
EXPECT_EQ(ExtSimple->getAggregateOperand(), ArgAgg);
const auto *ConstExtSimple = ExtSimple;
EXPECT_EQ(ConstExtSimple->getAggregateOperand(), ArgAgg);

// getAggregateOperandIndex
EXPECT_EQ(sandboxir::ExtractValueInst::getAggregateOperandIndex(),
llvm::ExtractValueInst::getAggregateOperandIndex());

// getIndices
EXPECT_EQ(ExtSimple->getIndices().size(), 1u);
EXPECT_EQ(ExtSimple->getIndices()[0], 0u);

// getNumIndices
EXPECT_EQ(ExtSimple->getNumIndices(), 1u);

// hasIndices
EXPECT_EQ(ExtSimple->hasIndices(), true);
}

TEST_F(SandboxIRTest, InsertValueInst) {
parseIR(C, R"IR(
define void @foo({i32, float} %agg, i32 %i) {
Expand Down
Loading