Skip to content

Commit 2fbe762

Browse files
committed
[OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode
This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call. Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to influence target device code generation.
1 parent e3cdc93 commit 2fbe762

File tree

4 files changed

+383
-34
lines changed

4 files changed

+383
-34
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
22372237
int32_t MinThreads = 1;
22382238
};
22392239

2240+
/// Container to pass LLVM IR runtime values or constants related to the
2241+
/// number of teams and threads with which the kernel must be launched, as
2242+
/// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
2243+
/// must be defined in the host prior to the call to the kernel launch OpenMP
2244+
/// RTL function.
2245+
struct TargetKernelRuntimeAttrs {
2246+
SmallVector<Value *, 3> MaxTeams = {nullptr};
2247+
Value *MinTeams = nullptr;
2248+
SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
2249+
SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
2250+
2251+
/// 'parallel' construct 'num_threads' clause value, if present and it is a
2252+
/// target SPMD kernel.
2253+
Value *MaxThreads = nullptr;
2254+
2255+
/// Total number of iterations of the target SPMD kernel or null if it is a
2256+
/// generic kernel.
2257+
Value *LoopTripCount = nullptr;
2258+
};
2259+
22402260
/// Data structure that contains the needed information to construct the
22412261
/// kernel args vector.
22422262
struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
29052925
///
29062926
/// \param Loc where the target data construct was encountered.
29072927
/// \param IsOffloadEntry whether it is an offload entry.
2928+
/// \param IsSPMD whether it is a target SPMD kernel.
29082929
/// \param CodeGenIP The insertion point where the call to the outlined
29092930
/// function should be emitted.
29102931
/// \param EntryInfo The entry information about the function.
29112932
/// \param DefaultAttrs Structure containing the default numbers of threads
29122933
/// and teams to launch the kernel with.
2934+
/// \param RuntimeAttrs Structure containing the runtime numbers of threads
2935+
/// and teams to launch the kernel with.
29132936
/// \param Inputs The input values to the region that will be passed.
29142937
/// as arguments to the outlined function.
29152938
/// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
29192942
// dependency information as passed in the depend clause
29202943
// \param HasNowait Whether the target construct has a `nowait` clause or not.
29212944
InsertPointOrErrorTy createTarget(
2922-
const LocationDescription &Loc, bool IsOffloadEntry,
2945+
const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
29232946
OpenMPIRBuilder::InsertPointTy AllocaIP,
29242947
OpenMPIRBuilder::InsertPointTy CodeGenIP,
29252948
TargetRegionEntryInfo &EntryInfo,
29262949
const TargetKernelDefaultAttrs &DefaultAttrs,
2950+
const TargetKernelRuntimeAttrs &RuntimeAttrs,
29272951
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
29282952
TargetBodyGenCallbackTy BodyGenCB,
29292953
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6731,8 +6731,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
67316731
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
67326732
}
67336733

6734+
static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
6735+
Module &M) {
6736+
if (List.empty())
6737+
return;
6738+
6739+
Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
6740+
6741+
// Convert List to what ConstantArray needs.
6742+
SmallVector<Constant *, 8> UsedArray;
6743+
UsedArray.reserve(List.size());
6744+
for (auto Item : List)
6745+
UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
6746+
cast<Constant>(&*Item), PtrTy));
6747+
6748+
ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
6749+
auto *GV =
6750+
new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
6751+
llvm::ConstantArray::get(ArrTy, UsedArray), Name);
6752+
6753+
GV->setSection("llvm.metadata");
6754+
}
6755+
6756+
static void
6757+
emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6758+
StringRef FunctionName, OMPTgtExecModeFlags Mode,
6759+
std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
6760+
auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
6761+
auto *GVMode = new llvm::GlobalVariable(
6762+
OMPBuilder.M, Int8Ty, /*isConstant=*/true,
6763+
llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
6764+
Twine(FunctionName, "_exec_mode"));
6765+
GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
6766+
LLVMCompilerUsed.emplace_back(GVMode);
6767+
}
6768+
67346769
static Expected<Function *> createOutlinedFunction(
6735-
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6770+
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
67366771
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
67376772
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
67386773
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6762,6 +6797,15 @@ static Expected<Function *> createOutlinedFunction(
67626797
auto Func =
67636798
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
67646799

6800+
if (OMPBuilder.Config.isTargetDevice()) {
6801+
std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
6802+
emitExecutionMode(OMPBuilder, Builder, FuncName,
6803+
IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
6804+
: OMP_TGT_EXEC_MODE_GENERIC,
6805+
LLVMCompilerUsed);
6806+
emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
6807+
}
6808+
67656809
// Save insert point.
67666810
IRBuilder<>::InsertPointGuard IPG(Builder);
67676811
// If there's a DISubprogram associated with current function, then
@@ -6802,7 +6846,7 @@ static Expected<Function *> createOutlinedFunction(
68026846
// Insert target init call in the device compilation pass.
68036847
if (OMPBuilder.Config.isTargetDevice())
68046848
Builder.restoreIP(
6805-
OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
6849+
OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
68066850

68076851
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
68086852

@@ -6998,7 +7042,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
69987042

69997043
static Error emitTargetOutlinedFunction(
70007044
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
7001-
TargetRegionEntryInfo &EntryInfo,
7045+
bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
70027046
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
70037047
Function *&OutlinedFn, Constant *&OutlinedFnID,
70047048
SmallVectorImpl<Value *> &Inputs,
@@ -7007,7 +7051,7 @@ static Error emitTargetOutlinedFunction(
70077051

70087052
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
70097053
[&](StringRef EntryFnName) {
7010-
return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
7054+
return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
70117055
EntryFnName, Inputs, CBFunc,
70127056
ArgAccessorFuncCB);
70137057
};
@@ -7307,6 +7351,7 @@ static void
73077351
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73087352
OpenMPIRBuilder::InsertPointTy AllocaIP,
73097353
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7354+
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
73107355
Function *OutlinedFn, Constant *OutlinedFnID,
73117356
SmallVectorImpl<Value *> &Args,
73127357
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7388,11 +7433,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73887433
/*ForEndCall=*/false);
73897434

73907435
SmallVector<Value *, 3> NumTeamsC;
7436+
for (auto [DefaultVal, RuntimeVal] :
7437+
zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
7438+
NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
7439+
7440+
// Calculate number of threads: 0 if no clauses specified, otherwise it is the
7441+
// minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7442+
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7443+
if (Clause)
7444+
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7445+
/*isSigned=*/false);
7446+
return Clause;
7447+
};
7448+
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7449+
if (Clause)
7450+
Result = Result
7451+
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7452+
Result, Clause)
7453+
: Clause;
7454+
};
7455+
7456+
// If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7457+
// the NUM_THREADS clause is overriden by THREAD_LIMIT.
73917458
SmallVector<Value *, 3> NumThreadsC;
7392-
for (auto V : DefaultAttrs.MaxTeams)
7393-
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7394-
for (auto V : DefaultAttrs.MaxThreads)
7395-
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7459+
Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
7460+
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
7461+
: nullptr;
7462+
7463+
for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
7464+
RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
7465+
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
7466+
Value *NumThreads = InitMaxThreadsClause(TargetVal);
7467+
7468+
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7469+
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7470+
7471+
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7472+
}
73967473

73977474
unsigned NumTargetItems = Info.NumberOfPtrs;
73987475
// TODO: Use correct device ID
@@ -7401,14 +7478,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74017478
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
74027479
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
74037480
llvm::omp::IdentFlag(0), 0);
7404-
// TODO: Use correct NumIterations
7405-
Value *NumIterations = Builder.getInt64(0);
7481+
7482+
Value *TripCount = RuntimeAttrs.LoopTripCount
7483+
? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
7484+
Builder.getInt64Ty(),
7485+
/*isSigned=*/false)
7486+
: Builder.getInt64(0);
7487+
74067488
// TODO: Use correct DynCGGroupMem
74077489
Value *DynCGGroupMem = Builder.getInt32(0);
74087490

7409-
KArgs = OpenMPIRBuilder::TargetKernelArgs(
7410-
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7411-
DynCGGroupMem, HasNoWait);
7491+
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
7492+
NumTeamsC, NumThreadsC,
7493+
DynCGGroupMem, HasNoWait);
74127494

74137495
// The presence of certain clauses on the target directive require the
74147496
// explicit generation of the target task.
@@ -7430,13 +7512,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74307512
}
74317513

74327514
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7433-
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7434-
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7515+
const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
7516+
InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
7517+
TargetRegionEntryInfo &EntryInfo,
74357518
const TargetKernelDefaultAttrs &DefaultAttrs,
7519+
const TargetKernelRuntimeAttrs &RuntimeAttrs,
74367520
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
74377521
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
74387522
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
74397523
SmallVector<DependData> Dependencies, bool HasNowait) {
7524+
assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
7525+
"trip count not expected if IsSPMD=false");
74407526

74417527
if (!updateToLocation(Loc))
74427528
return InsertPointTy();
@@ -7449,16 +7535,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74497535
// the target region itself is generated using the callbacks CBFunc
74507536
// and ArgAccessorFuncCB
74517537
if (Error Err = emitTargetOutlinedFunction(
7452-
*this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
7453-
OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7538+
*this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
7539+
OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
74547540
return Err;
74557541

74567542
// If we are not on the target device, then we need to generate code
74577543
// to make a remote call (offload) to the previously outlined function
74587544
// that represents the target region. Do that now.
74597545
if (!Config.isTargetDevice())
7460-
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7461-
OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7546+
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7547+
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7548+
HasNowait);
74627549
return Builder.saveIP();
74637550
}
74647551

0 commit comments

Comments
 (0)