Skip to content

[SandboxIR] Implement GetElementPtrInst #101078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 129 additions & 20 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class CallBase;
class CallInst;
class InvokeInst;
class CallBrInst;
class GetElementPtrInst;

/// Iterator for the `Use` edges of a User's operands.
/// \Returns the operand `Use` when dereferenced.
Expand Down Expand Up @@ -196,18 +197,19 @@ class Value {
/// order.
llvm::Value *Val = nullptr;

friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class SelectInst; // For getting `Val`.
friend class BranchInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.
friend class CallBase; // For getting `Val`.
friend class CallInst; // For getting `Val`.
friend class InvokeInst; // For getting `Val`.
friend class CallBrInst; // For getting `Val`.
friend class Context; // For getting `Val`.
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class SelectInst; // For getting `Val`.
friend class BranchInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.
friend class CallBase; // For getting `Val`.
friend class CallInst; // For getting `Val`.
friend class InvokeInst; // For getting `Val`.
friend class CallBrInst; // For getting `Val`.
friend class GetElementPtrInst; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -540,14 +542,15 @@ class Instruction : public sandboxir::User {
/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
friend class SelectInst; // For getTopmostLLVMInstruction().
friend class BranchInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
friend class CallInst; // For getTopmostLLVMInstruction().
friend class InvokeInst; // For getTopmostLLVMInstruction().
friend class CallBrInst; // For getTopmostLLVMInstruction().
friend class SelectInst; // For getTopmostLLVMInstruction().
friend class BranchInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
friend class CallInst; // For getTopmostLLVMInstruction().
friend class InvokeInst; // For getTopmostLLVMInstruction().
friend class CallBrInst; // For getTopmostLLVMInstruction().
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().

/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
Expand Down Expand Up @@ -1175,6 +1178,110 @@ class CallBrInst final : public CallBase {
#endif
};

class GetElementPtrInst final : public Instruction {
/// Use Context::createGetElementPtrInst(). Don't call
/// the constructor directly.
GetElementPtrInst(llvm::Instruction *I, Context &Ctx)
: Instruction(ClassID::GetElementPtr, Opcode::GetElementPtr, I, Ctx) {}
GetElementPtrInst(ClassID SubclassID, llvm::Instruction *I, Context &Ctx)
: Instruction(SubclassID, Opcode::GetElementPtr, I, Ctx) {}
friend class Context; // For accessing the constructor in
// create*()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}

public:
static Value *create(Type *Ty, Value *Ptr, ArrayRef<Value *> IdxList,
BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
const Twine &NameStr = "");
static Value *create(Type *Ty, Value *Ptr, ArrayRef<Value *> IdxList,
Instruction *InsertBefore, Context &Ctx,
const Twine &NameStr = "");
static Value *create(Type *Ty, Value *Ptr, ArrayRef<Value *> IdxList,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &NameStr = "");

static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::GetElementPtr;
}
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }

Type *getSourceElementType() const {
return cast<llvm::GetElementPtrInst>(Val)->getSourceElementType();
}
Type *getResultElementType() const {
return cast<llvm::GetElementPtrInst>(Val)->getResultElementType();
}
unsigned getAddressSpace() const {
return cast<llvm::GetElementPtrInst>(Val)->getAddressSpace();
}

inline op_iterator idx_begin() { return op_begin() + 1; }
inline const_op_iterator idx_begin() const {
return const_cast<GetElementPtrInst *>(this)->idx_begin();
}
inline op_iterator idx_end() { return op_end(); }
inline const_op_iterator idx_end() const {
return const_cast<GetElementPtrInst *>(this)->idx_end();
}
inline iterator_range<op_iterator> indices() {
return make_range(idx_begin(), idx_end());
}
inline iterator_range<const_op_iterator> indices() const {
return const_cast<GetElementPtrInst *>(this)->indices();
}

Value *getPointerOperand() const;
static unsigned getPointerOperandIndex() {
return llvm::GetElementPtrInst::getPointerOperandIndex();
}
Type *getPointerOperandType() const {
return cast<llvm::GetElementPtrInst>(Val)->getPointerOperandType();
}
unsigned getPointerAddressSpace() const {
return cast<llvm::GetElementPtrInst>(Val)->getPointerAddressSpace();
}
unsigned getNumIndices() const {
return cast<llvm::GetElementPtrInst>(Val)->getNumIndices();
}
bool hasIndices() const {
return cast<llvm::GetElementPtrInst>(Val)->hasIndices();
}
bool hasAllConstantIndices() const {
return cast<llvm::GetElementPtrInst>(Val)->hasAllConstantIndices();
}
GEPNoWrapFlags getNoWrapFlags() const {
return cast<llvm::GetElementPtrInst>(Val)->getNoWrapFlags();
}
bool isInBounds() const {
return cast<llvm::GetElementPtrInst>(Val)->isInBounds();
}
bool hasNoUnsignedSignedWrap() const {
return cast<llvm::GetElementPtrInst>(Val)->hasNoUnsignedSignedWrap();
}
bool hasNoUnsignedWrap() const {
return cast<llvm::GetElementPtrInst>(Val)->hasNoUnsignedWrap();
}
bool accumulateConstantOffset(const DataLayout &DL, APInt &Offset) const {
return cast<llvm::GetElementPtrInst>(Val)->accumulateConstantOffset(DL,
Offset);
}
// TODO: Add missing member functions.

#ifndef NDEBUG
void verify() const final {}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
/// an OpaqueInstr.
class OpaqueInst : public sandboxir::Instruction {
Expand Down Expand Up @@ -1329,6 +1436,8 @@ class Context {
friend InvokeInst; // For createInvokeInst()
CallBrInst *createCallBrInst(llvm::CallBrInst *I);
friend CallBrInst; // For createCallBrInst()
GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
friend GetElementPtrInst; // For createGetElementPtrInst()

public:
Context(LLVMContext &LLVMCtx)
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ DEF_INSTR(Ret, OP(Ret), ReturnInst)
DEF_INSTR(Call, OP(Call), CallInst)
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)

#ifdef DEF_VALUE
#undef DEF_VALUE
Expand Down
67 changes: 67 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,60 @@ void CallBrInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

Value *GetElementPtrInst::create(Type *Ty, Value *Ptr,
ArrayRef<Value *> IdxList,
BasicBlock::iterator WhereIt,
BasicBlock *WhereBB, Context &Ctx,
const Twine &NameStr) {
auto &Builder = Ctx.getLLVMIRBuilder();
if (WhereIt != WhereBB->end())
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
else
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
SmallVector<llvm::Value *> LLVMIdxList;
LLVMIdxList.reserve(IdxList.size());
for (Value *Idx : IdxList)
LLVMIdxList.push_back(Idx->Val);
llvm::Value *NewV = Builder.CreateGEP(Ty, Ptr->Val, LLVMIdxList, NameStr);
if (auto *NewGEP = dyn_cast<llvm::GetElementPtrInst>(NewV))
return Ctx.createGetElementPtrInst(NewGEP);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

Value *GetElementPtrInst::create(Type *Ty, Value *Ptr,
ArrayRef<Value *> IdxList,
Instruction *InsertBefore, Context &Ctx,
const Twine &NameStr) {
return GetElementPtrInst::create(Ty, Ptr, IdxList,
InsertBefore->getIterator(),
InsertBefore->getParent(), Ctx, NameStr);
}

Value *GetElementPtrInst::create(Type *Ty, Value *Ptr,
ArrayRef<Value *> IdxList,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &NameStr) {
return GetElementPtrInst::create(Ty, Ptr, IdxList, InsertAtEnd->end(),
InsertAtEnd, Ctx, NameStr);
}

Value *GetElementPtrInst::getPointerOperand() const {
return Ctx.getValue(cast<llvm::GetElementPtrInst>(Val)->getPointerOperand());
}

#ifndef NDEBUG
void GetElementPtrInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}

void GetElementPtrInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}

void OpaqueInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
Expand Down Expand Up @@ -1165,6 +1219,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
return It->second.get();
}
case llvm::Instruction::GetElementPtr: {
auto *LLVMGEP = cast<llvm::GetElementPtrInst>(LLVMV);
It->second = std::unique_ptr<GetElementPtrInst>(
new GetElementPtrInst(LLVMGEP, *this));
return It->second.get();
}
default:
break;
}
Expand Down Expand Up @@ -1223,6 +1283,13 @@ CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) {
return cast<CallBrInst>(registerValue(std::move(NewPtr)));
}

GetElementPtrInst *
Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
auto NewPtr =
std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
}

Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
if (It != LLVMValueToValueMap.end())
Expand Down
Loading
Loading