Skip to content

Commit 15aa4ef

Browse files
authored
[SandboxIR] Add the ExtractElementInst class (#102706)
1 parent 290f7ea commit 15aa4ef

File tree

4 files changed

+180
-62
lines changed

4 files changed

+180
-62
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class Context;
111111
class Function;
112112
class Instruction;
113113
class SelectInst;
114+
class ExtractElementInst;
114115
class InsertElementInst;
115116
class BranchInst;
116117
class UnaryInstruction;
@@ -232,24 +233,25 @@ class Value {
232233
/// order.
233234
llvm::Value *Val = nullptr;
234235

235-
friend class Context; // For getting `Val`.
236-
friend class User; // For getting `Val`.
237-
friend class Use; // For getting `Val`.
238-
friend class SelectInst; // For getting `Val`.
239-
friend class InsertElementInst; // For getting `Val`.
240-
friend class BranchInst; // For getting `Val`.
241-
friend class LoadInst; // For getting `Val`.
242-
friend class StoreInst; // For getting `Val`.
243-
friend class ReturnInst; // For getting `Val`.
244-
friend class CallBase; // For getting `Val`.
245-
friend class CallInst; // For getting `Val`.
246-
friend class InvokeInst; // For getting `Val`.
247-
friend class CallBrInst; // For getting `Val`.
248-
friend class GetElementPtrInst; // For getting `Val`.
249-
friend class AllocaInst; // For getting `Val`.
250-
friend class CastInst; // For getting `Val`.
251-
friend class PHINode; // For getting `Val`.
252-
friend class UnreachableInst; // For getting `Val`.
236+
friend class Context; // For getting `Val`.
237+
friend class User; // For getting `Val`.
238+
friend class Use; // For getting `Val`.
239+
friend class SelectInst; // For getting `Val`.
240+
friend class ExtractElementInst; // For getting `Val`.
241+
friend class InsertElementInst; // For getting `Val`.
242+
friend class BranchInst; // For getting `Val`.
243+
friend class LoadInst; // For getting `Val`.
244+
friend class StoreInst; // For getting `Val`.
245+
friend class ReturnInst; // For getting `Val`.
246+
friend class CallBase; // For getting `Val`.
247+
friend class CallInst; // For getting `Val`.
248+
friend class InvokeInst; // For getting `Val`.
249+
friend class CallBrInst; // For getting `Val`.
250+
friend class GetElementPtrInst; // For getting `Val`.
251+
friend class AllocaInst; // For getting `Val`.
252+
friend class CastInst; // For getting `Val`.
253+
friend class PHINode; // For getting `Val`.
254+
friend class UnreachableInst; // For getting `Val`.
253255

254256
/// All values point to the context.
255257
Context &Ctx;
@@ -615,20 +617,21 @@ class Instruction : public sandboxir::User {
615617
/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
616618
/// returns its topmost LLVM IR instruction.
617619
llvm::Instruction *getTopmostLLVMInstruction() const;
618-
friend class SelectInst; // For getTopmostLLVMInstruction().
619-
friend class InsertElementInst; // For getTopmostLLVMInstruction().
620-
friend class BranchInst; // For getTopmostLLVMInstruction().
621-
friend class LoadInst; // For getTopmostLLVMInstruction().
622-
friend class StoreInst; // For getTopmostLLVMInstruction().
623-
friend class ReturnInst; // For getTopmostLLVMInstruction().
624-
friend class CallInst; // For getTopmostLLVMInstruction().
625-
friend class InvokeInst; // For getTopmostLLVMInstruction().
626-
friend class CallBrInst; // For getTopmostLLVMInstruction().
627-
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
628-
friend class AllocaInst; // For getTopmostLLVMInstruction().
629-
friend class CastInst; // For getTopmostLLVMInstruction().
630-
friend class PHINode; // For getTopmostLLVMInstruction().
631-
friend class UnreachableInst; // For getTopmostLLVMInstruction().
620+
friend class SelectInst; // For getTopmostLLVMInstruction().
621+
friend class ExtractElementInst; // For getTopmostLLVMInstruction().
622+
friend class InsertElementInst; // For getTopmostLLVMInstruction().
623+
friend class BranchInst; // For getTopmostLLVMInstruction().
624+
friend class LoadInst; // For getTopmostLLVMInstruction().
625+
friend class StoreInst; // For getTopmostLLVMInstruction().
626+
friend class ReturnInst; // For getTopmostLLVMInstruction().
627+
friend class CallInst; // For getTopmostLLVMInstruction().
628+
friend class InvokeInst; // For getTopmostLLVMInstruction().
629+
friend class CallBrInst; // For getTopmostLLVMInstruction().
630+
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
631+
friend class AllocaInst; // For getTopmostLLVMInstruction().
632+
friend class CastInst; // For getTopmostLLVMInstruction().
633+
friend class PHINode; // For getTopmostLLVMInstruction().
634+
friend class UnreachableInst; // For getTopmostLLVMInstruction().
632635

633636
/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
634637
/// order.
@@ -768,6 +771,37 @@ class InsertElementInst final
768771
}
769772
};
770773

774+
class ExtractElementInst final
775+
: public SingleLLVMInstructionImpl<llvm::ExtractElementInst> {
776+
/// Use Context::createExtractElementInst() instead.
777+
ExtractElementInst(llvm::Instruction *I, Context &Ctx)
778+
: SingleLLVMInstructionImpl(ClassID::ExtractElement,
779+
Opcode::ExtractElement, I, Ctx) {}
780+
friend class Context; // For accessing the constructor in
781+
// create*()
782+
783+
public:
784+
static Value *create(Value *Vec, Value *Idx, Instruction *InsertBefore,
785+
Context &Ctx, const Twine &Name = "");
786+
static Value *create(Value *Vec, Value *Idx, BasicBlock *InsertAtEnd,
787+
Context &Ctx, const Twine &Name = "");
788+
static bool classof(const Value *From) {
789+
return From->getSubclassID() == ClassID::ExtractElement;
790+
}
791+
792+
static bool isValidOperands(const Value *Vec, const Value *Idx) {
793+
return llvm::ExtractElementInst::isValidOperands(Vec->Val, Idx->Val);
794+
}
795+
Value *getVectorOperand() { return getOperand(0); }
796+
Value *getIndexOperand() { return getOperand(1); }
797+
const Value *getVectorOperand() const { return getOperand(0); }
798+
const Value *getIndexOperand() const { return getOperand(1); }
799+
800+
VectorType *getVectorOperandType() const {
801+
return cast<VectorType>(getVectorOperand()->getType());
802+
}
803+
};
804+
771805
class BranchInst : public SingleLLVMInstructionImpl<llvm::BranchInst> {
772806
/// Use Context::createBranchInst(). Don't call the constructor directly.
773807
BranchInst(llvm::BranchInst *BI, Context &Ctx)
@@ -1644,6 +1678,8 @@ class Context {
16441678
friend SelectInst; // For createSelectInst()
16451679
InsertElementInst *createInsertElementInst(llvm::InsertElementInst *IEI);
16461680
friend InsertElementInst; // For createInsertElementInst()
1681+
ExtractElementInst *createExtractElementInst(llvm::ExtractElementInst *EEI);
1682+
friend ExtractElementInst; // For createExtractElementInst()
16471683
BranchInst *createBranchInst(llvm::BranchInst *I);
16481684
friend BranchInst; // For createBranchInst()
16491685
LoadInst *createLoadInst(llvm::LoadInst *LI);

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,37 @@ DEF_USER(Constant, Constant)
3232
#define OPCODES(...)
3333
#endif
3434
// clang-format off
35-
// ClassID, Opcode(s), Class
36-
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
37-
DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
38-
DEF_INSTR(Select, OP(Select), SelectInst)
39-
DEF_INSTR(Br, OP(Br), BranchInst)
40-
DEF_INSTR(Load, OP(Load), LoadInst)
41-
DEF_INSTR(Store, OP(Store), StoreInst)
42-
DEF_INSTR(Ret, OP(Ret), ReturnInst)
43-
DEF_INSTR(Call, OP(Call), CallInst)
44-
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
45-
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
46-
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
47-
DEF_INSTR(Alloca, OP(Alloca), AllocaInst)
48-
DEF_INSTR(Cast, OPCODES(\
49-
OP(ZExt) \
50-
OP(SExt) \
51-
OP(FPToUI) \
52-
OP(FPToSI) \
53-
OP(FPExt) \
54-
OP(PtrToInt) \
55-
OP(IntToPtr) \
56-
OP(SIToFP) \
57-
OP(UIToFP) \
58-
OP(Trunc) \
59-
OP(FPTrunc) \
60-
OP(BitCast) \
61-
OP(AddrSpaceCast) \
62-
), CastInst)
63-
DEF_INSTR(PHI, OP(PHI), PHINode)
64-
DEF_INSTR(Unreachable, OP(Unreachable), UnreachableInst)
35+
// ClassID, Opcode(s), Class
36+
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
37+
DEF_INSTR(ExtractElement, OP(ExtractElement), ExtractElementInst)
38+
DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
39+
DEF_INSTR(Select, OP(Select), SelectInst)
40+
DEF_INSTR(Br, OP(Br), BranchInst)
41+
DEF_INSTR(Load, OP(Load), LoadInst)
42+
DEF_INSTR(Store, OP(Store), StoreInst)
43+
DEF_INSTR(Ret, OP(Ret), ReturnInst)
44+
DEF_INSTR(Call, OP(Call), CallInst)
45+
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
46+
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
47+
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
48+
DEF_INSTR(Alloca, OP(Alloca), AllocaInst)
49+
DEF_INSTR(Cast, OPCODES(\
50+
OP(ZExt) \
51+
OP(SExt) \
52+
OP(FPToUI) \
53+
OP(FPToSI) \
54+
OP(FPExt) \
55+
OP(PtrToInt) \
56+
OP(IntToPtr) \
57+
OP(SIToFP) \
58+
OP(UIToFP) \
59+
OP(Trunc) \
60+
OP(FPTrunc) \
61+
OP(BitCast) \
62+
OP(AddrSpaceCast) \
63+
), CastInst)
64+
DEF_INSTR(PHI, OP(PHI), PHINode)
65+
DEF_INSTR(Unreachable, OP(Unreachable), UnreachableInst)
6566

6667
// clang-format on
6768
#ifdef DEF_VALUE

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,30 @@ Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx,
12351235
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
12361236
}
12371237

1238+
Value *ExtractElementInst::create(Value *Vec, Value *Idx,
1239+
Instruction *InsertBefore, Context &Ctx,
1240+
const Twine &Name) {
1241+
auto &Builder = Ctx.getLLVMIRBuilder();
1242+
Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
1243+
llvm::Value *NewV = Builder.CreateExtractElement(Vec->Val, Idx->Val, Name);
1244+
if (auto *NewExtract = dyn_cast<llvm::ExtractElementInst>(NewV))
1245+
return Ctx.createExtractElementInst(NewExtract);
1246+
assert(isa<llvm::Constant>(NewV) && "Expected constant");
1247+
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
1248+
}
1249+
1250+
Value *ExtractElementInst::create(Value *Vec, Value *Idx,
1251+
BasicBlock *InsertAtEnd, Context &Ctx,
1252+
const Twine &Name) {
1253+
auto &Builder = Ctx.getLLVMIRBuilder();
1254+
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
1255+
llvm::Value *NewV = Builder.CreateExtractElement(Vec->Val, Idx->Val, Name);
1256+
if (auto *NewExtract = dyn_cast<llvm::ExtractElementInst>(NewV))
1257+
return Ctx.createExtractElementInst(NewExtract);
1258+
assert(isa<llvm::Constant>(NewV) && "Expected constant");
1259+
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
1260+
}
1261+
12381262
Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx,
12391263
bool IsSigned) {
12401264
llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
@@ -1356,6 +1380,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
13561380
It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
13571381
return It->second.get();
13581382
}
1383+
case llvm::Instruction::ExtractElement: {
1384+
auto *LLVMIns = cast<llvm::ExtractElementInst>(LLVMV);
1385+
It->second = std::unique_ptr<ExtractElementInst>(
1386+
new ExtractElementInst(LLVMIns, *this));
1387+
return It->second.get();
1388+
}
13591389
case llvm::Instruction::InsertElement: {
13601390
auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV);
13611391
It->second = std::unique_ptr<InsertElementInst>(
@@ -1459,6 +1489,13 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
14591489
return cast<SelectInst>(registerValue(std::move(NewPtr)));
14601490
}
14611491

1492+
ExtractElementInst *
1493+
Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
1494+
auto NewPtr =
1495+
std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
1496+
return cast<ExtractElementInst>(registerValue(std::move(NewPtr)));
1497+
}
1498+
14621499
InsertElementInst *
14631500
Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
14641501
auto NewPtr =

llvm/unittests/SandboxIR/SandboxIRTest.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,50 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
631631
}
632632
}
633633

634+
TEST_F(SandboxIRTest, ExtractElementInst) {
635+
parseIR(C, R"IR(
636+
define void @foo(<2 x i8> %vec, i32 %idx) {
637+
%ins0 = extractelement <2 x i8> %vec, i32 %idx
638+
ret void
639+
}
640+
)IR");
641+
Function &LLVMF = *M->getFunction("foo");
642+
sandboxir::Context Ctx(C);
643+
auto &F = *Ctx.createFunction(&LLVMF);
644+
auto *ArgVec = F.getArg(0);
645+
auto *ArgIdx = F.getArg(1);
646+
auto *BB = &*F.begin();
647+
auto It = BB->begin();
648+
auto *EI = cast<sandboxir::ExtractElementInst>(&*It++);
649+
auto *Ret = &*It++;
650+
651+
EXPECT_EQ(EI->getOpcode(), sandboxir::Instruction::Opcode::ExtractElement);
652+
EXPECT_EQ(EI->getOperand(0), ArgVec);
653+
EXPECT_EQ(EI->getOperand(1), ArgIdx);
654+
EXPECT_EQ(EI->getVectorOperand(), ArgVec);
655+
EXPECT_EQ(EI->getIndexOperand(), ArgIdx);
656+
EXPECT_EQ(EI->getVectorOperandType(), ArgVec->getType());
657+
658+
auto *NewI1 =
659+
cast<sandboxir::ExtractElementInst>(sandboxir::ExtractElementInst::create(
660+
ArgVec, ArgIdx, Ret, Ctx, "NewExtrBeforeRet"));
661+
EXPECT_EQ(NewI1->getOperand(0), ArgVec);
662+
EXPECT_EQ(NewI1->getOperand(1), ArgIdx);
663+
EXPECT_EQ(NewI1->getNextNode(), Ret);
664+
665+
auto *NewI2 =
666+
cast<sandboxir::ExtractElementInst>(sandboxir::ExtractElementInst::create(
667+
ArgVec, ArgIdx, BB, Ctx, "NewExtrAtEndOfBB"));
668+
EXPECT_EQ(NewI2->getPrevNode(), Ret);
669+
670+
auto *LLVMArgVec = LLVMF.getArg(0);
671+
auto *LLVMArgIdx = LLVMF.getArg(1);
672+
EXPECT_EQ(sandboxir::ExtractElementInst::isValidOperands(ArgVec, ArgIdx),
673+
llvm::ExtractElementInst::isValidOperands(LLVMArgVec, LLVMArgIdx));
674+
EXPECT_EQ(sandboxir::ExtractElementInst::isValidOperands(ArgIdx, ArgVec),
675+
llvm::ExtractElementInst::isValidOperands(LLVMArgIdx, LLVMArgVec));
676+
}
677+
634678
TEST_F(SandboxIRTest, InsertElementInst) {
635679
parseIR(C, R"IR(
636680
define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {

0 commit comments

Comments
 (0)