Skip to content

[SandboxIR] Implement CastInst #101097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class CallInst;
class InvokeInst;
class CallBrInst;
class GetElementPtrInst;
class CastInst;

/// Iterator for the `Use` edges of a User's operands.
/// \Returns the operand `Use` when dereferenced.
Expand Down Expand Up @@ -210,6 +211,7 @@ class Value {
friend class InvokeInst; // For getting `Val`.
friend class CallBrInst; // For getting `Val`.
friend class GetElementPtrInst; // For getting `Val`.
friend class CastInst; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -525,9 +527,8 @@ class BBIterator {
class Instruction : public sandboxir::User {
public:
enum class Opcode {
#define DEF_VALUE(ID, CLASS)
#define DEF_USER(ID, CLASS)
#define OP(OPC) OPC,
#define OPCODES(...) __VA_ARGS__
#define DEF_INSTR(ID, OPC, CLASS) OPC
#include "llvm/SandboxIR/SandboxIRValues.def"
};
Expand All @@ -551,6 +552,7 @@ class Instruction : public sandboxir::User {
friend class InvokeInst; // For getTopmostLLVMInstruction().
friend class CallBrInst; // For getTopmostLLVMInstruction().
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
friend class CastInst; // For getTopmostLLVMInstruction().

/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
Expand Down Expand Up @@ -1282,6 +1284,79 @@ class GetElementPtrInst final : public Instruction {
#endif
};

class CastInst : public Instruction {
static Opcode getCastOpcode(llvm::Instruction::CastOps CastOp) {
switch (CastOp) {
case llvm::Instruction::ZExt:
return Opcode::ZExt;
case llvm::Instruction::SExt:
return Opcode::SExt;
case llvm::Instruction::FPToUI:
return Opcode::FPToUI;
case llvm::Instruction::FPToSI:
return Opcode::FPToSI;
case llvm::Instruction::FPExt:
return Opcode::FPExt;
case llvm::Instruction::PtrToInt:
return Opcode::PtrToInt;
case llvm::Instruction::IntToPtr:
return Opcode::IntToPtr;
case llvm::Instruction::SIToFP:
return Opcode::SIToFP;
case llvm::Instruction::UIToFP:
return Opcode::UIToFP;
case llvm::Instruction::Trunc:
return Opcode::Trunc;
case llvm::Instruction::FPTrunc:
return Opcode::FPTrunc;
case llvm::Instruction::BitCast:
return Opcode::BitCast;
case llvm::Instruction::AddrSpaceCast:
return Opcode::AddrSpaceCast;
case llvm::Instruction::CastOpsEnd:
llvm_unreachable("Bad CastOp!");
}
llvm_unreachable("Unhandled CastOp!");
}
/// Use Context::createCastInst(). Don't call the
/// constructor directly.
CastInst(llvm::CastInst *CI, Context &Ctx)
: Instruction(ClassID::Cast, getCastOpcode(CI->getOpcode()), CI, Ctx) {}
friend Context; // for SBCastInstruction()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}

public:
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
static Value *create(Type *DestTy, Opcode Op, Value *Operand,
BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
const Twine &Name = "");
static Value *create(Type *DestTy, Opcode Op, Value *Operand,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name = "");
static Value *create(Type *DestTy, Opcode Op, Value *Operand,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name = "");
/// For isa/dyn_cast.
static bool classof(const Value *From);
Type *getSrcTy() const { return cast<llvm::CastInst>(Val)->getSrcTy(); }
Type *getDestTy() const { return cast<llvm::CastInst>(Val)->getDestTy(); }
#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::CastInst>(Val) && "Expected CastInst!");
}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
/// an OpaqueInstr.
class OpaqueInst : public sandboxir::Instruction {
Expand Down Expand Up @@ -1438,6 +1513,8 @@ class Context {
friend CallBrInst; // For createCallBrInst()
GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
friend GetElementPtrInst; // For createGetElementPtrInst()
CastInst *createCastInst(llvm::CastInst *I);
friend CastInst; // For createCastInst()

public:
Context(LLVMContext &LLVMCtx)
Expand Down
49 changes: 38 additions & 11 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,42 @@ DEF_USER(Constant, Constant)
#ifndef DEF_INSTR
#define DEF_INSTR(ID, OPCODE, CLASS)
#endif
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)
DEF_INSTR(Call, OP(Call), CallInst)
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)

#ifndef OP
#define OP(OPCODE)
#endif

#ifndef OPCODES
#define OPCODES(...)
#endif
// clang-format off
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)
DEF_INSTR(Call, OP(Call), CallInst)
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
DEF_INSTR(Cast, OPCODES(\
OP(ZExt) \
OP(SExt) \
OP(FPToUI) \
OP(FPToSI) \
OP(FPExt) \
OP(PtrToInt) \
OP(IntToPtr) \
OP(SIToFP) \
OP(UIToFP) \
OP(Trunc) \
OP(FPTrunc) \
OP(BitCast) \
OP(AddrSpaceCast) \
), CastInst)
// clang-format on
#ifdef DEF_VALUE
#undef DEF_VALUE
#endif
Expand All @@ -47,3 +71,6 @@ DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
#ifdef OP
#undef OP
#endif
#ifdef OPCODES
#undef OPCODES
#endif
106 changes: 104 additions & 2 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,10 @@ BBIterator &BBIterator::operator--() {

const char *Instruction::getOpcodeName(Opcode Opc) {
switch (Opc) {
#define DEF_VALUE(ID, CLASS)
#define DEF_USER(ID, CLASS)
#define OP(OPC) \
case Opcode::OPC: \
return #OPC;
#define OPCODES(...) __VA_ARGS__
#define DEF_INSTR(ID, OPC, CLASS) OPC
#include "llvm/SandboxIR/SandboxIRValues.def"
}
Expand Down Expand Up @@ -1050,6 +1049,87 @@ void GetElementPtrInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
switch (Opc) {
case Instruction::Opcode::ZExt:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::ZExt);
case Instruction::Opcode::SExt:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::SExt);
case Instruction::Opcode::FPToUI:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::FPToUI);
case Instruction::Opcode::FPToSI:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::FPToSI);
case Instruction::Opcode::FPExt:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::FPExt);
case Instruction::Opcode::PtrToInt:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::PtrToInt);
case Instruction::Opcode::IntToPtr:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::IntToPtr);
case Instruction::Opcode::SIToFP:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::SIToFP);
case Instruction::Opcode::UIToFP:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::UIToFP);
case Instruction::Opcode::Trunc:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::Trunc);
case Instruction::Opcode::FPTrunc:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::FPTrunc);
case Instruction::Opcode::BitCast:
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::BitCast);
case Instruction::Opcode::AddrSpaceCast:
return static_cast<llvm::Instruction::CastOps>(
llvm::Instruction::AddrSpaceCast);
default:
llvm_unreachable("Opcode not suitable for CastInst!");
}
}

Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand,
BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
const Twine &Name) {
assert(getLLVMCastOp(Op) && "Opcode not suitable for CastInst!");
auto &Builder = Ctx.getLLVMIRBuilder();
if (WhereIt == WhereBB->end())
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
else
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
auto *NewV =
Builder.CreateCast(getLLVMCastOp(Op), Operand->Val, DestTy, Name);
if (auto *NewCI = dyn_cast<llvm::CastInst>(NewV))
return Ctx.createCastInst(NewCI);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
return create(DestTy, Op, Operand, InsertBefore->getIterator(),
InsertBefore->getParent(), Ctx, Name);
}

Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name) {
return create(DestTy, Op, Operand, InsertAtEnd->end(), InsertAtEnd, Ctx,
Name);
}

bool CastInst::classof(const Value *From) {
return From->getSubclassID() == ClassID::Cast;
}

#ifndef NDEBUG
void CastInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}

void CastInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}

void OpaqueInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
Expand Down Expand Up @@ -1225,6 +1305,23 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
new GetElementPtrInst(LLVMGEP, *this));
return It->second.get();
}
case llvm::Instruction::ZExt:
case llvm::Instruction::SExt:
case llvm::Instruction::FPToUI:
case llvm::Instruction::FPToSI:
case llvm::Instruction::FPExt:
case llvm::Instruction::PtrToInt:
case llvm::Instruction::IntToPtr:
case llvm::Instruction::SIToFP:
case llvm::Instruction::UIToFP:
case llvm::Instruction::Trunc:
case llvm::Instruction::FPTrunc:
case llvm::Instruction::BitCast:
case llvm::Instruction::AddrSpaceCast: {
auto *LLVMCast = cast<llvm::CastInst>(LLVMV);
It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
return It->second.get();
}
default:
break;
}
Expand Down Expand Up @@ -1290,6 +1387,11 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
}

CastInst *Context::createCastInst(llvm::CastInst *I) {
auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
return cast<CastInst>(registerValue(std::move(NewPtr)));
}

Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
if (It != LLVMValueToValueMap.end())
Expand Down
Loading
Loading