Skip to content

[SandboxIR] Add the ExtractElementInst class #102706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 68 additions & 32 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Context;
class Function;
class Instruction;
class SelectInst;
class ExtractElementInst;
class InsertElementInst;
class BranchInst;
class UnaryInstruction;
Expand Down Expand Up @@ -232,24 +233,25 @@ class Value {
/// order.
llvm::Value *Val = nullptr;

friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class SelectInst; // For getting `Val`.
friend class InsertElementInst; // For getting `Val`.
friend class BranchInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.
friend class CallBase; // For getting `Val`.
friend class CallInst; // For getting `Val`.
friend class InvokeInst; // For getting `Val`.
friend class CallBrInst; // For getting `Val`.
friend class GetElementPtrInst; // For getting `Val`.
friend class AllocaInst; // For getting `Val`.
friend class CastInst; // For getting `Val`.
friend class PHINode; // For getting `Val`.
friend class UnreachableInst; // For getting `Val`.
friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class SelectInst; // For getting `Val`.
friend class ExtractElementInst; // For getting `Val`.
friend class InsertElementInst; // For getting `Val`.
friend class BranchInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.
friend class CallBase; // For getting `Val`.
friend class CallInst; // For getting `Val`.
friend class InvokeInst; // For getting `Val`.
friend class CallBrInst; // For getting `Val`.
friend class GetElementPtrInst; // For getting `Val`.
friend class AllocaInst; // For getting `Val`.
friend class CastInst; // For getting `Val`.
friend class PHINode; // For getting `Val`.
friend class UnreachableInst; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -615,20 +617,21 @@ class Instruction : public sandboxir::User {
/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
friend class SelectInst; // For getTopmostLLVMInstruction().
friend class InsertElementInst; // For getTopmostLLVMInstruction().
friend class BranchInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
friend class CallInst; // For getTopmostLLVMInstruction().
friend class InvokeInst; // For getTopmostLLVMInstruction().
friend class CallBrInst; // For getTopmostLLVMInstruction().
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
friend class AllocaInst; // For getTopmostLLVMInstruction().
friend class CastInst; // For getTopmostLLVMInstruction().
friend class PHINode; // For getTopmostLLVMInstruction().
friend class UnreachableInst; // For getTopmostLLVMInstruction().
friend class SelectInst; // For getTopmostLLVMInstruction().
friend class ExtractElementInst; // For getTopmostLLVMInstruction().
friend class InsertElementInst; // For getTopmostLLVMInstruction().
friend class BranchInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
friend class CallInst; // For getTopmostLLVMInstruction().
friend class InvokeInst; // For getTopmostLLVMInstruction().
friend class CallBrInst; // For getTopmostLLVMInstruction().
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
friend class AllocaInst; // For getTopmostLLVMInstruction().
friend class CastInst; // For getTopmostLLVMInstruction().
friend class PHINode; // For getTopmostLLVMInstruction().
friend class UnreachableInst; // For getTopmostLLVMInstruction().

/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
Expand Down Expand Up @@ -768,6 +771,37 @@ class InsertElementInst final
}
};

class ExtractElementInst final
: public SingleLLVMInstructionImpl<llvm::ExtractElementInst> {
/// Use Context::createExtractElementInst() instead.
ExtractElementInst(llvm::Instruction *I, Context &Ctx)
: SingleLLVMInstructionImpl(ClassID::ExtractElement,
Opcode::ExtractElement, I, Ctx) {}
friend class Context; // For accessing the constructor in
// create*()

public:
static Value *create(Value *Vec, Value *Idx, Instruction *InsertBefore,
Context &Ctx, const Twine &Name = "");
static Value *create(Value *Vec, Value *Idx, BasicBlock *InsertAtEnd,
Context &Ctx, const Twine &Name = "");
static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::ExtractElement;
}

static bool isValidOperands(const Value *Vec, const Value *Idx) {
return llvm::ExtractElementInst::isValidOperands(Vec->Val, Idx->Val);
}
Value *getVectorOperand() { return getOperand(0); }
Value *getIndexOperand() { return getOperand(1); }
const Value *getVectorOperand() const { return getOperand(0); }
const Value *getIndexOperand() const { return getOperand(1); }

VectorType *getVectorOperandType() const {
return cast<VectorType>(getVectorOperand()->getType());
}
};

class BranchInst : public SingleLLVMInstructionImpl<llvm::BranchInst> {
/// Use Context::createBranchInst(). Don't call the constructor directly.
BranchInst(llvm::BranchInst *BI, Context &Ctx)
Expand Down Expand Up @@ -1644,6 +1678,8 @@ class Context {
friend SelectInst; // For createSelectInst()
InsertElementInst *createInsertElementInst(llvm::InsertElementInst *IEI);
friend InsertElementInst; // For createInsertElementInst()
ExtractElementInst *createExtractElementInst(llvm::ExtractElementInst *EEI);
friend ExtractElementInst; // For createExtractElementInst()
BranchInst *createBranchInst(llvm::BranchInst *I);
friend BranchInst; // For createBranchInst()
LoadInst *createLoadInst(llvm::LoadInst *LI);
Expand Down
61 changes: 31 additions & 30 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -32,36 +32,37 @@ DEF_USER(Constant, Constant)
#define OPCODES(...)
#endif
// clang-format off
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)
DEF_INSTR(Call, OP(Call), CallInst)
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
DEF_INSTR(Alloca, OP(Alloca), AllocaInst)
DEF_INSTR(Cast, OPCODES(\
OP(ZExt) \
OP(SExt) \
OP(FPToUI) \
OP(FPToSI) \
OP(FPExt) \
OP(PtrToInt) \
OP(IntToPtr) \
OP(SIToFP) \
OP(UIToFP) \
OP(Trunc) \
OP(FPTrunc) \
OP(BitCast) \
OP(AddrSpaceCast) \
), CastInst)
DEF_INSTR(PHI, OP(PHI), PHINode)
DEF_INSTR(Unreachable, OP(Unreachable), UnreachableInst)
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(ExtractElement, OP(ExtractElement), ExtractElementInst)
DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)
DEF_INSTR(Call, OP(Call), CallInst)
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
DEF_INSTR(Alloca, OP(Alloca), AllocaInst)
DEF_INSTR(Cast, OPCODES(\
OP(ZExt) \
OP(SExt) \
OP(FPToUI) \
OP(FPToSI) \
OP(FPExt) \
OP(PtrToInt) \
OP(IntToPtr) \
OP(SIToFP) \
OP(UIToFP) \
OP(Trunc) \
OP(FPTrunc) \
OP(BitCast) \
OP(AddrSpaceCast) \
), CastInst)
DEF_INSTR(PHI, OP(PHI), PHINode)
DEF_INSTR(Unreachable, OP(Unreachable), UnreachableInst)

// clang-format on
#ifdef DEF_VALUE
Expand Down
37 changes: 37 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,30 @@ Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx,
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

Value *ExtractElementInst::create(Value *Vec, Value *Idx,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
llvm::Value *NewV = Builder.CreateExtractElement(Vec->Val, Idx->Val, Name);
if (auto *NewExtract = dyn_cast<llvm::ExtractElementInst>(NewV))
return Ctx.createExtractElementInst(NewExtract);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

Value *ExtractElementInst::create(Value *Vec, Value *Idx,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
llvm::Value *NewV = Builder.CreateExtractElement(Vec->Val, Idx->Val, Name);
if (auto *NewExtract = dyn_cast<llvm::ExtractElementInst>(NewV))
return Ctx.createExtractElementInst(NewExtract);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx,
bool IsSigned) {
llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
Expand Down Expand Up @@ -1356,6 +1380,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
return It->second.get();
}
case llvm::Instruction::ExtractElement: {
auto *LLVMIns = cast<llvm::ExtractElementInst>(LLVMV);
It->second = std::unique_ptr<ExtractElementInst>(
new ExtractElementInst(LLVMIns, *this));
return It->second.get();
}
case llvm::Instruction::InsertElement: {
auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV);
It->second = std::unique_ptr<InsertElementInst>(
Expand Down Expand Up @@ -1459,6 +1489,13 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
return cast<SelectInst>(registerValue(std::move(NewPtr)));
}

ExtractElementInst *
Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
auto NewPtr =
std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
return cast<ExtractElementInst>(registerValue(std::move(NewPtr)));
}

InsertElementInst *
Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
auto NewPtr =
Expand Down
44 changes: 44 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,50 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
}
}

TEST_F(SandboxIRTest, ExtractElementInst) {
parseIR(C, R"IR(
define void @foo(<2 x i8> %vec, i32 %idx) {
%ins0 = extractelement <2 x i8> %vec, i32 %idx
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto *ArgVec = F.getArg(0);
auto *ArgIdx = F.getArg(1);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *EI = cast<sandboxir::ExtractElementInst>(&*It++);
auto *Ret = &*It++;

EXPECT_EQ(EI->getOpcode(), sandboxir::Instruction::Opcode::ExtractElement);
EXPECT_EQ(EI->getOperand(0), ArgVec);
EXPECT_EQ(EI->getOperand(1), ArgIdx);
EXPECT_EQ(EI->getVectorOperand(), ArgVec);
EXPECT_EQ(EI->getIndexOperand(), ArgIdx);
EXPECT_EQ(EI->getVectorOperandType(), ArgVec->getType());

auto *NewI1 =
cast<sandboxir::ExtractElementInst>(sandboxir::ExtractElementInst::create(
ArgVec, ArgIdx, Ret, Ctx, "NewExtrBeforeRet"));
EXPECT_EQ(NewI1->getOperand(0), ArgVec);
EXPECT_EQ(NewI1->getOperand(1), ArgIdx);
EXPECT_EQ(NewI1->getNextNode(), Ret);

auto *NewI2 =
cast<sandboxir::ExtractElementInst>(sandboxir::ExtractElementInst::create(
ArgVec, ArgIdx, BB, Ctx, "NewExtrAtEndOfBB"));
EXPECT_EQ(NewI2->getPrevNode(), Ret);

auto *LLVMArgVec = LLVMF.getArg(0);
auto *LLVMArgIdx = LLVMF.getArg(1);
EXPECT_EQ(sandboxir::ExtractElementInst::isValidOperands(ArgVec, ArgIdx),
llvm::ExtractElementInst::isValidOperands(LLVMArgVec, LLVMArgIdx));
EXPECT_EQ(sandboxir::ExtractElementInst::isValidOperands(ArgIdx, ArgVec),
llvm::ExtractElementInst::isValidOperands(LLVMArgIdx, LLVMArgVec));
}

TEST_F(SandboxIRTest, InsertElementInst) {
parseIR(C, R"IR(
define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {
Expand Down
Loading