Skip to content

Commit eb03279

Browse files
authored
[SandboxIR] Implement CastInst (#101097)
This patch implements sandboxir::CastInst which mirrors llvm::CastInst. Just like in llvm::CastInst there are multiple opcodes that correspond to a CastInst, like ZExt, FPToUI etc. These are implemented in follow-up patches.
1 parent 6992ebc commit eb03279

File tree

4 files changed

+382
-15
lines changed

4 files changed

+382
-15
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class CallInst;
9494
class InvokeInst;
9595
class CallBrInst;
9696
class GetElementPtrInst;
97+
class CastInst;
9798

9899
/// Iterator for the `Use` edges of a User's operands.
99100
/// \Returns the operand `Use` when dereferenced.
@@ -210,6 +211,7 @@ class Value {
210211
friend class InvokeInst; // For getting `Val`.
211212
friend class CallBrInst; // For getting `Val`.
212213
friend class GetElementPtrInst; // For getting `Val`.
214+
friend class CastInst; // For getting `Val`.
213215

214216
/// All values point to the context.
215217
Context &Ctx;
@@ -525,9 +527,8 @@ class BBIterator {
525527
class Instruction : public sandboxir::User {
526528
public:
527529
enum class Opcode {
528-
#define DEF_VALUE(ID, CLASS)
529-
#define DEF_USER(ID, CLASS)
530530
#define OP(OPC) OPC,
531+
#define OPCODES(...) __VA_ARGS__
531532
#define DEF_INSTR(ID, OPC, CLASS) OPC
532533
#include "llvm/SandboxIR/SandboxIRValues.def"
533534
};
@@ -551,6 +552,7 @@ class Instruction : public sandboxir::User {
551552
friend class InvokeInst; // For getTopmostLLVMInstruction().
552553
friend class CallBrInst; // For getTopmostLLVMInstruction().
553554
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
555+
friend class CastInst; // For getTopmostLLVMInstruction().
554556

555557
/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
556558
/// order.
@@ -1290,6 +1292,79 @@ class GetElementPtrInst final : public Instruction {
12901292
#endif
12911293
};
12921294

1295+
class CastInst : public Instruction {
1296+
static Opcode getCastOpcode(llvm::Instruction::CastOps CastOp) {
1297+
switch (CastOp) {
1298+
case llvm::Instruction::ZExt:
1299+
return Opcode::ZExt;
1300+
case llvm::Instruction::SExt:
1301+
return Opcode::SExt;
1302+
case llvm::Instruction::FPToUI:
1303+
return Opcode::FPToUI;
1304+
case llvm::Instruction::FPToSI:
1305+
return Opcode::FPToSI;
1306+
case llvm::Instruction::FPExt:
1307+
return Opcode::FPExt;
1308+
case llvm::Instruction::PtrToInt:
1309+
return Opcode::PtrToInt;
1310+
case llvm::Instruction::IntToPtr:
1311+
return Opcode::IntToPtr;
1312+
case llvm::Instruction::SIToFP:
1313+
return Opcode::SIToFP;
1314+
case llvm::Instruction::UIToFP:
1315+
return Opcode::UIToFP;
1316+
case llvm::Instruction::Trunc:
1317+
return Opcode::Trunc;
1318+
case llvm::Instruction::FPTrunc:
1319+
return Opcode::FPTrunc;
1320+
case llvm::Instruction::BitCast:
1321+
return Opcode::BitCast;
1322+
case llvm::Instruction::AddrSpaceCast:
1323+
return Opcode::AddrSpaceCast;
1324+
case llvm::Instruction::CastOpsEnd:
1325+
llvm_unreachable("Bad CastOp!");
1326+
}
1327+
llvm_unreachable("Unhandled CastOp!");
1328+
}
1329+
/// Use Context::createCastInst(). Don't call the
1330+
/// constructor directly.
1331+
CastInst(llvm::CastInst *CI, Context &Ctx)
1332+
: Instruction(ClassID::Cast, getCastOpcode(CI->getOpcode()), CI, Ctx) {}
1333+
friend Context; // for SBCastInstruction()
1334+
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
1335+
return getOperandUseDefault(OpIdx, Verify);
1336+
}
1337+
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
1338+
return {cast<llvm::Instruction>(Val)};
1339+
}
1340+
1341+
public:
1342+
unsigned getUseOperandNo(const Use &Use) const final {
1343+
return getUseOperandNoDefault(Use);
1344+
}
1345+
unsigned getNumOfIRInstrs() const final { return 1u; }
1346+
static Value *create(Type *DestTy, Opcode Op, Value *Operand,
1347+
BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
1348+
const Twine &Name = "");
1349+
static Value *create(Type *DestTy, Opcode Op, Value *Operand,
1350+
Instruction *InsertBefore, Context &Ctx,
1351+
const Twine &Name = "");
1352+
static Value *create(Type *DestTy, Opcode Op, Value *Operand,
1353+
BasicBlock *InsertAtEnd, Context &Ctx,
1354+
const Twine &Name = "");
1355+
/// For isa/dyn_cast.
1356+
static bool classof(const Value *From);
1357+
Type *getSrcTy() const { return cast<llvm::CastInst>(Val)->getSrcTy(); }
1358+
Type *getDestTy() const { return cast<llvm::CastInst>(Val)->getDestTy(); }
1359+
#ifndef NDEBUG
1360+
void verify() const final {
1361+
assert(isa<llvm::CastInst>(Val) && "Expected CastInst!");
1362+
}
1363+
void dump(raw_ostream &OS) const override;
1364+
LLVM_DUMP_METHOD void dump() const override;
1365+
#endif
1366+
};
1367+
12931368
/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
12941369
/// an OpaqueInstr.
12951370
class OpaqueInst : public sandboxir::Instruction {
@@ -1446,6 +1521,8 @@ class Context {
14461521
friend CallBrInst; // For createCallBrInst()
14471522
GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
14481523
friend GetElementPtrInst; // For createGetElementPtrInst()
1524+
CastInst *createCastInst(llvm::CastInst *I);
1525+
friend CastInst; // For createCastInst()
14491526

14501527
public:
14511528
Context(LLVMContext &LLVMCtx)

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,42 @@ DEF_USER(Constant, Constant)
2323
#ifndef DEF_INSTR
2424
#define DEF_INSTR(ID, OPCODE, CLASS)
2525
#endif
26-
// ClassID, Opcode(s), Class
27-
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
28-
DEF_INSTR(Select, OP(Select), SelectInst)
29-
DEF_INSTR(Br, OP(Br), BranchInst)
30-
DEF_INSTR(Load, OP(Load), LoadInst)
31-
DEF_INSTR(Store, OP(Store), StoreInst)
32-
DEF_INSTR(Ret, OP(Ret), ReturnInst)
33-
DEF_INSTR(Call, OP(Call), CallInst)
34-
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
35-
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
36-
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
3726

27+
#ifndef OP
28+
#define OP(OPCODE)
29+
#endif
30+
31+
#ifndef OPCODES
32+
#define OPCODES(...)
33+
#endif
34+
// clang-format off
35+
// ClassID, Opcode(s), Class
36+
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
37+
DEF_INSTR(Select, OP(Select), SelectInst)
38+
DEF_INSTR(Br, OP(Br), BranchInst)
39+
DEF_INSTR(Load, OP(Load), LoadInst)
40+
DEF_INSTR(Store, OP(Store), StoreInst)
41+
DEF_INSTR(Ret, OP(Ret), ReturnInst)
42+
DEF_INSTR(Call, OP(Call), CallInst)
43+
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
44+
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
45+
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
46+
DEF_INSTR(Cast, OPCODES(\
47+
OP(ZExt) \
48+
OP(SExt) \
49+
OP(FPToUI) \
50+
OP(FPToSI) \
51+
OP(FPExt) \
52+
OP(PtrToInt) \
53+
OP(IntToPtr) \
54+
OP(SIToFP) \
55+
OP(UIToFP) \
56+
OP(Trunc) \
57+
OP(FPTrunc) \
58+
OP(BitCast) \
59+
OP(AddrSpaceCast) \
60+
), CastInst)
61+
// clang-format on
3862
#ifdef DEF_VALUE
3963
#undef DEF_VALUE
4064
#endif
@@ -47,3 +71,6 @@ DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
4771
#ifdef OP
4872
#undef OP
4973
#endif
74+
#ifdef OPCODES
75+
#undef OPCODES
76+
#endif

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,10 @@ BBIterator &BBIterator::operator--() {
313313

314314
const char *Instruction::getOpcodeName(Opcode Opc) {
315315
switch (Opc) {
316-
#define DEF_VALUE(ID, CLASS)
317-
#define DEF_USER(ID, CLASS)
318316
#define OP(OPC) \
319317
case Opcode::OPC: \
320318
return #OPC;
319+
#define OPCODES(...) __VA_ARGS__
321320
#define DEF_INSTR(ID, OPC, CLASS) OPC
322321
#include "llvm/SandboxIR/SandboxIRValues.def"
323322
}
@@ -1061,6 +1060,87 @@ void GetElementPtrInst::dump() const {
10611060
dump(dbgs());
10621061
dbgs() << "\n";
10631062
}
1063+
#endif // NDEBUG
1064+
1065+
static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
1066+
switch (Opc) {
1067+
case Instruction::Opcode::ZExt:
1068+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::ZExt);
1069+
case Instruction::Opcode::SExt:
1070+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::SExt);
1071+
case Instruction::Opcode::FPToUI:
1072+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::FPToUI);
1073+
case Instruction::Opcode::FPToSI:
1074+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::FPToSI);
1075+
case Instruction::Opcode::FPExt:
1076+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::FPExt);
1077+
case Instruction::Opcode::PtrToInt:
1078+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::PtrToInt);
1079+
case Instruction::Opcode::IntToPtr:
1080+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::IntToPtr);
1081+
case Instruction::Opcode::SIToFP:
1082+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::SIToFP);
1083+
case Instruction::Opcode::UIToFP:
1084+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::UIToFP);
1085+
case Instruction::Opcode::Trunc:
1086+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::Trunc);
1087+
case Instruction::Opcode::FPTrunc:
1088+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::FPTrunc);
1089+
case Instruction::Opcode::BitCast:
1090+
return static_cast<llvm::Instruction::CastOps>(llvm::Instruction::BitCast);
1091+
case Instruction::Opcode::AddrSpaceCast:
1092+
return static_cast<llvm::Instruction::CastOps>(
1093+
llvm::Instruction::AddrSpaceCast);
1094+
default:
1095+
llvm_unreachable("Opcode not suitable for CastInst!");
1096+
}
1097+
}
1098+
1099+
Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand,
1100+
BBIterator WhereIt, BasicBlock *WhereBB, Context &Ctx,
1101+
const Twine &Name) {
1102+
assert(getLLVMCastOp(Op) && "Opcode not suitable for CastInst!");
1103+
auto &Builder = Ctx.getLLVMIRBuilder();
1104+
if (WhereIt == WhereBB->end())
1105+
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
1106+
else
1107+
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
1108+
auto *NewV =
1109+
Builder.CreateCast(getLLVMCastOp(Op), Operand->Val, DestTy, Name);
1110+
if (auto *NewCI = dyn_cast<llvm::CastInst>(NewV))
1111+
return Ctx.createCastInst(NewCI);
1112+
assert(isa<llvm::Constant>(NewV) && "Expected constant");
1113+
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
1114+
}
1115+
1116+
Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand,
1117+
Instruction *InsertBefore, Context &Ctx,
1118+
const Twine &Name) {
1119+
return create(DestTy, Op, Operand, InsertBefore->getIterator(),
1120+
InsertBefore->getParent(), Ctx, Name);
1121+
}
1122+
1123+
Value *CastInst::create(Type *DestTy, Opcode Op, Value *Operand,
1124+
BasicBlock *InsertAtEnd, Context &Ctx,
1125+
const Twine &Name) {
1126+
return create(DestTy, Op, Operand, InsertAtEnd->end(), InsertAtEnd, Ctx,
1127+
Name);
1128+
}
1129+
1130+
bool CastInst::classof(const Value *From) {
1131+
return From->getSubclassID() == ClassID::Cast;
1132+
}
1133+
1134+
#ifndef NDEBUG
1135+
void CastInst::dump(raw_ostream &OS) const {
1136+
dumpCommonPrefix(OS);
1137+
dumpCommonSuffix(OS);
1138+
}
1139+
1140+
void CastInst::dump() const {
1141+
dump(dbgs());
1142+
dbgs() << "\n";
1143+
}
10641144

10651145
void OpaqueInst::dump(raw_ostream &OS) const {
10661146
dumpCommonPrefix(OS);
@@ -1236,6 +1316,23 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
12361316
new GetElementPtrInst(LLVMGEP, *this));
12371317
return It->second.get();
12381318
}
1319+
case llvm::Instruction::ZExt:
1320+
case llvm::Instruction::SExt:
1321+
case llvm::Instruction::FPToUI:
1322+
case llvm::Instruction::FPToSI:
1323+
case llvm::Instruction::FPExt:
1324+
case llvm::Instruction::PtrToInt:
1325+
case llvm::Instruction::IntToPtr:
1326+
case llvm::Instruction::SIToFP:
1327+
case llvm::Instruction::UIToFP:
1328+
case llvm::Instruction::Trunc:
1329+
case llvm::Instruction::FPTrunc:
1330+
case llvm::Instruction::BitCast:
1331+
case llvm::Instruction::AddrSpaceCast: {
1332+
auto *LLVMCast = cast<llvm::CastInst>(LLVMV);
1333+
It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
1334+
return It->second.get();
1335+
}
12391336
default:
12401337
break;
12411338
}
@@ -1301,6 +1398,11 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
13011398
return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
13021399
}
13031400

1401+
CastInst *Context::createCastInst(llvm::CastInst *I) {
1402+
auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
1403+
return cast<CastInst>(registerValue(std::move(NewPtr)));
1404+
}
1405+
13041406
Value *Context::getValue(llvm::Value *V) const {
13051407
auto It = LLVMValueToValueMap.find(V);
13061408
if (It != LLVMValueToValueMap.end())

0 commit comments

Comments
 (0)