Skip to content

Commit 47a6495

Browse files
committed
[OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode
This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call. Additionally, kernel type information is used to influence target device code generation and the `IsSPMD` flag is replaced by `ExecFlags`, which provide more granularity.
1 parent 45c6667 commit 47a6495

File tree

5 files changed

+386
-61
lines changed

5 files changed

+386
-61
lines changed

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "clang/AST/StmtVisitor.h"
2121
#include "clang/Basic/Cuda.h"
2222
#include "llvm/ADT/SmallPtrSet.h"
23+
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
2324
#include "llvm/Frontend/OpenMP/OMPGridValues.h"
2425

2526
using namespace clang;
@@ -745,7 +746,9 @@ void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
745746
CodeGenFunction &CGF,
746747
EntryFunctionState &EST, bool IsSPMD) {
747748
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
748-
Attrs.IsSPMD = IsSPMD;
749+
Attrs.ExecFlags =
750+
IsSPMD ? llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD
751+
: llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC;
749752
computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
750753

751754
CGBuilderTy &Bld = CGF.Builder;

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,9 +1389,6 @@ class OpenMPIRBuilder {
13891389

13901390
/// Supporting functions for Reductions CodeGen.
13911391
private:
1392-
/// Emit the llvm.used metadata.
1393-
void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List);
1394-
13951392
/// Get the id of the current thread on the GPU.
13961393
Value *getGPUThreadID();
13971394

@@ -2013,6 +2010,13 @@ class OpenMPIRBuilder {
20132010
/// Value.
20142011
GlobalValue *createGlobalFlag(unsigned Value, StringRef Name);
20152012

2013+
/// Emit the llvm.used metadata.
2014+
void emitUsed(StringRef Name, ArrayRef<llvm::WeakTrackingVH> List);
2015+
2016+
/// Emit the kernel execution mode.
2017+
GlobalVariable *emitKernelExecutionMode(StringRef KernelName,
2018+
omp::OMPTgtExecModeFlags Mode);
2019+
20162020
/// Generate control flow and cleanup for cancellation.
20172021
///
20182022
/// \param CancelFlag Flag indicating if the cancellation is performed.
@@ -2233,13 +2237,34 @@ class OpenMPIRBuilder {
22332237
/// time. The number of max values will be 1 except for the case where
22342238
/// ompx_bare is set.
22352239
struct TargetKernelDefaultAttrs {
2236-
bool IsSPMD = false;
2240+
omp::OMPTgtExecModeFlags ExecFlags =
2241+
omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC;
22372242
SmallVector<int32_t, 3> MaxTeams = {-1};
22382243
int32_t MinTeams = 1;
22392244
SmallVector<int32_t, 3> MaxThreads = {-1};
22402245
int32_t MinThreads = 1;
22412246
};
22422247

2248+
/// Container to pass LLVM IR runtime values or constants related to the
2249+
/// number of teams and threads with which the kernel must be launched, as
2250+
/// well as the trip count of the loop, if it is an SPMD or Generic-SPMD
2251+
/// kernel. These must be defined in the host prior to the call to the kernel
2252+
/// launch OpenMP RTL function.
2253+
struct TargetKernelRuntimeAttrs {
2254+
SmallVector<Value *, 3> MaxTeams = {nullptr};
2255+
Value *MinTeams = nullptr;
2256+
SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
2257+
SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
2258+
2259+
/// 'parallel' construct 'num_threads' clause value, if present and it is an
2260+
/// SPMD kernel.
2261+
Value *MaxThreads = nullptr;
2262+
2263+
/// Total number of iterations of the SPMD or Generic-SPMD kernel or null if
2264+
/// it is a generic kernel.
2265+
Value *LoopTripCount = nullptr;
2266+
};
2267+
22432268
/// Data structure that contains the needed information to construct the
22442269
/// kernel args vector.
22452270
struct TargetKernelArgs {
@@ -2971,7 +2996,9 @@ class OpenMPIRBuilder {
29712996
/// \param CodeGenIP The insertion point where the call to the outlined
29722997
/// function should be emitted.
29732998
/// \param EntryInfo The entry information about the function.
2974-
/// \param DefaultAttrs Structure containing the default numbers of threads
2999+
/// \param DefaultAttrs Structure containing the default attributes, including
3000+
/// numbers of threads and teams to launch the kernel with.
3001+
/// \param RuntimeAttrs Structure containing the runtime numbers of threads
29753002
/// and teams to launch the kernel with.
29763003
/// \param Inputs The input values to the region that will be passed.
29773004
/// as arguments to the outlined function.
@@ -2987,6 +3014,7 @@ class OpenMPIRBuilder {
29873014
OpenMPIRBuilder::InsertPointTy CodeGenIP,
29883015
TargetRegionEntryInfo &EntryInfo,
29893016
const TargetKernelDefaultAttrs &DefaultAttrs,
3017+
const TargetKernelRuntimeAttrs &RuntimeAttrs,
29903018
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
29913019
TargetBodyGenCallbackTy BodyGenCB,
29923020
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 92 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
830830
return GV;
831831
}
832832

833+
void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
834+
if (List.empty())
835+
return;
836+
837+
// Convert List to what ConstantArray needs.
838+
SmallVector<Constant *, 8> UsedArray;
839+
UsedArray.resize(List.size());
840+
for (unsigned I = 0, E = List.size(); I != E; ++I)
841+
UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
842+
cast<Constant>(&*List[I]), Builder.getPtrTy());
843+
844+
if (UsedArray.empty())
845+
return;
846+
ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
847+
848+
auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
849+
ConstantArray::get(ATy, UsedArray), Name);
850+
851+
GV->setSection("llvm.metadata");
852+
}
853+
854+
GlobalVariable *
855+
OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
856+
OMPTgtExecModeFlags Mode) {
857+
auto *Int8Ty = Builder.getInt8Ty();
858+
auto *GVMode = new GlobalVariable(
859+
M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
860+
ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode"));
861+
GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
862+
return GVMode;
863+
}
864+
833865
Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
834866
uint32_t SrcLocStrSize,
835867
IdentFlag LocFlags,
@@ -2260,28 +2292,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
22602292
return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
22612293
}
22622294

2263-
void OpenMPIRBuilder::emitUsed(StringRef Name,
2264-
std::vector<WeakTrackingVH> &List) {
2265-
if (List.empty())
2266-
return;
2267-
2268-
// Convert List to what ConstantArray needs.
2269-
SmallVector<Constant *, 8> UsedArray;
2270-
UsedArray.resize(List.size());
2271-
for (unsigned I = 0, E = List.size(); I != E; ++I)
2272-
UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2273-
cast<Constant>(&*List[I]), Builder.getPtrTy());
2274-
2275-
if (UsedArray.empty())
2276-
return;
2277-
ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
2278-
2279-
auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
2280-
ConstantArray::get(ATy, UsedArray), Name);
2281-
2282-
GV->setSection("llvm.metadata");
2283-
}
2284-
22852295
Value *OpenMPIRBuilder::getGPUThreadID() {
22862296
return Builder.CreateCall(
22872297
getOrCreateRuntimeFunction(M,
@@ -6140,10 +6150,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
61406150
uint32_t SrcLocStrSize;
61416151
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
61426152
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6143-
Constant *IsSPMDVal = ConstantInt::getSigned(
6144-
Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6145-
Constant *UseGenericStateMachineVal =
6146-
ConstantInt::getSigned(Int8, !Attrs.IsSPMD);
6153+
Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags);
6154+
Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
6155+
Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
61476156
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
61486157
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
61496158

@@ -6778,6 +6787,12 @@ static Expected<Function *> createOutlinedFunction(
67786787
auto Func =
67796788
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
67806789

6790+
if (OMPBuilder.Config.isTargetDevice()) {
6791+
Value *ExecMode =
6792+
OMPBuilder.emitKernelExecutionMode(FuncName, DefaultAttrs.ExecFlags);
6793+
OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
6794+
}
6795+
67816796
// Save insert point.
67826797
IRBuilder<>::InsertPointGuard IPG(Builder);
67836798
// If there's a DISubprogram associated with current function, then
@@ -7325,6 +7340,7 @@ static void
73257340
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73267341
OpenMPIRBuilder::InsertPointTy AllocaIP,
73277342
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7343+
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
73287344
Function *OutlinedFn, Constant *OutlinedFnID,
73297345
SmallVectorImpl<Value *> &Args,
73307346
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7406,11 +7422,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74067422
/*ForEndCall=*/false);
74077423

74087424
SmallVector<Value *, 3> NumTeamsC;
7425+
for (auto [DefaultVal, RuntimeVal] :
7426+
zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
7427+
NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
7428+
7429+
// Calculate number of threads: 0 if no clauses specified, otherwise it is the
7430+
// minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7431+
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7432+
if (Clause)
7433+
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7434+
/*isSigned=*/false);
7435+
return Clause;
7436+
};
7437+
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7438+
if (Clause)
7439+
Result = Result
7440+
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7441+
Result, Clause)
7442+
: Clause;
7443+
};
7444+
7445+
// If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7446+
// the NUM_THREADS clause is overriden by THREAD_LIMIT.
74097447
SmallVector<Value *, 3> NumThreadsC;
7410-
for (auto V : DefaultAttrs.MaxTeams)
7411-
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7412-
for (auto V : DefaultAttrs.MaxThreads)
7413-
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7448+
Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
7449+
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
7450+
: nullptr;
7451+
7452+
for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
7453+
RuntimeAttrs.TargetThreadLimit)) {
7454+
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
7455+
Value *NumThreads = InitMaxThreadsClause(TargetVal);
7456+
7457+
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7458+
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7459+
7460+
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7461+
}
74147462

74157463
unsigned NumTargetItems = Info.NumberOfPtrs;
74167464
// TODO: Use correct device ID
@@ -7419,14 +7467,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74197467
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
74207468
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
74217469
llvm::omp::IdentFlag(0), 0);
7422-
// TODO: Use correct NumIterations
7423-
Value *NumIterations = Builder.getInt64(0);
7470+
7471+
Value *TripCount = RuntimeAttrs.LoopTripCount
7472+
? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
7473+
Builder.getInt64Ty(),
7474+
/*isSigned=*/false)
7475+
: Builder.getInt64(0);
7476+
74247477
// TODO: Use correct DynCGGroupMem
74257478
Value *DynCGGroupMem = Builder.getInt32(0);
74267479

7427-
KArgs = OpenMPIRBuilder::TargetKernelArgs(
7428-
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7429-
DynCGGroupMem, HasNoWait);
7480+
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
7481+
NumTeamsC, NumThreadsC,
7482+
DynCGGroupMem, HasNoWait);
74307483

74317484
// The presence of certain clauses on the target directive require the
74327485
// explicit generation of the target task.
@@ -7451,6 +7504,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74517504
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
74527505
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
74537506
const TargetKernelDefaultAttrs &DefaultAttrs,
7507+
const TargetKernelRuntimeAttrs &RuntimeAttrs,
74547508
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
74557509
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
74567510
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7475,8 +7529,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74757529
// to make a remote call (offload) to the previously outlined function
74767530
// that represents the target region. Do that now.
74777531
if (!Config.isTargetDevice())
7478-
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7479-
OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7532+
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7533+
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7534+
HasNowait);
74807535
return Builder.saveIP();
74817536
}
74827537

0 commit comments

Comments
 (0)