Skip to content

[SandboxIR] Implement SandboxIR Type #106294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 39 additions & 38 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/SandboxIR/Tracker.h"
#include "llvm/SandboxIR/Type.h"
#include "llvm/SandboxIR/Use.h"
#include "llvm/Support/raw_ostream.h"
#include <iterator>
Expand Down Expand Up @@ -386,7 +387,7 @@ class Value {
return Cnt == Num;
}

Type *getType() const { return Val->getType(); }
Type *getType() const;

Context &getContext() const { return Ctx; }

Expand Down Expand Up @@ -574,8 +575,7 @@ class ConstantInt : public Constant {
public:
/// If Ty is a vector type, return a Constant with a splat of the given
/// value. Otherwise return a ConstantInt for the given value.
static ConstantInt *get(Type *Ty, uint64_t V, Context &Ctx,
bool IsSigned = false);
static ConstantInt *get(Type *Ty, uint64_t V, bool IsSigned = false);

// TODO: Implement missing functions.

Expand Down Expand Up @@ -1024,10 +1024,7 @@ class ExtractElementInst final
Value *getIndexOperand() { return getOperand(1); }
const Value *getVectorOperand() const { return getOperand(0); }
const Value *getIndexOperand() const { return getOperand(1); }

VectorType *getVectorOperandType() const {
return cast<VectorType>(getVectorOperand()->getType());
}
VectorType *getVectorOperandType() const;
};

class ShuffleVectorInst final
Expand Down Expand Up @@ -1072,9 +1069,7 @@ class ShuffleVectorInst final
}

/// Overload to return most specific vector type.
VectorType *getType() const {
return cast<llvm::ShuffleVectorInst>(Val)->getType();
}
VectorType *getType() const;

/// Return the shuffle mask value of this instruction for the given element
/// index. Return PoisonMaskElem if the element is undef.
Expand All @@ -1100,7 +1095,7 @@ class ShuffleVectorInst final
Constant *getShuffleMaskForBitcode() const;

static Constant *convertShuffleMaskForBitcode(ArrayRef<int> Mask,
Type *ResultTy, Context &Ctx);
Type *ResultTy);

void setShuffleMask(ArrayRef<int> Mask);

Expand Down Expand Up @@ -1646,9 +1641,7 @@ class ExtractValueInst : public UnaryInstruction {
/// with an extractvalue instruction with the specified parameters.
///
/// Null is returned if the indices are invalid for the specified type.
static Type *getIndexedType(Type *Agg, ArrayRef<unsigned> Idxs) {
return llvm::ExtractValueInst::getIndexedType(Agg, Idxs);
}
static Type *getIndexedType(Type *Agg, ArrayRef<unsigned> Idxs);

using idx_iterator = llvm::ExtractValueInst::idx_iterator;

Expand Down Expand Up @@ -1843,9 +1836,7 @@ class CallBase : public SingleLLVMInstructionImpl<llvm::CallBase> {
Opc == Instruction::ClassID::CallBr;
}

FunctionType *getFunctionType() const {
return cast<llvm::CallBase>(Val)->getFunctionType();
}
FunctionType *getFunctionType() const;

op_iterator data_operands_begin() { return op_begin(); }
const_op_iterator data_operands_begin() const {
Expand Down Expand Up @@ -2261,12 +2252,8 @@ class GetElementPtrInst final
return From->getSubclassID() == ClassID::GetElementPtr;
}

Type *getSourceElementType() const {
return cast<llvm::GetElementPtrInst>(Val)->getSourceElementType();
}
Type *getResultElementType() const {
return cast<llvm::GetElementPtrInst>(Val)->getResultElementType();
}
Type *getSourceElementType() const;
Type *getResultElementType() const;
unsigned getAddressSpace() const {
return cast<llvm::GetElementPtrInst>(Val)->getAddressSpace();
}
Expand All @@ -2290,9 +2277,7 @@ class GetElementPtrInst final
static unsigned getPointerOperandIndex() {
return llvm::GetElementPtrInst::getPointerOperandIndex();
}
Type *getPointerOperandType() const {
return cast<llvm::GetElementPtrInst>(Val)->getPointerOperandType();
}
Type *getPointerOperandType() const;
unsigned getPointerAddressSpace() const {
return cast<llvm::GetElementPtrInst>(Val)->getPointerAddressSpace();
}
Expand Down Expand Up @@ -2843,9 +2828,7 @@ class AllocaInst final : public UnaryInstruction {
return const_cast<AllocaInst *>(this)->getArraySize();
}
/// Overload to return most specific pointer type.
PointerType *getType() const {
return cast<llvm::AllocaInst>(Val)->getType();
}
PointerType *getType() const;
/// Return the address space for the allocation.
unsigned getAddressSpace() const {
return cast<llvm::AllocaInst>(Val)->getAddressSpace();
Expand All @@ -2861,9 +2844,7 @@ class AllocaInst final : public UnaryInstruction {
return cast<llvm::AllocaInst>(Val)->getAllocationSizeInBits(DL);
}
/// Return the type that is being allocated by the instruction.
Type *getAllocatedType() const {
return cast<llvm::AllocaInst>(Val)->getAllocatedType();
}
Type *getAllocatedType() const;
/// for use only in special circumstances that need to generically
/// transform a whole instruction (eg: IR linking and vectorization).
void setAllocatedType(Type *Ty);
Expand Down Expand Up @@ -2945,8 +2926,8 @@ class CastInst : public UnaryInstruction {
const Twine &Name = "");
/// For isa/dyn_cast.
static bool classof(const Value *From);
Type *getSrcTy() const { return cast<llvm::CastInst>(Val)->getSrcTy(); }
Type *getDestTy() const { return cast<llvm::CastInst>(Val)->getDestTy(); }
Type *getSrcTy() const;
Type *getDestTy() const;
};

/// Instruction that can have a nneg flag (zext/uitofp).
Expand Down Expand Up @@ -3126,13 +3107,25 @@ class OpaqueInst : public SingleLLVMInstructionImpl<llvm::Instruction> {
class Context {
protected:
LLVMContext &LLVMCtx;
friend class Type; // For LLVMCtx.
friend class PointerType; // For LLVMCtx.
Tracker IRTracker;

/// Maps LLVM Value to the corresponding sandboxir::Value. Owns all
/// SandboxIR objects.
DenseMap<llvm::Value *, std::unique_ptr<sandboxir::Value>>
LLVMValueToValueMap;

/// Type has a protected destructor to prohibit the user from managing the
/// lifetime of the Type objects. Context is friend of Type, and this custom
/// deleter can destroy Type.
struct TypeDeleter {
void operator()(Type *Ty) { delete Ty; }
};
/// Maps LLVM Type to the corresonding sandboxir::Type. Owns all Sandbox IR
/// Type objects.
DenseMap<llvm::Type *, std::unique_ptr<Type, TypeDeleter>> LLVMTypeToTypeMap;

/// Remove \p V from the maps and returns the unique_ptr.
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
/// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively
Expand Down Expand Up @@ -3167,7 +3160,6 @@ class Context {
/// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will
/// also create all contents of the block.
BasicBlock *createBasicBlock(llvm::BasicBlock *BB);

friend class BasicBlock; // For getOrCreateValue().

IRBuilder<ConstantFolder> LLVMIRBuilder;
Expand Down Expand Up @@ -3257,6 +3249,17 @@ class Context {
const sandboxir::Value *getValue(const llvm::Value *V) const {
return getValue(const_cast<llvm::Value *>(V));
}

Type *getType(llvm::Type *LLVMTy) {
if (LLVMTy == nullptr)
return nullptr;
auto Pair = LLVMTypeToTypeMap.insert({LLVMTy, nullptr});
auto It = Pair.first;
if (Pair.second)
It->second = std::unique_ptr<Type, TypeDeleter>(new Type(LLVMTy, *this));
return It->second.get();
}

/// Create a sandboxir::Function for an existing LLVM IR \p F, including all
/// blocks and instructions.
/// This is the main API function for creating Sandbox IR.
Expand Down Expand Up @@ -3303,9 +3306,7 @@ class Function : public Constant {
LLVMBBToBB BBGetter(Ctx);
return iterator(cast<llvm::Function>(Val)->end(), BBGetter);
}
FunctionType *getFunctionType() const {
return cast<llvm::Function>(Val)->getFunctionType();
}
FunctionType *getFunctionType() const;

#ifndef NDEBUG
void verify() const final {
Expand Down
Loading
Loading