Skip to content

Commit be8ccd5

Browse files
committed
[OpenMPIRBuilder] Emit __atomic_load and __atomic_compare_exchange libcalls for complex types in atomic update
1 parent 2e8d815 commit be8ccd5

File tree

3 files changed

+252
-12
lines changed

3 files changed

+252
-12
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,14 +2418,25 @@ class OpenMPIRBuilder {
24182418
/// \param IsXBinopExpr true if \a X is Left H.S. in Right H.S. part of the
24192419
/// update expression, false otherwise.
24202420
/// (e.g. true for X = X BinOp Expr)
2421-
///
2421+
/// \param shouldEmitLibCall true is atomicrmw cannot be emitted for \a X
24222422
/// \returns A pair of the old value of X before the update, and the value
24232423
/// used for the update.
24242424
std::pair<Value *, Value *>
24252425
emitAtomicUpdate(InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
24262426
AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
24272427
AtomicUpdateCallbackTy &UpdateOp, bool VolatileX,
2428-
bool IsXBinopExpr);
2428+
bool IsXBinopExpr, bool shouldEmitLibCall = false);
2429+
2430+
bool emitAtomicCompareExchangeLibCall(
2431+
Instruction *I, unsigned Size, Align Alignment, Value *PointerOperand,
2432+
Value *ValueOperand, Value *CASExpected, AtomicOrdering Ordering,
2433+
AtomicOrdering Ordering2, llvm::PHINode *PHI, llvm::BasicBlock *ContBB,
2434+
llvm::BasicBlock *ExitBB);
2435+
2436+
bool emitAtomicLoadLibCall(Instruction *I, unsigned Size, Align Alignment,
2437+
Value *PointerOperand, Value *ValueOperand,
2438+
Value *CASExpected, AtomicOrdering Ordering,
2439+
AtomicOrdering Ordering2, Value *&LoadedVal);
24292440

24302441
/// Emit the binary op. described by \p RMWOp, using \p Src1 and \p Src2 .
24312442
///
@@ -2487,14 +2498,15 @@ class OpenMPIRBuilder {
24872498
/// \param IsXBinopExpr true if \a X is Left H.S. in Right H.S. part of the
24882499
/// update expression, false otherwise.
24892500
/// (e.g. true for X = X BinOp Expr)
2490-
///
2501+
/// \param shouldEmitLibCall true is atomicrmw cannot be emitted for \a X
24912502
/// \return Insertion point after generated atomic update IR.
24922503
InsertPointTy createAtomicUpdate(const LocationDescription &Loc,
24932504
InsertPointTy AllocaIP, AtomicOpValue &X,
24942505
Value *Expr, AtomicOrdering AO,
24952506
AtomicRMWInst::BinOp RMWOp,
24962507
AtomicUpdateCallbackTy &UpdateOp,
2497-
bool IsXBinopExpr);
2508+
bool IsXBinopExpr,
2509+
bool shouldEmitLibCall = false);
24982510

24992511
/// Emit atomic update for constructs: --- Only Scalar data types
25002512
/// V = X; X = X BinOp Expr ,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 233 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5936,7 +5936,8 @@ OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
59365936
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
59375937
const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
59385938
Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
5939-
AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
5939+
AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr,
5940+
bool shouldEmitLibCall) {
59405941
assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
59415942
if (!updateToLocation(Loc))
59425943
return Loc.IP;
@@ -5955,7 +5956,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
59555956
});
59565957

59575958
emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
5958-
X.IsVolatile, IsXBinopExpr);
5959+
X.IsVolatile, IsXBinopExpr, shouldEmitLibCall);
59595960
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
59605961
return Builder.saveIP();
59615962
}
@@ -5993,10 +5994,180 @@ Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
59935994
llvm_unreachable("Unsupported atomic update operation");
59945995
}
59955996

5997+
bool OpenMPIRBuilder::emitAtomicCompareExchangeLibCall(
5998+
Instruction *I, unsigned Size, Align Alignment, Value *PointerOperand,
5999+
Value *ValueOperand, Value *CASExpected, AtomicOrdering Ordering,
6000+
AtomicOrdering Ordering2, llvm::PHINode *PHI, llvm::BasicBlock *ContBB,
6001+
llvm::BasicBlock *ExitBB) {
6002+
6003+
LLVMContext &Ctx = I->getContext();
6004+
Module *M = I->getModule();
6005+
const DataLayout &DL = M->getDataLayout();
6006+
IRBuilder<> Builder(I);
6007+
IRBuilder<> AllocaBuilder(&I->getFunction()->getEntryBlock().front());
6008+
6009+
Type *SizedIntTy = Type::getIntNTy(Ctx, Size * 8);
6010+
6011+
const Align AllocaAlignment = DL.getPrefTypeAlign(SizedIntTy);
6012+
6013+
ConstantInt *SizeVal64 = ConstantInt::get(Type::getInt64Ty(Ctx), Size);
6014+
assert(Ordering != AtomicOrdering::NotAtomic && "expect atomic MO");
6015+
Constant *OrderingVal =
6016+
ConstantInt::get(Type::getInt32Ty(Ctx), (int)toCABI(Ordering));
6017+
Constant *Ordering2Val = nullptr;
6018+
if (CASExpected) {
6019+
assert(Ordering2 != AtomicOrdering::NotAtomic && "expect atomic MO");
6020+
Ordering2Val =
6021+
ConstantInt::get(Type::getInt32Ty(Ctx), (int)toCABI(Ordering2));
6022+
}
6023+
6024+
bool HasResult = I->getType() != Type::getVoidTy(Ctx);
6025+
AllocaInst *AllocaCASExpected = nullptr;
6026+
AllocaInst *AllocaValue = nullptr;
6027+
AllocaInst *AllocaResult = nullptr;
6028+
6029+
Type *ResultTy;
6030+
SmallVector<Value *, 6> Args;
6031+
AttributeList Attr;
6032+
6033+
Args.push_back(ConstantInt::get(DL.getIntPtrType(Ctx), Size));
6034+
6035+
Value *PtrVal = PointerOperand;
6036+
PtrVal = Builder.CreateAddrSpaceCast(PtrVal, PointerType::getUnqual(Ctx));
6037+
Args.push_back(PtrVal);
6038+
6039+
if (CASExpected) {
6040+
AllocaCASExpected = AllocaBuilder.CreateAlloca(CASExpected->getType());
6041+
AllocaCASExpected->setAlignment(AllocaAlignment);
6042+
Builder.CreateLifetimeStart(AllocaCASExpected, SizeVal64);
6043+
Builder.CreateAlignedStore(CASExpected, AllocaCASExpected, AllocaAlignment);
6044+
Args.push_back(AllocaCASExpected);
6045+
}
6046+
6047+
if (ValueOperand) {
6048+
AllocaValue = AllocaBuilder.CreateAlloca(ValueOperand->getType());
6049+
AllocaValue->setAlignment(AllocaAlignment);
6050+
Builder.CreateLifetimeStart(AllocaValue, SizeVal64);
6051+
Builder.CreateAlignedStore(ValueOperand, AllocaValue, AllocaAlignment);
6052+
Args.push_back(AllocaValue);
6053+
}
6054+
6055+
if (!CASExpected && HasResult) {
6056+
AllocaResult = AllocaBuilder.CreateAlloca(I->getType());
6057+
AllocaResult->setAlignment(AllocaAlignment);
6058+
Builder.CreateLifetimeStart(AllocaResult, SizeVal64);
6059+
Args.push_back(AllocaResult);
6060+
}
6061+
6062+
Args.push_back(OrderingVal);
6063+
6064+
if (Ordering2Val)
6065+
Args.push_back(Ordering2Val);
6066+
6067+
ResultTy = Type::getInt1Ty(Ctx);
6068+
Attr = Attr.addRetAttribute(Ctx, Attribute::ZExt);
6069+
6070+
SmallVector<Type *, 6> ArgTys;
6071+
for (Value *Arg : Args)
6072+
ArgTys.push_back(Arg->getType());
6073+
6074+
FunctionType *FnType = FunctionType::get(ResultTy, ArgTys, false);
6075+
FunctionCallee LibcallFn =
6076+
M->getOrInsertFunction("__atomic_compare_exchange", FnType, Attr);
6077+
CallInst *Call = Builder.CreateCall(LibcallFn, Args);
6078+
Call->setAttributes(Attr);
6079+
Value *Result = Call;
6080+
6081+
if (ValueOperand)
6082+
Builder.CreateLifetimeEnd(AllocaValue, SizeVal64);
6083+
6084+
Type *FinalResultTy = I->getType();
6085+
Value *V = PoisonValue::get(FinalResultTy);
6086+
Value *ExpectedOut = Builder.CreateAlignedLoad(
6087+
CASExpected->getType(), AllocaCASExpected, AllocaAlignment);
6088+
Builder.CreateLifetimeEnd(AllocaCASExpected, SizeVal64);
6089+
V = Builder.CreateInsertValue(V, ExpectedOut, 0);
6090+
V = Builder.CreateInsertValue(V, Result, 1);
6091+
I->replaceAllUsesWith(V);
6092+
Value *PreviousVal = Builder.CreateExtractValue(V, /*Idxs=*/0);
6093+
Value *SuccessFailureVal = Builder.CreateExtractValue(V, /*Idxs=*/1);
6094+
PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
6095+
Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
6096+
return true;
6097+
}
6098+
6099+
bool OpenMPIRBuilder::emitAtomicLoadLibCall(
6100+
Instruction *I, unsigned Size, Align Alignment, Value *PointerOperand,
6101+
Value *ValueOperand, Value *CASExpected, AtomicOrdering Ordering,
6102+
AtomicOrdering Ordering2, Value *&LoadedVal) {
6103+
6104+
LLVMContext &Ctx = I->getContext();
6105+
Module *M = I->getModule();
6106+
const DataLayout &DL = M->getDataLayout();
6107+
IRBuilder<> Builder(I);
6108+
IRBuilder<> AllocaBuilder(&I->getFunction()->getEntryBlock().front());
6109+
6110+
Type *SizedIntTy = Type::getIntNTy(Ctx, Size * 8);
6111+
6112+
const Align AllocaAlignment = DL.getPrefTypeAlign(SizedIntTy);
6113+
6114+
ConstantInt *SizeVal64 = ConstantInt::get(Type::getInt64Ty(Ctx), Size);
6115+
assert(Ordering != AtomicOrdering::NotAtomic && "expect atomic MO");
6116+
Constant *OrderingVal =
6117+
ConstantInt::get(Type::getInt32Ty(Ctx), (int)toCABI(Ordering));
6118+
Constant *Ordering2Val = nullptr;
6119+
6120+
bool HasResult = I->getType() != Type::getVoidTy(Ctx);
6121+
AllocaInst *AllocaCASExpected = nullptr;
6122+
AllocaInst *AllocaValue = nullptr;
6123+
AllocaInst *AllocaResult = nullptr;
6124+
6125+
Type *ResultTy;
6126+
SmallVector<Value *, 6> Args;
6127+
AttributeList Attr;
6128+
6129+
Args.push_back(ConstantInt::get(DL.getIntPtrType(Ctx), Size));
6130+
6131+
Value *PtrVal = PointerOperand;
6132+
PtrVal = Builder.CreateAddrSpaceCast(PtrVal, PointerType::getUnqual(Ctx));
6133+
Args.push_back(PtrVal);
6134+
6135+
if (!CASExpected && HasResult) {
6136+
AllocaResult = AllocaBuilder.CreateAlloca(I->getType());
6137+
AllocaResult->setAlignment(AllocaAlignment);
6138+
Builder.CreateLifetimeStart(AllocaResult, SizeVal64);
6139+
Args.push_back(AllocaResult);
6140+
}
6141+
6142+
Args.push_back(OrderingVal);
6143+
6144+
if (Ordering2Val)
6145+
Args.push_back(Ordering2Val);
6146+
6147+
ResultTy = Type::getVoidTy(Ctx);
6148+
6149+
SmallVector<Type *, 6> ArgTys;
6150+
for (Value *Arg : Args)
6151+
ArgTys.push_back(Arg->getType());
6152+
6153+
FunctionType *FnType = FunctionType::get(ResultTy, ArgTys, false);
6154+
FunctionCallee LibcallFn =
6155+
M->getOrInsertFunction("__atomic_load", FnType, Attr);
6156+
CallInst *Call = Builder.CreateCall(LibcallFn, Args);
6157+
Call->setAttributes(Attr);
6158+
6159+
LoadedVal =
6160+
Builder.CreateAlignedLoad(I->getType(), AllocaResult, AllocaAlignment);
6161+
Builder.CreateLifetimeEnd(AllocaResult, SizeVal64);
6162+
I->replaceAllUsesWith(LoadedVal);
6163+
return true;
6164+
}
6165+
59966166
std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
59976167
InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
59986168
AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
5999-
AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
6169+
AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr,
6170+
bool shouldEmitLibCall) {
60006171
// TODO: handle the case where XElemTy is not byte-sized or not a power of 2
60016172
// or a complex datatype.
60026173
bool emitRMWOp = false;
@@ -6018,7 +6189,7 @@ std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
60186189
emitRMWOp &= XElemTy->isIntegerTy();
60196190

60206191
std::pair<Value *, Value *> Res;
6021-
if (emitRMWOp) {
6192+
if (emitRMWOp && !shouldEmitLibCall) {
60226193
Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
60236194
// not needed except in case of postfix captures. Generate anyway for
60246195
// consistency with the else part. Will be removed with any DCE pass.
@@ -6027,6 +6198,64 @@ std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
60276198
Res.second = Res.first;
60286199
else
60296200
Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
6201+
} else if (shouldEmitLibCall) {
6202+
LoadInst *OldVal =
6203+
Builder.CreateLoad(XElemTy, X, X->getName() + ".atomic.load");
6204+
OldVal->setAtomic(AO);
6205+
const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
6206+
unsigned LoadSize =
6207+
LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
6208+
6209+
Value *LoadedVal = nullptr;
6210+
emitAtomicLoadLibCall(OldVal, LoadSize, OldVal->getAlign(),
6211+
OldVal->getPointerOperand(), nullptr, nullptr,
6212+
OldVal->getOrdering(), AtomicOrdering::NotAtomic,
6213+
LoadedVal);
6214+
6215+
BasicBlock *CurBB = Builder.GetInsertBlock();
6216+
Instruction *CurBBTI = CurBB->getTerminator();
6217+
CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
6218+
BasicBlock *ExitBB =
6219+
CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
6220+
BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
6221+
X->getName() + ".atomic.cont");
6222+
ContBB->getTerminator()->eraseFromParent();
6223+
Builder.restoreIP(AllocaIP);
6224+
AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
6225+
NewAtomicAddr->setName(X->getName() + "x.new.val");
6226+
Builder.SetInsertPoint(ContBB);
6227+
llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
6228+
PHI->addIncoming(LoadedVal, CurBB);
6229+
Value *OldExprVal = PHI;
6230+
6231+
Value *Upd = UpdateOp(OldExprVal, Builder);
6232+
Builder.CreateStore(Upd, NewAtomicAddr);
6233+
LoadInst *DesiredVal = Builder.CreateLoad(XElemTy, NewAtomicAddr);
6234+
AtomicOrdering Failure =
6235+
llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
6236+
AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
6237+
X, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
6238+
6239+
const DataLayout &DL = Result->getModule()->getDataLayout();
6240+
unsigned Size = DL.getTypeStoreSize(Result->getCompareOperand()->getType());
6241+
6242+
emitAtomicCompareExchangeLibCall(
6243+
Result, Size, Result->getAlign(), Result->getPointerOperand(),
6244+
Result->getNewValOperand(), Result->getCompareOperand(),
6245+
Result->getSuccessOrdering(), Result->getFailureOrdering(), PHI, ContBB,
6246+
ExitBB);
6247+
6248+
Result->eraseFromParent();
6249+
OldVal->eraseFromParent();
6250+
Res.first = OldExprVal;
6251+
Res.second = Upd;
6252+
if (UnreachableInst *ExitTI =
6253+
dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
6254+
CurBBTI->eraseFromParent();
6255+
Builder.SetInsertPoint(ExitBB);
6256+
} else {
6257+
Builder.SetInsertPoint(ExitTI);
6258+
}
60306259
} else {
60316260
IntegerType *IntCastTy =
60326261
IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,7 +1621,7 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
16211621

16221622
// Convert values and types.
16231623
auto &innerOpList = opInst.getRegion().front().getOperations();
1624-
bool isRegionArgUsed{false}, isXBinopExpr{false};
1624+
bool isRegionArgUsed{false}, isXBinopExpr{false}, shouldEmitLibCall{false};
16251625
llvm::AtomicRMWInst::BinOp binop;
16261626
mlir::Value mlirExpr;
16271627
// Find the binary update operation that uses the region argument
@@ -1640,8 +1640,7 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
16401640
}
16411641
}
16421642
if (!isRegionArgUsed)
1643-
return opInst.emitError("no atomic update operation with region argument"
1644-
" as operand found inside atomic.update region");
1643+
shouldEmitLibCall = true;
16451644

16461645
llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
16471646
llvm::Value *llvmX = moduleTranslation.lookupValue(opInst.getX());
@@ -1679,7 +1678,7 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
16791678
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
16801679
builder.restoreIP(ompBuilder->createAtomicUpdate(
16811680
ompLoc, allocaIP, llvmAtomicX, llvmExpr, atomicOrdering, binop, updateFn,
1682-
isXBinopExpr));
1681+
isXBinopExpr, shouldEmitLibCall));
16831682
return updateGenStatus;
16841683
}
16851684

0 commit comments

Comments
 (0)