Skip to content

[SandboxIR] Implement AtomicRMWInst #104529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class BitCastInst;
class AllocaInst;
class UnaryOperator;
class BinaryOperator;
class AtomicRMWInst;
class AtomicCmpXchgInst;

/// Iterator for the `Use` edges of a User's operands.
Expand Down Expand Up @@ -253,6 +254,7 @@ class Value {
friend class GetElementPtrInst; // For getting `Val`.
friend class UnaryOperator; // For getting `Val`.
friend class BinaryOperator; // For getting `Val`.
friend class AtomicRMWInst; // For getting `Val`.
friend class AtomicCmpXchgInst; // For getting `Val`.
friend class AllocaInst; // For getting `Val`.
friend class CastInst; // For getting `Val`.
Expand Down Expand Up @@ -636,6 +638,7 @@ class Instruction : public sandboxir::User {
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
friend class UnaryOperator; // For getTopmostLLVMInstruction().
friend class BinaryOperator; // For getTopmostLLVMInstruction().
friend class AtomicRMWInst; // For getTopmostLLVMInstruction().
friend class AtomicCmpXchgInst; // For getTopmostLLVMInstruction().
friend class AllocaInst; // For getTopmostLLVMInstruction().
friend class CastInst; // For getTopmostLLVMInstruction().
Expand Down Expand Up @@ -1559,6 +1562,77 @@ class BinaryOperator : public SingleLLVMInstructionImpl<llvm::BinaryOperator> {
void swapOperands() { swapOperandsInternal(0, 1); }
};

class AtomicRMWInst : public SingleLLVMInstructionImpl<llvm::AtomicRMWInst> {
AtomicRMWInst(llvm::AtomicRMWInst *Atomic, Context &Ctx)
: SingleLLVMInstructionImpl(ClassID::AtomicRMW,
Instruction::Opcode::AtomicRMW, Atomic, Ctx) {
}
friend class Context; // For constructor.

public:
using BinOp = llvm::AtomicRMWInst::BinOp;
BinOp getOperation() const {
return cast<llvm::AtomicRMWInst>(Val)->getOperation();
}
static StringRef getOperationName(BinOp Op) {
return llvm::AtomicRMWInst::getOperationName(Op);
}
static bool isFPOperation(BinOp Op) {
return llvm::AtomicRMWInst::isFPOperation(Op);
}
void setOperation(BinOp Op) {
cast<llvm::AtomicRMWInst>(Val)->setOperation(Op);
}
Align getAlign() const { return cast<llvm::AtomicRMWInst>(Val)->getAlign(); }
void setAlignment(Align Align);
bool isVolatile() const {
return cast<llvm::AtomicRMWInst>(Val)->isVolatile();
}
void setVolatile(bool V);
AtomicOrdering getOrdering() const {
return cast<llvm::AtomicRMWInst>(Val)->getOrdering();
}
void setOrdering(AtomicOrdering Ordering);
SyncScope::ID getSyncScopeID() const {
return cast<llvm::AtomicRMWInst>(Val)->getSyncScopeID();
}
void setSyncScopeID(SyncScope::ID SSID);
Value *getPointerOperand();
const Value *getPointerOperand() const {
return const_cast<AtomicRMWInst *>(this)->getPointerOperand();
}
Value *getValOperand();
const Value *getValOperand() const {
return const_cast<AtomicRMWInst *>(this)->getValOperand();
}
unsigned getPointerAddressSpace() const {
return cast<llvm::AtomicRMWInst>(Val)->getPointerAddressSpace();
}
bool isFloatingPointOperation() const {
return cast<llvm::AtomicRMWInst>(Val)->isFloatingPointOperation();
}
static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::AtomicRMW;
}

static AtomicRMWInst *create(BinOp Op, Value *Ptr, Value *Val,
MaybeAlign Align, AtomicOrdering Ordering,
BBIterator WhereIt, BasicBlock *WhereBB,
Context &Ctx,
SyncScope::ID SSID = SyncScope::System,
const Twine &Name = "");
static AtomicRMWInst *create(BinOp Op, Value *Ptr, Value *Val,
MaybeAlign Align, AtomicOrdering Ordering,
Instruction *InsertBefore, Context &Ctx,
SyncScope::ID SSID = SyncScope::System,
const Twine &Name = "");
static AtomicRMWInst *create(BinOp Op, Value *Ptr, Value *Val,
MaybeAlign Align, AtomicOrdering Ordering,
BasicBlock *InsertAtEnd, Context &Ctx,
SyncScope::ID SSID = SyncScope::System,
const Twine &Name = "");
};

class AtomicCmpXchgInst
: public SingleLLVMInstructionImpl<llvm::AtomicCmpXchgInst> {
AtomicCmpXchgInst(llvm::AtomicCmpXchgInst *Atomic, Context &Ctx)
Expand Down Expand Up @@ -2007,6 +2081,8 @@ class Context {
friend UnaryOperator; // For createUnaryOperator()
BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
friend BinaryOperator; // For createBinaryOperator()
AtomicRMWInst *createAtomicRMWInst(llvm::AtomicRMWInst *I);
friend AtomicRMWInst; // For createAtomicRMWInst()
AtomicCmpXchgInst *createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I);
friend AtomicCmpXchgInst; // For createAtomicCmpXchgInst()
AllocaInst *createAllocaInst(llvm::AllocaInst *I);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ DEF_INSTR(BinaryOperator, OPCODES(\
OP(Or) \
OP(Xor) \
), BinaryOperator)
DEF_INSTR(AtomicRMW, OP(AtomicRMW), AtomicRMWInst)
DEF_INSTR(AtomicCmpXchg, OP(AtomicCmpXchg), AtomicCmpXchgInst)
DEF_INSTR(Alloca, OP(Alloca), AllocaInst)
DEF_INSTR(Cast, OPCODES(\
Expand Down
78 changes: 78 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,74 @@ Value *BinaryOperator::createWithCopiedFlags(Instruction::Opcode Op, Value *LHS,
InsertAtEnd, Ctx, Name);
}

void AtomicRMWInst::setAlignment(Align Align) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&AtomicRMWInst::getAlign,
&AtomicRMWInst::setAlignment>>(this);
cast<llvm::AtomicRMWInst>(Val)->setAlignment(Align);
}

void AtomicRMWInst::setVolatile(bool V) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&AtomicRMWInst::isVolatile,
&AtomicRMWInst::setVolatile>>(this);
cast<llvm::AtomicRMWInst>(Val)->setVolatile(V);
}

void AtomicRMWInst::setOrdering(AtomicOrdering Ordering) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&AtomicRMWInst::getOrdering,
&AtomicRMWInst::setOrdering>>(this);
cast<llvm::AtomicRMWInst>(Val)->setOrdering(Ordering);
}

void AtomicRMWInst::setSyncScopeID(SyncScope::ID SSID) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&AtomicRMWInst::getSyncScopeID,
&AtomicRMWInst::setSyncScopeID>>(this);
cast<llvm::AtomicRMWInst>(Val)->setSyncScopeID(SSID);
}

Value *AtomicRMWInst::getPointerOperand() {
return Ctx.getValue(cast<llvm::AtomicRMWInst>(Val)->getPointerOperand());
}

Value *AtomicRMWInst::getValOperand() {
return Ctx.getValue(cast<llvm::AtomicRMWInst>(Val)->getValOperand());
}

AtomicRMWInst *AtomicRMWInst::create(BinOp Op, Value *Ptr, Value *Val,
MaybeAlign Align, AtomicOrdering Ordering,
BBIterator WhereIt, BasicBlock *WhereBB,
Context &Ctx, SyncScope::ID SSID,
const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
if (WhereIt == WhereBB->end())
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
else
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
auto *LLVMAtomicRMW =
Builder.CreateAtomicRMW(Op, Ptr->Val, Val->Val, Align, Ordering, SSID);
LLVMAtomicRMW->setName(Name);
return Ctx.createAtomicRMWInst(LLVMAtomicRMW);
}

AtomicRMWInst *AtomicRMWInst::create(BinOp Op, Value *Ptr, Value *Val,
MaybeAlign Align, AtomicOrdering Ordering,
Instruction *InsertBefore, Context &Ctx,
SyncScope::ID SSID, const Twine &Name) {
return create(Op, Ptr, Val, Align, Ordering, InsertBefore->getIterator(),
InsertBefore->getParent(), Ctx, SSID, Name);
}

AtomicRMWInst *AtomicRMWInst::create(BinOp Op, Value *Ptr, Value *Val,
MaybeAlign Align, AtomicOrdering Ordering,
BasicBlock *InsertAtEnd, Context &Ctx,
SyncScope::ID SSID, const Twine &Name) {
return create(Op, Ptr, Val, Align, Ordering, InsertAtEnd->end(), InsertAtEnd,
Ctx, SSID, Name);
}

void AtomicCmpXchgInst::setSyncScopeID(SyncScope::ID SSID) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&AtomicCmpXchgInst::getSyncScopeID,
Expand Down Expand Up @@ -1823,6 +1891,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
new BinaryOperator(LLVMBinaryOperator, *this));
return It->second.get();
}
case llvm::Instruction::AtomicRMW: {
auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(LLVMV);
It->second =
std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(LLVMAtomicRMW, *this));
return It->second.get();
}
case llvm::Instruction::AtomicCmpXchg: {
auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(LLVMV);
It->second = std::unique_ptr<AtomicCmpXchgInst>(
Expand Down Expand Up @@ -1954,6 +2028,10 @@ BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
}
AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) {
auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this));
return cast<AtomicRMWInst>(registerValue(std::move(NewPtr)));
}
AtomicCmpXchgInst *
Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
auto NewPtr =
Expand Down
163 changes: 163 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,169 @@ define void @foo(i8 %arg0, i8 %arg1, float %farg0, float %farg1) {
}
}

TEST_F(SandboxIRTest, AtomicRMWInst) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %arg) {
%atomicrmw = atomicrmw add ptr %ptr, i8 %arg acquire, align 128
ret void
}
)IR");
llvm::Function &LLVMF = *M->getFunction("foo");
llvm::BasicBlock *LLVMBB = &*LLVMF.begin();
auto LLVMIt = LLVMBB->begin();
auto *LLVMRMW = cast<llvm::AtomicRMWInst>(&*LLVMIt++);

sandboxir::Context Ctx(C);
sandboxir::Function *F = Ctx.createFunction(&LLVMF);
auto *Ptr = F->getArg(0);
auto *Arg = F->getArg(1);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *RMW = cast<sandboxir::AtomicRMWInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

// Check getOperationName().
EXPECT_EQ(
sandboxir::AtomicRMWInst::getOperationName(
sandboxir::AtomicRMWInst::BinOp::Add),
llvm::AtomicRMWInst::getOperationName(llvm::AtomicRMWInst::BinOp::Add));
// Check isFPOperation().
EXPECT_EQ(
sandboxir::AtomicRMWInst::isFPOperation(
sandboxir::AtomicRMWInst::BinOp::Add),
llvm::AtomicRMWInst::isFPOperation(llvm::AtomicRMWInst::BinOp::Add));
EXPECT_FALSE(sandboxir::AtomicRMWInst::isFPOperation(
sandboxir::AtomicRMWInst::BinOp::Add));
EXPECT_TRUE(sandboxir::AtomicRMWInst::isFPOperation(
sandboxir::AtomicRMWInst::BinOp::FAdd));
// Check setOperation(), getOperation().
EXPECT_EQ(RMW->getOperation(), LLVMRMW->getOperation());
RMW->setOperation(sandboxir::AtomicRMWInst::BinOp::Sub);
EXPECT_EQ(RMW->getOperation(), sandboxir::AtomicRMWInst::BinOp::Sub);
RMW->setOperation(sandboxir::AtomicRMWInst::BinOp::Add);
// Check getAlign().
EXPECT_EQ(RMW->getAlign(), LLVMRMW->getAlign());
auto OrigAlign = RMW->getAlign();
Align NewAlign(256);
EXPECT_NE(NewAlign, OrigAlign);
RMW->setAlignment(NewAlign);
EXPECT_EQ(RMW->getAlign(), NewAlign);
RMW->setAlignment(OrigAlign);
EXPECT_EQ(RMW->getAlign(), OrigAlign);
// Check isVolatile(), setVolatile().
EXPECT_EQ(RMW->isVolatile(), LLVMRMW->isVolatile());
bool OrigV = RMW->isVolatile();
bool NewV = true;
EXPECT_NE(NewV, OrigV);
RMW->setVolatile(NewV);
EXPECT_EQ(RMW->isVolatile(), NewV);
RMW->setVolatile(OrigV);
EXPECT_EQ(RMW->isVolatile(), OrigV);
// Check getOrdering(), setOrdering().
EXPECT_EQ(RMW->getOrdering(), LLVMRMW->getOrdering());
auto OldOrdering = RMW->getOrdering();
auto NewOrdering = AtomicOrdering::Monotonic;
EXPECT_NE(NewOrdering, OldOrdering);
RMW->setOrdering(NewOrdering);
EXPECT_EQ(RMW->getOrdering(), NewOrdering);
RMW->setOrdering(OldOrdering);
EXPECT_EQ(RMW->getOrdering(), OldOrdering);
// Check getSyncScopeID(), setSyncScopeID().
EXPECT_EQ(RMW->getSyncScopeID(), LLVMRMW->getSyncScopeID());
auto OrigSSID = RMW->getSyncScopeID();
SyncScope::ID NewSSID = SyncScope::SingleThread;
EXPECT_NE(NewSSID, OrigSSID);
RMW->setSyncScopeID(NewSSID);
EXPECT_EQ(RMW->getSyncScopeID(), NewSSID);
RMW->setSyncScopeID(OrigSSID);
EXPECT_EQ(RMW->getSyncScopeID(), OrigSSID);
// Check getPointerOperand().
EXPECT_EQ(RMW->getPointerOperand(),
Ctx.getValue(LLVMRMW->getPointerOperand()));
// Check getValOperand().
EXPECT_EQ(RMW->getValOperand(), Ctx.getValue(LLVMRMW->getValOperand()));
// Check getPointerAddressSpace().
EXPECT_EQ(RMW->getPointerAddressSpace(), LLVMRMW->getPointerAddressSpace());
// Check isFloatingPointOperation().
EXPECT_EQ(RMW->isFloatingPointOperation(),
LLVMRMW->isFloatingPointOperation());

Align Align(1024);
auto Ordering = AtomicOrdering::Acquire;
auto SSID = SyncScope::System;
{
// Check create() WhereIt, WhereBB.
auto *NewI =
cast<sandboxir::AtomicRMWInst>(sandboxir::AtomicRMWInst::create(
sandboxir::AtomicRMWInst::BinOp::Sub, Ptr, Arg, Align, Ordering,
/*WhereIt=*/Ret->getIterator(),
/*WhereBB=*/Ret->getParent(), Ctx, SSID, "NewAtomicRMW1"));
// Check getOpcode().
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::AtomicRMW);
// Check getAlign().
EXPECT_EQ(NewI->getAlign(), Align);
// Check getSuccessOrdering().
EXPECT_EQ(NewI->getOrdering(), Ordering);
// Check instr position.
EXPECT_EQ(NewI->getNextNode(), Ret);
// Check getPointerOperand().
EXPECT_EQ(NewI->getPointerOperand(), Ptr);
// Check getValOperand().
EXPECT_EQ(NewI->getValOperand(), Arg);
#ifndef NDEBUG
// Check getName().
EXPECT_EQ(NewI->getName(), "NewAtomicRMW1");
#endif // NDEBUG
}
{
// Check create() InsertBefore.
auto *NewI =
cast<sandboxir::AtomicRMWInst>(sandboxir::AtomicRMWInst::create(
sandboxir::AtomicRMWInst::BinOp::Sub, Ptr, Arg, Align, Ordering,
/*InsertBefore=*/Ret, Ctx, SSID, "NewAtomicRMW2"));
// Check getOpcode().
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::AtomicRMW);
// Check getAlign().
EXPECT_EQ(NewI->getAlign(), Align);
// Check getSuccessOrdering().
EXPECT_EQ(NewI->getOrdering(), Ordering);
// Check instr position.
EXPECT_EQ(NewI->getNextNode(), Ret);
// Check getPointerOperand().
EXPECT_EQ(NewI->getPointerOperand(), Ptr);
// Check getValOperand().
EXPECT_EQ(NewI->getValOperand(), Arg);
#ifndef NDEBUG
// Check getName().
EXPECT_EQ(NewI->getName(), "NewAtomicRMW2");
#endif // NDEBUG
}
{
// Check create() InsertAtEnd.
auto *NewI =
cast<sandboxir::AtomicRMWInst>(sandboxir::AtomicRMWInst::create(
sandboxir::AtomicRMWInst::BinOp::Sub, Ptr, Arg, Align, Ordering,
/*InsertAtEnd=*/BB, Ctx, SSID, "NewAtomicRMW3"));
// Check getOpcode().
EXPECT_EQ(NewI->getOpcode(), sandboxir::Instruction::Opcode::AtomicRMW);
// Check getAlign().
EXPECT_EQ(NewI->getAlign(), Align);
// Check getSuccessOrdering().
EXPECT_EQ(NewI->getOrdering(), Ordering);
// Check instr position.
EXPECT_EQ(NewI->getParent(), BB);
EXPECT_EQ(NewI->getNextNode(), nullptr);
// Check getPointerOperand().
EXPECT_EQ(NewI->getPointerOperand(), Ptr);
// Check getValOperand().
EXPECT_EQ(NewI->getValOperand(), Arg);
#ifndef NDEBUG
// Check getName().
EXPECT_EQ(NewI->getName(), "NewAtomicRMW3");
#endif // NDEBUG
}
}

TEST_F(SandboxIRTest, AtomicCmpXchgInst) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %cmp, i8 %new) {
Expand Down
Loading
Loading