Skip to content

Commit f120456

Browse files
committed
[OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode
This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call. Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to influence target device code generation.
1 parent 1fcfe48 commit f120456

File tree

4 files changed

+383
-34
lines changed

4 files changed

+383
-34
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
22372237
int32_t MinThreads = 1;
22382238
};
22392239

2240+
/// Container to pass LLVM IR runtime values or constants related to the
2241+
/// number of teams and threads with which the kernel must be launched, as
2242+
/// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
2243+
/// must be defined in the host prior to the call to the kernel launch OpenMP
2244+
/// RTL function.
2245+
struct TargetKernelRuntimeAttrs {
2246+
SmallVector<Value *, 3> MaxTeams = {nullptr};
2247+
Value *MinTeams = nullptr;
2248+
SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
2249+
SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
2250+
2251+
/// 'parallel' construct 'num_threads' clause value, if present and it is a
2252+
/// target SPMD kernel.
2253+
Value *MaxThreads = nullptr;
2254+
2255+
/// Total number of iterations of the target SPMD kernel or null if it is a
2256+
/// generic kernel.
2257+
Value *LoopTripCount = nullptr;
2258+
};
2259+
22402260
/// Data structure that contains the needed information to construct the
22412261
/// kernel args vector.
22422262
struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
29052925
///
29062926
/// \param Loc where the target data construct was encountered.
29072927
/// \param IsOffloadEntry whether it is an offload entry.
2928+
/// \param IsSPMD whether it is a target SPMD kernel.
29082929
/// \param CodeGenIP The insertion point where the call to the outlined
29092930
/// function should be emitted.
29102931
/// \param EntryInfo The entry information about the function.
29112932
/// \param DefaultAttrs Structure containing the default numbers of threads
29122933
/// and teams to launch the kernel with.
2934+
/// \param RuntimeAttrs Structure containing the runtime numbers of threads
2935+
/// and teams to launch the kernel with.
29132936
/// \param Inputs The input values to the region that will be passed.
29142937
/// as arguments to the outlined function.
29152938
/// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
29192942
// dependency information as passed in the depend clause
29202943
// \param HasNowait Whether the target construct has a `nowait` clause or not.
29212944
InsertPointOrErrorTy createTarget(
2922-
const LocationDescription &Loc, bool IsOffloadEntry,
2945+
const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
29232946
OpenMPIRBuilder::InsertPointTy AllocaIP,
29242947
OpenMPIRBuilder::InsertPointTy CodeGenIP,
29252948
TargetRegionEntryInfo &EntryInfo,
29262949
const TargetKernelDefaultAttrs &DefaultAttrs,
2950+
const TargetKernelRuntimeAttrs &RuntimeAttrs,
29272951
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
29282952
TargetBodyGenCallbackTy BodyGenCB,
29292953
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
67276727
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
67286728
}
67296729

6730+
static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
6731+
Module &M) {
6732+
if (List.empty())
6733+
return;
6734+
6735+
Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
6736+
6737+
// Convert List to what ConstantArray needs.
6738+
SmallVector<Constant *, 8> UsedArray;
6739+
UsedArray.reserve(List.size());
6740+
for (auto Item : List)
6741+
UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
6742+
cast<Constant>(&*Item), PtrTy));
6743+
6744+
ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
6745+
auto *GV =
6746+
new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
6747+
llvm::ConstantArray::get(ArrTy, UsedArray), Name);
6748+
6749+
GV->setSection("llvm.metadata");
6750+
}
6751+
6752+
static void
6753+
emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6754+
StringRef FunctionName, OMPTgtExecModeFlags Mode,
6755+
std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
6756+
auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
6757+
auto *GVMode = new llvm::GlobalVariable(
6758+
OMPBuilder.M, Int8Ty, /*isConstant=*/true,
6759+
llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
6760+
Twine(FunctionName, "_exec_mode"));
6761+
GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
6762+
LLVMCompilerUsed.emplace_back(GVMode);
6763+
}
6764+
67306765
static Expected<Function *> createOutlinedFunction(
6731-
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6766+
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
67326767
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
67336768
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
67346769
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,15 @@ static Expected<Function *> createOutlinedFunction(
67586793
auto Func =
67596794
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
67606795

6796+
if (OMPBuilder.Config.isTargetDevice()) {
6797+
std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
6798+
emitExecutionMode(OMPBuilder, Builder, FuncName,
6799+
IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
6800+
: OMP_TGT_EXEC_MODE_GENERIC,
6801+
LLVMCompilerUsed);
6802+
emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
6803+
}
6804+
67616805
// Save insert point.
67626806
IRBuilder<>::InsertPointGuard IPG(Builder);
67636807
// If there's a DISubprogram associated with current function, then
@@ -6798,7 +6842,7 @@ static Expected<Function *> createOutlinedFunction(
67986842
// Insert target init call in the device compilation pass.
67996843
if (OMPBuilder.Config.isTargetDevice())
68006844
Builder.restoreIP(
6801-
OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
6845+
OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
68026846

68036847
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
68046848

@@ -6995,7 +7039,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
69957039

69967040
static Error emitTargetOutlinedFunction(
69977041
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
6998-
TargetRegionEntryInfo &EntryInfo,
7042+
bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
69997043
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
70007044
Function *&OutlinedFn, Constant *&OutlinedFnID,
70017045
SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7048,7 @@ static Error emitTargetOutlinedFunction(
70047048

70057049
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
70067050
[&](StringRef EntryFnName) {
7007-
return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
7051+
return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
70087052
EntryFnName, Inputs, CBFunc,
70097053
ArgAccessorFuncCB);
70107054
};
@@ -7304,6 +7348,7 @@ static void
73047348
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73057349
OpenMPIRBuilder::InsertPointTy AllocaIP,
73067350
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7351+
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
73077352
Function *OutlinedFn, Constant *OutlinedFnID,
73087353
SmallVectorImpl<Value *> &Args,
73097354
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7430,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73857430
/*ForEndCall=*/false);
73867431

73877432
SmallVector<Value *, 3> NumTeamsC;
7433+
for (auto [DefaultVal, RuntimeVal] :
7434+
zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
7435+
NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
7436+
7437+
// Calculate number of threads: 0 if no clauses specified, otherwise it is the
7438+
// minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7439+
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7440+
if (Clause)
7441+
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7442+
/*isSigned=*/false);
7443+
return Clause;
7444+
};
7445+
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7446+
if (Clause)
7447+
Result = Result
7448+
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7449+
Result, Clause)
7450+
: Clause;
7451+
};
7452+
7453+
// If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7454+
// the NUM_THREADS clause is overriden by THREAD_LIMIT.
73887455
SmallVector<Value *, 3> NumThreadsC;
7389-
for (auto V : DefaultAttrs.MaxTeams)
7390-
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7391-
for (auto V : DefaultAttrs.MaxThreads)
7392-
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7456+
Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
7457+
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
7458+
: nullptr;
7459+
7460+
for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
7461+
RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
7462+
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
7463+
Value *NumThreads = InitMaxThreadsClause(TargetVal);
7464+
7465+
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7466+
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7467+
7468+
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7469+
}
73937470

73947471
unsigned NumTargetItems = Info.NumberOfPtrs;
73957472
// TODO: Use correct device ID
@@ -7398,14 +7475,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73987475
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
73997476
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
74007477
llvm::omp::IdentFlag(0), 0);
7401-
// TODO: Use correct NumIterations
7402-
Value *NumIterations = Builder.getInt64(0);
7478+
7479+
Value *TripCount = RuntimeAttrs.LoopTripCount
7480+
? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
7481+
Builder.getInt64Ty(),
7482+
/*isSigned=*/false)
7483+
: Builder.getInt64(0);
7484+
74037485
// TODO: Use correct DynCGGroupMem
74047486
Value *DynCGGroupMem = Builder.getInt32(0);
74057487

7406-
KArgs = OpenMPIRBuilder::TargetKernelArgs(
7407-
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7408-
DynCGGroupMem, HasNoWait);
7488+
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
7489+
NumTeamsC, NumThreadsC,
7490+
DynCGGroupMem, HasNoWait);
74097491

74107492
// The presence of certain clauses on the target directive require the
74117493
// explicit generation of the target task.
@@ -7427,13 +7509,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74277509
}
74287510

74297511
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7430-
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7431-
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7512+
const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
7513+
InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
7514+
TargetRegionEntryInfo &EntryInfo,
74327515
const TargetKernelDefaultAttrs &DefaultAttrs,
7516+
const TargetKernelRuntimeAttrs &RuntimeAttrs,
74337517
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
74347518
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
74357519
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
74367520
SmallVector<DependData> Dependencies, bool HasNowait) {
7521+
assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
7522+
"trip count not expected if IsSPMD=false");
74377523

74387524
if (!updateToLocation(Loc))
74397525
return InsertPointTy();
@@ -7446,16 +7532,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74467532
// the target region itself is generated using the callbacks CBFunc
74477533
// and ArgAccessorFuncCB
74487534
if (Error Err = emitTargetOutlinedFunction(
7449-
*this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
7450-
OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7535+
*this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
7536+
OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
74517537
return Err;
74527538

74537539
// If we are not on the target device, then we need to generate code
74547540
// to make a remote call (offload) to the previously outlined function
74557541
// that represents the target region. Do that now.
74567542
if (!Config.isTargetDevice())
7457-
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7458-
OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7543+
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7544+
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7545+
HasNowait);
74597546
return Builder.saveIP();
74607547
}
74617548

0 commit comments

Comments
 (0)