Skip to content

[MLIR][OpenMP] Add codegen for teams reductions #133310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,6 @@ void CGOpenMPRuntimeGPU::emitReduction(
return;

bool ParallelReduction = isOpenMPParallelDirective(Options.ReductionKind);
bool DistributeReduction = isOpenMPDistributeDirective(Options.ReductionKind);
bool TeamsReduction = isOpenMPTeamsDirective(Options.ReductionKind);

ASTContext &C = CGM.getContext();
Expand Down Expand Up @@ -1757,7 +1756,7 @@ void CGOpenMPRuntimeGPU::emitReduction(
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
cantFail(OMPBuilder.createReductionsGPU(
OmpLoc, AllocaIP, CodeGenIP, ReductionInfos, false, TeamsReduction,
DistributeReduction, llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
CGF.getTarget().getGridValue(),
C.getLangOpts().OpenMPCUDAReductionBufNum, RTLoc));
CGF.Builder.restoreIP(AfterIP);
Expand Down
10 changes: 6 additions & 4 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1907,8 +1907,6 @@ class OpenMPIRBuilder {
/// nowait.
/// \param IsTeamsReduction Optional flag set if it is a teams
/// reduction.
/// \param HasDistribute Optional flag set if it is a
/// distribute reduction.
/// \param GridValue Optional GPU grid value.
/// \param ReductionBufNum Optional OpenMPCUDAReductionBufNumValue to be
/// used for teams reduction.
Expand All @@ -1917,7 +1915,6 @@ class OpenMPIRBuilder {
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
bool IsNoWait = false, bool IsTeamsReduction = false,
bool HasDistribute = false,
ReductionGenCBKind ReductionGenCBKind = ReductionGenCBKind::MLIR,
std::optional<omp::GV> GridValue = {}, unsigned ReductionBufNum = 1024,
Value *SrcLocInfo = nullptr);
Expand Down Expand Up @@ -1985,11 +1982,14 @@ class OpenMPIRBuilder {
/// \param IsNoWait A flag set if the reduction is marked as nowait.
/// \param IsByRef A flag set if the reduction is using reference
/// or direct value.
/// \param IsTeamsReduction Optional flag set if it is a teams
/// reduction.
InsertPointOrErrorTy createReductions(const LocationDescription &Loc,
InsertPointTy AllocaIP,
ArrayRef<ReductionInfo> ReductionInfos,
ArrayRef<bool> IsByRef,
bool IsNoWait = false);
bool IsNoWait = false,
bool IsTeamsReduction = false);

///}

Expand Down Expand Up @@ -2273,6 +2273,8 @@ class OpenMPIRBuilder {
int32_t MinTeams = 1;
SmallVector<int32_t, 3> MaxThreads = {-1};
int32_t MinThreads = 1;
int32_t ReductionDataSize = 0;
int32_t ReductionBufferLength = 0;
};

/// Container to pass LLVM IR runtime values or constants related to the
Expand Down
173 changes: 119 additions & 54 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3495,9 +3495,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
unsigned ReductionBufNum, Value *SrcLocInfo) {
bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind,
std::optional<omp::GV> GridValue, unsigned ReductionBufNum,
Value *SrcLocInfo) {
if (!updateToLocation(Loc))
return InsertPointTy();
Builder.restoreIP(CodeGenIP);
Expand All @@ -3514,6 +3514,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
if (ReductionInfos.size() == 0)
return Builder.saveIP();

BasicBlock *ContinuationBlock = nullptr;
if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
// Copied code from createReductions
BasicBlock *InsertBlock = Loc.IP.getBlock();
ContinuationBlock =
InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
InsertBlock->getTerminator()->eraseFromParent();
Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
}

Function *CurFunc = Builder.GetInsertBlock()->getParent();
AttributeList FuncAttrs;
AttrBuilder AttrBldr(Ctx);
Expand Down Expand Up @@ -3669,11 +3679,21 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
ReductionFunc;
});
} else {
assert(false && "Unhandled ReductionGenCBKind");
Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
Value *Reduced;
InsertPointOrErrorTy AfterIP =
RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
if (!AfterIP)
return AfterIP.takeError();
Builder.CreateStore(Reduced, LHS, false);
}
}
emitBlock(ExitBB, CurFunc);

if (ContinuationBlock) {
Builder.CreateBr(ContinuationBlock);
Builder.SetInsertPoint(ContinuationBlock);
}
Config.setEmitLLVMUsed();

return Builder.saveIP();
Expand All @@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
".omp.reduction.func", &M);
}

OpenMPIRBuilder::InsertPointOrErrorTy
OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
InsertPointTy AllocaIP,
ArrayRef<ReductionInfo> ReductionInfos,
ArrayRef<bool> IsByRef, bool IsNoWait) {
assert(ReductionInfos.size() == IsByRef.size());
for (const ReductionInfo &RI : ReductionInfos) {
(void)RI;
assert(RI.Variable && "expected non-null variable");
assert(RI.PrivateVariable && "expected non-null private variable");
assert(RI.ReductionGen && "expected non-null reduction generator callback");
assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
"expected variables and their private equivalents to have the same "
"type");
assert(RI.Variable->getType()->isPointerTy() &&
"expected variables to be pointers");
static Error populateReductionFunction(
Function *ReductionFunc,
ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
Module *Module = ReductionFunc->getParent();
BasicBlock *ReductionFuncBlock =
BasicBlock::Create(Module->getContext(), "", ReductionFunc);
Builder.SetInsertPoint(ReductionFuncBlock);
Value *LHSArrayPtr = nullptr;
Value *RHSArrayPtr = nullptr;
if (IsGPU) {
// Need to alloca memory here and deal with the pointers before getting
// LHS/RHS pointers out
//
Argument *Arg0 = ReductionFunc->getArg(0);
Argument *Arg1 = ReductionFunc->getArg(1);
Type *Arg0Type = Arg0->getType();
Type *Arg1Type = Arg1->getType();

Value *LHSAlloca =
Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
Value *RHSAlloca =
Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
Value *LHSAddrCast =
Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
Value *RHSAddrCast =
Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
Builder.CreateStore(Arg0, LHSAddrCast);
Builder.CreateStore(Arg1, RHSAddrCast);
LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
} else {
LHSArrayPtr = ReductionFunc->getArg(0);
RHSArrayPtr = ReductionFunc->getArg(1);
}

unsigned NumReductions = ReductionInfos.size();
Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);

for (auto En : enumerate(ReductionInfos)) {
const OpenMPIRBuilder::ReductionInfo &RI = En.value();
Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
RedArrayTy, LHSArrayPtr, 0, En.index());
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
LHSI8Ptr, RI.Variable->getType());
Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
RedArrayTy, RHSArrayPtr, 0, En.index());
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
RHSI8Ptr, RI.PrivateVariable->getType());
Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
Value *Reduced;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
if (!AfterIP)
return AfterIP.takeError();

Builder.restoreIP(*AfterIP);
// TODO: Consider flagging an error.
if (!Builder.GetInsertBlock())
return Error::success();
Comment on lines +3771 to +3773
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the reduction callback returns an invalid IP, that seems like an implementation bug rather than some error condition based on the input the user should know about. Maybe we should just assert in this case, since there are plenty of other places where we assume the returned IP is valid if the callback didn't return an llvm::Error.

Having said that, I can see there are also a few other places where the IP is checked, so feel free to leave this if you think it's a better option. At some point, we should probably decide on one single approach and apply it everywhere, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to skip this for now. I have some thoughts how to improve the IP handling, but we can discuss that separately.


// store is inside of the reduction region when using by-ref
if (!IsByRef[En.index()])
Builder.CreateStore(Reduced, LHSPtr);
}
Builder.CreateRetVoid();
return Error::success();
}

OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
const LocationDescription &Loc, InsertPointTy AllocaIP,
ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
bool IsNoWait, bool IsTeamsReduction) {
assert(ReductionInfos.size() == IsByRef.size());
if (Config.isGPU())
return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos,
IsNoWait, IsTeamsReduction);

checkReductionInfos(ReductionInfos, /*IsGPU*/ false);

if (!updateToLocation(Loc))
return InsertPointTy();

if (ReductionInfos.size() == 0)
return Builder.saveIP();

BasicBlock *InsertBlock = Loc.IP.getBlock();
BasicBlock *ContinuationBlock =
InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
Expand Down Expand Up @@ -3832,38 +3920,13 @@ OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
// Populate the outlined reduction function using the elementwise reduction
// function. Partial values are extracted from the type-erased array of
// pointers to private variables.
BasicBlock *ReductionFuncBlock =
BasicBlock::Create(Module->getContext(), "", ReductionFunc);
Builder.SetInsertPoint(ReductionFuncBlock);
Value *LHSArrayPtr = ReductionFunc->getArg(0);
Value *RHSArrayPtr = ReductionFunc->getArg(1);
Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
IsByRef, /*isGPU=*/false);
if (Err)
return Err;

for (auto En : enumerate(ReductionInfos)) {
const ReductionInfo &RI = En.value();
Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
RedArrayTy, LHSArrayPtr, 0, En.index());
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
RedArrayTy, RHSArrayPtr, 0, En.index());
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
Value *RHSPtr =
Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
Value *Reduced;
InsertPointOrErrorTy AfterIP =
RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
if (!AfterIP)
return AfterIP.takeError();
Builder.restoreIP(*AfterIP);
if (!Builder.GetInsertBlock())
return InsertPointTy();
// store is inside of the reduction region when using by-ref
if (!IsByRef[En.index()])
Builder.CreateStore(Reduced, LHSPtr);
}
Builder.CreateRetVoid();
if (!Builder.GetInsertBlock())
return InsertPointTy();

Builder.SetInsertPoint(ContinuationBlock);
return Builder.saveIP();
Expand Down Expand Up @@ -6239,8 +6302,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
Constant *ReductionDataSize =
ConstantInt::getSigned(Int32, Attrs.ReductionDataSize);
Constant *ReductionBufferLength =
ConstantInt::getSigned(Int32, Attrs.ReductionBufferLength);

Function *Fn = getOrCreateRuntimeFunctionPtr(
omp::RuntimeFunction::OMPRTL___kmpc_target_init);
Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2354,6 +2354,7 @@ TEST_F(OpenMPIRBuilderTest, StaticWorkshareLoopTarget) {
"256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8");
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.Config.IsTargetDevice = true;
OMPBuilder.Config.setIsGPU(false);
OMPBuilder.initialize();
IRBuilder<> Builder(BB);
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
Expand Down
Loading