Skip to content

Commit ec29660

Browse files
authored
[SandboxIR] Implement UnaryOperator (llvm#104509)
This patch implements sandboxir::UnaryOperator mirroring llvm::UnaryOperator.
1 parent 3e1d4ec commit ec29660

File tree

4 files changed

+251
-1
lines changed

4 files changed

+251
-1
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class CastInst;
130130
class PtrToIntInst;
131131
class BitCastInst;
132132
class AllocaInst;
133+
class UnaryOperator;
133134
class BinaryOperator;
134135
class AtomicCmpXchgInst;
135136

@@ -250,6 +251,7 @@ class Value {
250251
friend class InvokeInst; // For getting `Val`.
251252
friend class CallBrInst; // For getting `Val`.
252253
friend class GetElementPtrInst; // For getting `Val`.
254+
friend class UnaryOperator; // For getting `Val`.
253255
friend class BinaryOperator; // For getting `Val`.
254256
friend class AtomicCmpXchgInst; // For getting `Val`.
255257
friend class AllocaInst; // For getting `Val`.
@@ -632,6 +634,7 @@ class Instruction : public sandboxir::User {
632634
friend class InvokeInst; // For getTopmostLLVMInstruction().
633635
friend class CallBrInst; // For getTopmostLLVMInstruction().
634636
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
637+
friend class UnaryOperator; // For getTopmostLLVMInstruction().
635638
friend class BinaryOperator; // For getTopmostLLVMInstruction().
636639
friend class AtomicCmpXchgInst; // For getTopmostLLVMInstruction().
637640
friend class AllocaInst; // For getTopmostLLVMInstruction().
@@ -1435,6 +1438,47 @@ class GetElementPtrInst final
14351438
// TODO: Add missing member functions.
14361439
};
14371440

1441+
class UnaryOperator : public UnaryInstruction {
1442+
static Opcode getUnaryOpcode(llvm::Instruction::UnaryOps UnOp) {
1443+
switch (UnOp) {
1444+
case llvm::Instruction::FNeg:
1445+
return Opcode::FNeg;
1446+
case llvm::Instruction::UnaryOpsEnd:
1447+
llvm_unreachable("Bad UnOp!");
1448+
}
1449+
llvm_unreachable("Unhandled UnOp!");
1450+
}
1451+
UnaryOperator(llvm::UnaryOperator *UO, Context &Ctx)
1452+
: UnaryInstruction(ClassID::UnOp, getUnaryOpcode(UO->getOpcode()), UO,
1453+
Ctx) {}
1454+
friend Context; // for constructor.
1455+
public:
1456+
static Value *create(Instruction::Opcode Op, Value *OpV, BBIterator WhereIt,
1457+
BasicBlock *WhereBB, Context &Ctx,
1458+
const Twine &Name = "");
1459+
static Value *create(Instruction::Opcode Op, Value *OpV,
1460+
Instruction *InsertBefore, Context &Ctx,
1461+
const Twine &Name = "");
1462+
static Value *create(Instruction::Opcode Op, Value *OpV,
1463+
BasicBlock *InsertAtEnd, Context &Ctx,
1464+
const Twine &Name = "");
1465+
static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
1466+
Value *CopyFrom, BBIterator WhereIt,
1467+
BasicBlock *WhereBB, Context &Ctx,
1468+
const Twine &Name = "");
1469+
static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
1470+
Value *CopyFrom,
1471+
Instruction *InsertBefore, Context &Ctx,
1472+
const Twine &Name = "");
1473+
static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
1474+
Value *CopyFrom, BasicBlock *InsertAtEnd,
1475+
Context &Ctx, const Twine &Name = "");
1476+
/// For isa/dyn_cast.
1477+
static bool classof(const Value *From) {
1478+
return From->getSubclassID() == ClassID::UnOp;
1479+
}
1480+
};
1481+
14381482
class BinaryOperator : public SingleLLVMInstructionImpl<llvm::BinaryOperator> {
14391483
static Opcode getBinOpOpcode(llvm::Instruction::BinaryOps BinOp) {
14401484
switch (BinOp) {
@@ -1959,6 +2003,8 @@ class Context {
19592003
friend CallBrInst; // For createCallBrInst()
19602004
GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
19612005
friend GetElementPtrInst; // For createGetElementPtrInst()
2006+
UnaryOperator *createUnaryOperator(llvm::UnaryOperator *I);
2007+
friend UnaryOperator; // For createUnaryOperator()
19622008
BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
19632009
friend BinaryOperator; // For createBinaryOperator()
19642010
AtomicCmpXchgInst *createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I);

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ DEF_INSTR(Call, OP(Call), CallInst)
4545
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
4646
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
4747
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
48-
DEF_INSTR(BinaryOperator, OPCODES( \
48+
DEF_INSTR(UnOp, OPCODES( \
49+
OP(FNeg) \
50+
), UnaryOperator)
51+
DEF_INSTR(BinaryOperator, OPCODES(\
4952
OP(Add) \
5053
OP(FAdd) \
5154
OP(Sub) \

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,71 @@ static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
12191219
}
12201220
}
12211221

1222+
/// \Returns the LLVM opcode that corresponds to \p Opc.
1223+
static llvm::Instruction::UnaryOps getLLVMUnaryOp(Instruction::Opcode Opc) {
1224+
switch (Opc) {
1225+
case Instruction::Opcode::FNeg:
1226+
return static_cast<llvm::Instruction::UnaryOps>(llvm::Instruction::FNeg);
1227+
default:
1228+
llvm_unreachable("Not a unary op!");
1229+
}
1230+
}
1231+
1232+
Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
1233+
BBIterator WhereIt, BasicBlock *WhereBB,
1234+
Context &Ctx, const Twine &Name) {
1235+
auto &Builder = Ctx.getLLVMIRBuilder();
1236+
if (WhereIt == WhereBB->end())
1237+
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
1238+
else
1239+
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
1240+
auto *NewLLVMV = Builder.CreateUnOp(getLLVMUnaryOp(Op), OpV->Val, Name);
1241+
if (auto *NewUnOpV = dyn_cast<llvm::UnaryOperator>(NewLLVMV)) {
1242+
return Ctx.createUnaryOperator(NewUnOpV);
1243+
}
1244+
assert(isa<llvm::Constant>(NewLLVMV) && "Expected constant");
1245+
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewLLVMV));
1246+
}
1247+
1248+
Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
1249+
Instruction *InsertBefore, Context &Ctx,
1250+
const Twine &Name) {
1251+
return create(Op, OpV, InsertBefore->getIterator(), InsertBefore->getParent(),
1252+
Ctx, Name);
1253+
}
1254+
1255+
Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
1256+
BasicBlock *InsertAfter, Context &Ctx,
1257+
const Twine &Name) {
1258+
return create(Op, OpV, InsertAfter->end(), InsertAfter, Ctx, Name);
1259+
}
1260+
1261+
Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
1262+
Value *CopyFrom, BBIterator WhereIt,
1263+
BasicBlock *WhereBB, Context &Ctx,
1264+
const Twine &Name) {
1265+
auto *NewV = create(Op, OpV, WhereIt, WhereBB, Ctx, Name);
1266+
if (auto *UnI = dyn_cast<llvm::UnaryOperator>(NewV->Val))
1267+
UnI->copyIRFlags(CopyFrom->Val);
1268+
return NewV;
1269+
}
1270+
1271+
Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
1272+
Value *CopyFrom,
1273+
Instruction *InsertBefore,
1274+
Context &Ctx, const Twine &Name) {
1275+
return createWithCopiedFlags(Op, OpV, CopyFrom, InsertBefore->getIterator(),
1276+
InsertBefore->getParent(), Ctx, Name);
1277+
}
1278+
1279+
Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
1280+
Value *CopyFrom,
1281+
BasicBlock *InsertAtEnd,
1282+
Context &Ctx, const Twine &Name) {
1283+
return createWithCopiedFlags(Op, OpV, CopyFrom, InsertAtEnd->end(),
1284+
InsertAtEnd, Ctx, Name);
1285+
}
1286+
12221287
/// \Returns the LLVM opcode that corresponds to \p Opc.
12231288
static llvm::Instruction::BinaryOps getLLVMBinaryOp(Instruction::Opcode Opc) {
12241289
switch (Opc) {
@@ -1729,6 +1794,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
17291794
new GetElementPtrInst(LLVMGEP, *this));
17301795
return It->second.get();
17311796
}
1797+
case llvm::Instruction::FNeg: {
1798+
auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
1799+
It->second = std::unique_ptr<UnaryOperator>(
1800+
new UnaryOperator(LLVMUnaryOperator, *this));
1801+
return It->second.get();
1802+
}
17321803
case llvm::Instruction::Add:
17331804
case llvm::Instruction::FAdd:
17341805
case llvm::Instruction::Sub:
@@ -1875,6 +1946,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
18751946
std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
18761947
return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
18771948
}
1949+
UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
1950+
auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
1951+
return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
1952+
}
18781953
BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
18791954
auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
18801955
return cast<BinaryOperator>(registerValue(std::move(NewPtr)));

llvm/unittests/SandboxIR/SandboxIRTest.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,6 +1620,132 @@ define void @foo(i32 %arg, float %farg) {
16201620
EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
16211621
}
16221622

1623+
TEST_F(SandboxIRTest, UnaryOperator) {
1624+
parseIR(C, R"IR(
1625+
define void @foo(float %arg0) {
1626+
%fneg = fneg float %arg0
1627+
%copyfrom = fadd reassoc float %arg0, 42.0
1628+
ret void
1629+
}
1630+
)IR");
1631+
Function &LLVMF = *M->getFunction("foo");
1632+
sandboxir::Context Ctx(C);
1633+
1634+
auto &F = *Ctx.createFunction(&LLVMF);
1635+
auto *Arg0 = F.getArg(0);
1636+
auto *BB = &*F.begin();
1637+
auto It = BB->begin();
1638+
auto *I = cast<sandboxir::UnaryOperator>(&*It++);
1639+
auto *CopyFrom = cast<sandboxir::BinaryOperator>(&*It++);
1640+
auto *Ret = &*It++;
1641+
EXPECT_EQ(I->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
1642+
EXPECT_EQ(I->getOperand(0), Arg0);
1643+
1644+
{
1645+
// Check create() WhereIt, WhereBB.
1646+
auto *NewI =
1647+
cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
1648+
sandboxir::Instruction::Opcode::FNeg, Arg0,
1649+
/*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
1650+
"New1"));
1651+
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
1652+
EXPECT_EQ(NewI->getOperand(0), Arg0);
1653+
#ifndef NDEBUG
1654+
EXPECT_EQ(NewI->getName(), "New1");
1655+
#endif // NDEBUG
1656+
EXPECT_EQ(NewI->getNextNode(), Ret);
1657+
}
1658+
{
1659+
// Check create() InsertBefore.
1660+
auto *NewI =
1661+
cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
1662+
sandboxir::Instruction::Opcode::FNeg, Arg0,
1663+
/*InsertBefore=*/Ret, Ctx, "New2"));
1664+
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
1665+
EXPECT_EQ(NewI->getOperand(0), Arg0);
1666+
#ifndef NDEBUG
1667+
EXPECT_EQ(NewI->getName(), "New2");
1668+
#endif // NDEBUG
1669+
EXPECT_EQ(NewI->getNextNode(), Ret);
1670+
}
1671+
{
1672+
// Check create() InsertAtEnd.
1673+
auto *NewI =
1674+
cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
1675+
sandboxir::Instruction::Opcode::FNeg, Arg0,
1676+
/*InsertAtEnd=*/BB, Ctx, "New3"));
1677+
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
1678+
EXPECT_EQ(NewI->getOperand(0), Arg0);
1679+
#ifndef NDEBUG
1680+
EXPECT_EQ(NewI->getName(), "New3");
1681+
#endif // NDEBUG
1682+
EXPECT_EQ(NewI->getParent(), BB);
1683+
EXPECT_EQ(NewI->getNextNode(), nullptr);
1684+
}
1685+
{
1686+
// Check create() when it gets folded.
1687+
auto *FortyTwo = CopyFrom->getOperand(1);
1688+
auto *NewV = sandboxir::UnaryOperator::create(
1689+
sandboxir::Instruction::Opcode::FNeg, FortyTwo,
1690+
/*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
1691+
"Folded");
1692+
EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
1693+
}
1694+
1695+
{
1696+
// Check createWithCopiedFlags() WhereIt, WhereBB.
1697+
auto *NewI = cast<sandboxir::UnaryOperator>(
1698+
sandboxir::UnaryOperator::createWithCopiedFlags(
1699+
sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
1700+
/*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
1701+
"NewCopyFrom1"));
1702+
EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
1703+
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
1704+
EXPECT_EQ(NewI->getOperand(0), Arg0);
1705+
#ifndef NDEBUG
1706+
EXPECT_EQ(NewI->getName(), "NewCopyFrom1");
1707+
#endif // NDEBUG
1708+
EXPECT_EQ(NewI->getNextNode(), Ret);
1709+
}
1710+
{
1711+
// Check createWithCopiedFlags() InsertBefore,
1712+
auto *NewI = cast<sandboxir::UnaryOperator>(
1713+
sandboxir::UnaryOperator::createWithCopiedFlags(
1714+
sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
1715+
/*InsertBefore=*/Ret, Ctx, "NewCopyFrom2"));
1716+
EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
1717+
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
1718+
EXPECT_EQ(NewI->getOperand(0), Arg0);
1719+
#ifndef NDEBUG
1720+
EXPECT_EQ(NewI->getName(), "NewCopyFrom2");
1721+
#endif // NDEBUG
1722+
EXPECT_EQ(NewI->getNextNode(), Ret);
1723+
}
1724+
{
1725+
// Check createWithCopiedFlags() InsertAtEnd,
1726+
auto *NewI = cast<sandboxir::UnaryOperator>(
1727+
sandboxir::UnaryOperator::createWithCopiedFlags(
1728+
sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
1729+
/*InsertAtEnd=*/BB, Ctx, "NewCopyFrom3"));
1730+
EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
1731+
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
1732+
EXPECT_EQ(NewI->getOperand(0), Arg0);
1733+
#ifndef NDEBUG
1734+
EXPECT_EQ(NewI->getName(), "NewCopyFrom3");
1735+
#endif // NDEBUG
1736+
EXPECT_EQ(NewI->getParent(), BB);
1737+
EXPECT_EQ(NewI->getNextNode(), nullptr);
1738+
}
1739+
{
1740+
// Check createWithCopiedFlags() when it gets folded.
1741+
auto *FortyTwo = CopyFrom->getOperand(1);
1742+
auto *NewV = sandboxir::UnaryOperator::createWithCopiedFlags(
1743+
sandboxir::Instruction::Opcode::FNeg, FortyTwo, CopyFrom,
1744+
/*InsertAtEnd=*/BB, Ctx, "Folded");
1745+
EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
1746+
}
1747+
}
1748+
16231749
TEST_F(SandboxIRTest, BinaryOperator) {
16241750
parseIR(C, R"IR(
16251751
define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {

0 commit comments

Comments
 (0)