Skip to content

[SandboxIR] Implement UnaryOperator #104509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class CastInst;
class PtrToIntInst;
class BitCastInst;
class AllocaInst;
class UnaryOperator;
class BinaryOperator;
class AtomicCmpXchgInst;

Expand Down Expand Up @@ -250,6 +251,7 @@ class Value {
friend class InvokeInst; // For getting `Val`.
friend class CallBrInst; // For getting `Val`.
friend class GetElementPtrInst; // For getting `Val`.
friend class UnaryOperator; // For getting `Val`.
friend class BinaryOperator; // For getting `Val`.
friend class AtomicCmpXchgInst; // For getting `Val`.
friend class AllocaInst; // For getting `Val`.
Expand Down Expand Up @@ -632,6 +634,7 @@ class Instruction : public sandboxir::User {
friend class InvokeInst; // For getTopmostLLVMInstruction().
friend class CallBrInst; // For getTopmostLLVMInstruction().
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
friend class UnaryOperator; // For getTopmostLLVMInstruction().
friend class BinaryOperator; // For getTopmostLLVMInstruction().
friend class AtomicCmpXchgInst; // For getTopmostLLVMInstruction().
friend class AllocaInst; // For getTopmostLLVMInstruction().
Expand Down Expand Up @@ -1435,6 +1438,47 @@ class GetElementPtrInst final
// TODO: Add missing member functions.
};

class UnaryOperator : public UnaryInstruction {
static Opcode getUnaryOpcode(llvm::Instruction::UnaryOps UnOp) {
switch (UnOp) {
case llvm::Instruction::FNeg:
return Opcode::FNeg;
case llvm::Instruction::UnaryOpsEnd:
llvm_unreachable("Bad UnOp!");
}
llvm_unreachable("Unhandled UnOp!");
}
UnaryOperator(llvm::UnaryOperator *UO, Context &Ctx)
: UnaryInstruction(ClassID::UnOp, getUnaryOpcode(UO->getOpcode()), UO,
Ctx) {}
friend Context; // for constructor.
public:
static Value *create(Instruction::Opcode Op, Value *OpV, BBIterator WhereIt,
BasicBlock *WhereBB, Context &Ctx,
const Twine &Name = "");
static Value *create(Instruction::Opcode Op, Value *OpV,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name = "");
static Value *create(Instruction::Opcode Op, Value *OpV,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name = "");
static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
Value *CopyFrom, BBIterator WhereIt,
BasicBlock *WhereBB, Context &Ctx,
const Twine &Name = "");
static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
Value *CopyFrom,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name = "");
static Value *createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
Value *CopyFrom, BasicBlock *InsertAtEnd,
Context &Ctx, const Twine &Name = "");
/// For isa/dyn_cast.
static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::UnOp;
}
};

class BinaryOperator : public SingleLLVMInstructionImpl<llvm::BinaryOperator> {
static Opcode getBinOpOpcode(llvm::Instruction::BinaryOps BinOp) {
switch (BinOp) {
Expand Down Expand Up @@ -1959,6 +2003,8 @@ class Context {
friend CallBrInst; // For createCallBrInst()
GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
friend GetElementPtrInst; // For createGetElementPtrInst()
UnaryOperator *createUnaryOperator(llvm::UnaryOperator *I);
friend UnaryOperator; // For createUnaryOperator()
BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
friend BinaryOperator; // For createBinaryOperator()
AtomicCmpXchgInst *createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I);
Expand Down
5 changes: 4 additions & 1 deletion llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ DEF_INSTR(Call, OP(Call), CallInst)
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
DEF_INSTR(BinaryOperator, OPCODES( \
DEF_INSTR(UnOp, OPCODES( \
OP(FNeg) \
), UnaryOperator)
DEF_INSTR(BinaryOperator, OPCODES(\
OP(Add) \
OP(FAdd) \
OP(Sub) \
Expand Down
75 changes: 75 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,71 @@ static llvm::Instruction::CastOps getLLVMCastOp(Instruction::Opcode Opc) {
}
}

/// \Returns the LLVM opcode that corresponds to \p Opc.
static llvm::Instruction::UnaryOps getLLVMUnaryOp(Instruction::Opcode Opc) {
switch (Opc) {
case Instruction::Opcode::FNeg:
return static_cast<llvm::Instruction::UnaryOps>(llvm::Instruction::FNeg);
default:
llvm_unreachable("Not a unary op!");
}
}

Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
BBIterator WhereIt, BasicBlock *WhereBB,
Context &Ctx, const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
if (WhereIt == WhereBB->end())
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
else
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
auto *NewLLVMV = Builder.CreateUnOp(getLLVMUnaryOp(Op), OpV->Val, Name);
if (auto *NewUnOpV = dyn_cast<llvm::UnaryOperator>(NewLLVMV)) {
return Ctx.createUnaryOperator(NewUnOpV);
}
assert(isa<llvm::Constant>(NewLLVMV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewLLVMV));
}

Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
return create(Op, OpV, InsertBefore->getIterator(), InsertBefore->getParent(),
Ctx, Name);
}

Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
BasicBlock *InsertAfter, Context &Ctx,
const Twine &Name) {
return create(Op, OpV, InsertAfter->end(), InsertAfter, Ctx, Name);
}

Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
Value *CopyFrom, BBIterator WhereIt,
BasicBlock *WhereBB, Context &Ctx,
const Twine &Name) {
auto *NewV = create(Op, OpV, WhereIt, WhereBB, Ctx, Name);
if (auto *UnI = dyn_cast<llvm::UnaryOperator>(NewV->Val))
UnI->copyIRFlags(CopyFrom->Val);
return NewV;
}

Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
Value *CopyFrom,
Instruction *InsertBefore,
Context &Ctx, const Twine &Name) {
return createWithCopiedFlags(Op, OpV, CopyFrom, InsertBefore->getIterator(),
InsertBefore->getParent(), Ctx, Name);
}

Value *UnaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *OpV,
Value *CopyFrom,
BasicBlock *InsertAtEnd,
Context &Ctx, const Twine &Name) {
return createWithCopiedFlags(Op, OpV, CopyFrom, InsertAtEnd->end(),
InsertAtEnd, Ctx, Name);
}

/// \Returns the LLVM opcode that corresponds to \p Opc.
static llvm::Instruction::BinaryOps getLLVMBinaryOp(Instruction::Opcode Opc) {
switch (Opc) {
Expand Down Expand Up @@ -1729,6 +1794,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
new GetElementPtrInst(LLVMGEP, *this));
return It->second.get();
}
case llvm::Instruction::FNeg: {
auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
It->second = std::unique_ptr<UnaryOperator>(
new UnaryOperator(LLVMUnaryOperator, *this));
return It->second.get();
}
case llvm::Instruction::Add:
case llvm::Instruction::FAdd:
case llvm::Instruction::Sub:
Expand Down Expand Up @@ -1875,6 +1946,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
}
UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
}
BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
Expand Down
126 changes: 126 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,132 @@ define void @foo(i32 %arg, float %farg) {
EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
}

TEST_F(SandboxIRTest, UnaryOperator) {
parseIR(C, R"IR(
define void @foo(float %arg0) {
%fneg = fneg float %arg0
%copyfrom = fadd reassoc float %arg0, 42.0
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto *Arg0 = F.getArg(0);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *I = cast<sandboxir::UnaryOperator>(&*It++);
auto *CopyFrom = cast<sandboxir::BinaryOperator>(&*It++);
auto *Ret = &*It++;
EXPECT_EQ(I->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
EXPECT_EQ(I->getOperand(0), Arg0);

{
// Check create() WhereIt, WhereBB.
auto *NewI =
cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
sandboxir::Instruction::Opcode::FNeg, Arg0,
/*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
"New1"));
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
EXPECT_EQ(NewI->getOperand(0), Arg0);
#ifndef NDEBUG
EXPECT_EQ(NewI->getName(), "New1");
#endif // NDEBUG
EXPECT_EQ(NewI->getNextNode(), Ret);
}
{
// Check create() InsertBefore.
auto *NewI =
cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
sandboxir::Instruction::Opcode::FNeg, Arg0,
/*InsertBefore=*/Ret, Ctx, "New2"));
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
EXPECT_EQ(NewI->getOperand(0), Arg0);
#ifndef NDEBUG
EXPECT_EQ(NewI->getName(), "New2");
#endif // NDEBUG
EXPECT_EQ(NewI->getNextNode(), Ret);
}
{
// Check create() InsertAtEnd.
auto *NewI =
cast<sandboxir::UnaryOperator>(sandboxir::UnaryOperator::create(
sandboxir::Instruction::Opcode::FNeg, Arg0,
/*InsertAtEnd=*/BB, Ctx, "New3"));
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
EXPECT_EQ(NewI->getOperand(0), Arg0);
#ifndef NDEBUG
EXPECT_EQ(NewI->getName(), "New3");
#endif // NDEBUG
EXPECT_EQ(NewI->getParent(), BB);
EXPECT_EQ(NewI->getNextNode(), nullptr);
}
{
// Check create() when it gets folded.
auto *FortyTwo = CopyFrom->getOperand(1);
auto *NewV = sandboxir::UnaryOperator::create(
sandboxir::Instruction::Opcode::FNeg, FortyTwo,
/*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
"Folded");
EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
}

{
// Check createWithCopiedFlags() WhereIt, WhereBB.
auto *NewI = cast<sandboxir::UnaryOperator>(
sandboxir::UnaryOperator::createWithCopiedFlags(
sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
/*WhereIt=*/Ret->getIterator(), /*WhereBB=*/Ret->getParent(), Ctx,
"NewCopyFrom1"));
EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
EXPECT_EQ(NewI->getOperand(0), Arg0);
#ifndef NDEBUG
EXPECT_EQ(NewI->getName(), "NewCopyFrom1");
#endif // NDEBUG
EXPECT_EQ(NewI->getNextNode(), Ret);
}
{
// Check createWithCopiedFlags() InsertBefore,
auto *NewI = cast<sandboxir::UnaryOperator>(
sandboxir::UnaryOperator::createWithCopiedFlags(
sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
/*InsertBefore=*/Ret, Ctx, "NewCopyFrom2"));
EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
EXPECT_EQ(NewI->getOperand(0), Arg0);
#ifndef NDEBUG
EXPECT_EQ(NewI->getName(), "NewCopyFrom2");
#endif // NDEBUG
EXPECT_EQ(NewI->getNextNode(), Ret);
}
{
// Check createWithCopiedFlags() InsertAtEnd,
auto *NewI = cast<sandboxir::UnaryOperator>(
sandboxir::UnaryOperator::createWithCopiedFlags(
sandboxir::Instruction::Opcode::FNeg, Arg0, CopyFrom,
/*InsertAtEnd=*/BB, Ctx, "NewCopyFrom3"));
EXPECT_EQ(NewI->hasAllowReassoc(), CopyFrom->hasAllowReassoc());
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::FNeg);
EXPECT_EQ(NewI->getOperand(0), Arg0);
#ifndef NDEBUG
EXPECT_EQ(NewI->getName(), "NewCopyFrom3");
#endif // NDEBUG
EXPECT_EQ(NewI->getParent(), BB);
EXPECT_EQ(NewI->getNextNode(), nullptr);
}
{
// Check createWithCopiedFlags() when it gets folded.
auto *FortyTwo = CopyFrom->getOperand(1);
auto *NewV = sandboxir::UnaryOperator::createWithCopiedFlags(
sandboxir::Instruction::Opcode::FNeg, FortyTwo, CopyFrom,
/*InsertAtEnd=*/BB, Ctx, "Folded");
EXPECT_TRUE(isa<sandboxir::Constant>(NewV));
}
}

TEST_F(SandboxIRTest, BinaryOperator) {
parseIR(C, R"IR(
define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {
Expand Down
Loading