-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][OpenMP] Add codegen for teams reductions #133310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
78f53d7
2c58b5f
717a6ec
d793df5
44f0e7a
fcb0e90
f1aa930
623188b
d5eb4a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3495,9 +3495,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos, | |
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU( | ||
const LocationDescription &Loc, InsertPointTy AllocaIP, | ||
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos, | ||
bool IsNoWait, bool IsTeamsReduction, bool HasDistribute, | ||
ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue, | ||
unsigned ReductionBufNum, Value *SrcLocInfo) { | ||
bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind, | ||
std::optional<omp::GV> GridValue, unsigned ReductionBufNum, | ||
Value *SrcLocInfo) { | ||
if (!updateToLocation(Loc)) | ||
return InsertPointTy(); | ||
Builder.restoreIP(CodeGenIP); | ||
|
@@ -3514,6 +3514,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU( | |
if (ReductionInfos.size() == 0) | ||
return Builder.saveIP(); | ||
|
||
BasicBlock *ContinuationBlock = nullptr; | ||
if (ReductionGenCBKind != ReductionGenCBKind::Clang) { | ||
// Copied code from createReductions | ||
BasicBlock *InsertBlock = Loc.IP.getBlock(); | ||
ContinuationBlock = | ||
InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize"); | ||
InsertBlock->getTerminator()->eraseFromParent(); | ||
Builder.SetInsertPoint(InsertBlock, InsertBlock->end()); | ||
} | ||
|
||
Function *CurFunc = Builder.GetInsertBlock()->getParent(); | ||
AttributeList FuncAttrs; | ||
AttrBuilder AttrBldr(Ctx); | ||
|
@@ -3669,11 +3679,21 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU( | |
ReductionFunc; | ||
}); | ||
} else { | ||
assert(false && "Unhandled ReductionGenCBKind"); | ||
Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs"); | ||
Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs"); | ||
Value *Reduced; | ||
InsertPointOrErrorTy AfterIP = | ||
RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced); | ||
if (!AfterIP) | ||
return AfterIP.takeError(); | ||
Builder.CreateStore(Reduced, LHS, false); | ||
} | ||
} | ||
emitBlock(ExitBB, CurFunc); | ||
|
||
if (ContinuationBlock) { | ||
Builder.CreateBr(ContinuationBlock); | ||
Builder.SetInsertPoint(ContinuationBlock); | ||
} | ||
Config.setEmitLLVMUsed(); | ||
|
||
return Builder.saveIP(); | ||
|
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) { | |
".omp.reduction.func", &M); | ||
} | ||
|
||
OpenMPIRBuilder::InsertPointOrErrorTy | ||
OpenMPIRBuilder::createReductions(const LocationDescription &Loc, | ||
InsertPointTy AllocaIP, | ||
ArrayRef<ReductionInfo> ReductionInfos, | ||
ArrayRef<bool> IsByRef, bool IsNoWait) { | ||
assert(ReductionInfos.size() == IsByRef.size()); | ||
for (const ReductionInfo &RI : ReductionInfos) { | ||
(void)RI; | ||
assert(RI.Variable && "expected non-null variable"); | ||
assert(RI.PrivateVariable && "expected non-null private variable"); | ||
assert(RI.ReductionGen && "expected non-null reduction generator callback"); | ||
assert(RI.Variable->getType() == RI.PrivateVariable->getType() && | ||
"expected variables and their private equivalents to have the same " | ||
"type"); | ||
assert(RI.Variable->getType()->isPointerTy() && | ||
"expected variables to be pointers"); | ||
static Error populateReductionFunction( | ||
Function *ReductionFunc, | ||
ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos, | ||
IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) { | ||
Module *Module = ReductionFunc->getParent(); | ||
BasicBlock *ReductionFuncBlock = | ||
BasicBlock::Create(Module->getContext(), "", ReductionFunc); | ||
Builder.SetInsertPoint(ReductionFuncBlock); | ||
Value *LHSArrayPtr = nullptr; | ||
Value *RHSArrayPtr = nullptr; | ||
if (IsGPU) { | ||
// Need to alloca memory here and deal with the pointers before getting | ||
// LHS/RHS pointers out | ||
// | ||
Argument *Arg0 = ReductionFunc->getArg(0); | ||
Argument *Arg1 = ReductionFunc->getArg(1); | ||
Type *Arg0Type = Arg0->getType(); | ||
Type *Arg1Type = Arg1->getType(); | ||
|
||
Value *LHSAlloca = | ||
Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr"); | ||
Value *RHSAlloca = | ||
Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr"); | ||
Value *LHSAddrCast = | ||
Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type); | ||
Value *RHSAddrCast = | ||
Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type); | ||
Builder.CreateStore(Arg0, LHSAddrCast); | ||
Builder.CreateStore(Arg1, RHSAddrCast); | ||
LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast); | ||
RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast); | ||
} else { | ||
LHSArrayPtr = ReductionFunc->getArg(0); | ||
RHSArrayPtr = ReductionFunc->getArg(1); | ||
} | ||
|
||
unsigned NumReductions = ReductionInfos.size(); | ||
Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions); | ||
|
||
for (auto En : enumerate(ReductionInfos)) { | ||
const OpenMPIRBuilder::ReductionInfo &RI = En.value(); | ||
Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64( | ||
RedArrayTy, LHSArrayPtr, 0, En.index()); | ||
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr); | ||
Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast( | ||
LHSI8Ptr, RI.Variable->getType()); | ||
Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr); | ||
Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64( | ||
RedArrayTy, RHSArrayPtr, 0, En.index()); | ||
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr); | ||
Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast( | ||
RHSI8Ptr, RI.PrivateVariable->getType()); | ||
Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr); | ||
Value *Reduced; | ||
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = | ||
RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced); | ||
if (!AfterIP) | ||
return AfterIP.takeError(); | ||
|
||
Builder.restoreIP(*AfterIP); | ||
// TODO: Consider flagging an error. | ||
if (!Builder.GetInsertBlock()) | ||
return Error::success(); | ||
Comment on lines
+3771
to
+3773
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the reduction callback returns an invalid IP, that seems like an implementation bug rather than some error condition based on the input the user should know about. Maybe we should just assert in this case, since there are plenty of other places where we assume the returned IP is valid if the callback didn't return an Having said that, I can see there are also a few other places where the IP is checked, so feel free to leave this if you think it's a better option. At some point, we should probably decide on one single approach and apply it everywhere, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm going to skip this for now. I have some thoughts how to improve the IP handling, but we can discuss that separately. |
||
|
||
// store is inside of the reduction region when using by-ref | ||
if (!IsByRef[En.index()]) | ||
Builder.CreateStore(Reduced, LHSPtr); | ||
} | ||
Builder.CreateRetVoid(); | ||
return Error::success(); | ||
} | ||
|
||
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions( | ||
const LocationDescription &Loc, InsertPointTy AllocaIP, | ||
ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef, | ||
bool IsNoWait, bool IsTeamsReduction) { | ||
assert(ReductionInfos.size() == IsByRef.size()); | ||
if (Config.isGPU()) | ||
return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos, | ||
IsNoWait, IsTeamsReduction); | ||
|
||
checkReductionInfos(ReductionInfos, /*IsGPU*/ false); | ||
|
||
if (!updateToLocation(Loc)) | ||
return InsertPointTy(); | ||
|
||
if (ReductionInfos.size() == 0) | ||
return Builder.saveIP(); | ||
|
||
BasicBlock *InsertBlock = Loc.IP.getBlock(); | ||
BasicBlock *ContinuationBlock = | ||
InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize"); | ||
|
@@ -3832,38 +3920,13 @@ OpenMPIRBuilder::createReductions(const LocationDescription &Loc, | |
// Populate the outlined reduction function using the elementwise reduction | ||
// function. Partial values are extracted from the type-erased array of | ||
// pointers to private variables. | ||
BasicBlock *ReductionFuncBlock = | ||
BasicBlock::Create(Module->getContext(), "", ReductionFunc); | ||
Builder.SetInsertPoint(ReductionFuncBlock); | ||
Value *LHSArrayPtr = ReductionFunc->getArg(0); | ||
Value *RHSArrayPtr = ReductionFunc->getArg(1); | ||
Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder, | ||
IsByRef, /*isGPU=*/false); | ||
if (Err) | ||
return Err; | ||
|
||
for (auto En : enumerate(ReductionInfos)) { | ||
const ReductionInfo &RI = En.value(); | ||
Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64( | ||
RedArrayTy, LHSArrayPtr, 0, En.index()); | ||
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr); | ||
Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType()); | ||
Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr); | ||
Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64( | ||
RedArrayTy, RHSArrayPtr, 0, En.index()); | ||
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr); | ||
Value *RHSPtr = | ||
Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType()); | ||
Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr); | ||
Value *Reduced; | ||
InsertPointOrErrorTy AfterIP = | ||
RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced); | ||
if (!AfterIP) | ||
return AfterIP.takeError(); | ||
Builder.restoreIP(*AfterIP); | ||
if (!Builder.GetInsertBlock()) | ||
return InsertPointTy(); | ||
// store is inside of the reduction region when using by-ref | ||
if (!IsByRef[En.index()]) | ||
Builder.CreateStore(Reduced, LHSPtr); | ||
} | ||
Builder.CreateRetVoid(); | ||
if (!Builder.GetInsertBlock()) | ||
return InsertPointTy(); | ||
|
||
Builder.SetInsertPoint(ContinuationBlock); | ||
return Builder.saveIP(); | ||
|
@@ -6239,8 +6302,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit( | |
Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal); | ||
Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams); | ||
Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front()); | ||
Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0); | ||
Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0); | ||
Constant *ReductionDataSize = | ||
ConstantInt::getSigned(Int32, Attrs.ReductionDataSize); | ||
Constant *ReductionBufferLength = | ||
ConstantInt::getSigned(Int32, Attrs.ReductionBufferLength); | ||
|
||
Function *Fn = getOrCreateRuntimeFunctionPtr( | ||
omp::RuntimeFunction::OMPRTL___kmpc_target_init); | ||
|
Uh oh!
There was an error while loading. Please reload this page.