Skip to content

Commit 0c19f71

Browse files
committed
[OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode
This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call. Additionally, kernel type information is used to influence target device code generation and the `IsSPMD` flag is replaced by `ExecFlags`, which provide more granularity.
1 parent 27bc6bd commit 0c19f71

File tree

5 files changed

+398
-67
lines changed

5 files changed

+398
-67
lines changed

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "clang/AST/StmtVisitor.h"
2121
#include "clang/Basic/Cuda.h"
2222
#include "llvm/ADT/SmallPtrSet.h"
23+
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
2324
#include "llvm/Frontend/OpenMP/OMPGridValues.h"
2425

2526
using namespace clang;
@@ -745,7 +746,9 @@ void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
745746
CodeGenFunction &CGF,
746747
EntryFunctionState &EST, bool IsSPMD) {
747748
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
748-
Attrs.IsSPMD = IsSPMD;
749+
Attrs.ExecFlags =
750+
IsSPMD ? llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD
751+
: llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC;
749752
computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
750753

751754
CGBuilderTy &Bld = CGF.Builder;

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,9 +1389,6 @@ class OpenMPIRBuilder {
13891389

13901390
/// Supporting functions for Reductions CodeGen.
13911391
private:
1392-
/// Emit the llvm.used metadata.
1393-
void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List);
1394-
13951392
/// Get the id of the current thread on the GPU.
13961393
Value *getGPUThreadID();
13971394

@@ -2013,6 +2010,13 @@ class OpenMPIRBuilder {
20132010
/// Value.
20142011
GlobalValue *createGlobalFlag(unsigned Value, StringRef Name);
20152012

2013+
/// Emit the llvm.used metadata.
2014+
void emitUsed(StringRef Name, ArrayRef<llvm::WeakTrackingVH> List);
2015+
2016+
/// Emit the kernel execution mode.
2017+
GlobalVariable *emitKernelExecutionMode(StringRef KernelName,
2018+
omp::OMPTgtExecModeFlags Mode);
2019+
20162020
/// Generate control flow and cleanup for cancellation.
20172021
///
20182022
/// \param CancelFlag Flag indicating if the cancellation is performed.
@@ -2233,13 +2237,34 @@ class OpenMPIRBuilder {
22332237
/// time. The number of max values will be 1 except for the case where
22342238
/// ompx_bare is set.
22352239
struct TargetKernelDefaultAttrs {
2236-
bool IsSPMD = false;
2240+
omp::OMPTgtExecModeFlags ExecFlags =
2241+
omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC;
22372242
SmallVector<int32_t, 3> MaxTeams = {-1};
22382243
int32_t MinTeams = 1;
22392244
SmallVector<int32_t, 3> MaxThreads = {-1};
22402245
int32_t MinThreads = 1;
22412246
};
22422247

2248+
/// Container to pass LLVM IR runtime values or constants related to the
2249+
/// number of teams and threads with which the kernel must be launched, as
2250+
/// well as the trip count of the loop, if it is an SPMD or Generic-SPMD
2251+
/// kernel. These must be defined in the host prior to the call to the kernel
2252+
/// launch OpenMP RTL function.
2253+
struct TargetKernelRuntimeAttrs {
2254+
SmallVector<Value *, 3> MaxTeams = {nullptr};
2255+
Value *MinTeams = nullptr;
2256+
SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
2257+
SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
2258+
2259+
/// 'parallel' construct 'num_threads' clause value, if present and it is an
2260+
/// SPMD kernel.
2261+
Value *MaxThreads = nullptr;
2262+
2263+
/// Total number of iterations of the SPMD or Generic-SPMD kernel or null if
2264+
/// it is a generic kernel.
2265+
Value *LoopTripCount = nullptr;
2266+
};
2267+
22432268
/// Data structure that contains the needed information to construct the
22442269
/// kernel args vector.
22452270
struct TargetKernelArgs {
@@ -2971,7 +2996,9 @@ class OpenMPIRBuilder {
29712996
/// \param CodeGenIP The insertion point where the call to the outlined
29722997
/// function should be emitted.
29732998
/// \param EntryInfo The entry information about the function.
2974-
/// \param DefaultAttrs Structure containing the default numbers of threads
2999+
/// \param DefaultAttrs Structure containing the default attributes, including
3000+
/// numbers of threads and teams to launch the kernel with.
3001+
/// \param RuntimeAttrs Structure containing the runtime numbers of threads
29753002
/// and teams to launch the kernel with.
29763003
/// \param Inputs The input values to the region that will be passed.
29773004
/// as arguments to the outlined function.
@@ -2987,6 +3014,7 @@ class OpenMPIRBuilder {
29873014
OpenMPIRBuilder::InsertPointTy CodeGenIP,
29883015
TargetRegionEntryInfo &EntryInfo,
29893016
const TargetKernelDefaultAttrs &DefaultAttrs,
3017+
const TargetKernelRuntimeAttrs &RuntimeAttrs,
29903018
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
29913019
TargetBodyGenCallbackTy BodyGenCB,
29923020
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 92 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
830830
return GV;
831831
}
832832

833+
void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
834+
if (List.empty())
835+
return;
836+
837+
// Convert List to what ConstantArray needs.
838+
SmallVector<Constant *, 8> UsedArray;
839+
UsedArray.resize(List.size());
840+
for (unsigned I = 0, E = List.size(); I != E; ++I)
841+
UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
842+
cast<Constant>(&*List[I]), Builder.getPtrTy());
843+
844+
if (UsedArray.empty())
845+
return;
846+
ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
847+
848+
auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
849+
ConstantArray::get(ATy, UsedArray), Name);
850+
851+
GV->setSection("llvm.metadata");
852+
}
853+
854+
GlobalVariable *
855+
OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
856+
OMPTgtExecModeFlags Mode) {
857+
auto *Int8Ty = Builder.getInt8Ty();
858+
auto *GVMode = new GlobalVariable(
859+
M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
860+
ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode"));
861+
GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
862+
return GVMode;
863+
}
864+
833865
Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
834866
uint32_t SrcLocStrSize,
835867
IdentFlag LocFlags,
@@ -2260,28 +2292,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
22602292
return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
22612293
}
22622294

2263-
void OpenMPIRBuilder::emitUsed(StringRef Name,
2264-
std::vector<WeakTrackingVH> &List) {
2265-
if (List.empty())
2266-
return;
2267-
2268-
// Convert List to what ConstantArray needs.
2269-
SmallVector<Constant *, 8> UsedArray;
2270-
UsedArray.resize(List.size());
2271-
for (unsigned I = 0, E = List.size(); I != E; ++I)
2272-
UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2273-
cast<Constant>(&*List[I]), Builder.getPtrTy());
2274-
2275-
if (UsedArray.empty())
2276-
return;
2277-
ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
2278-
2279-
auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
2280-
ConstantArray::get(ATy, UsedArray), Name);
2281-
2282-
GV->setSection("llvm.metadata");
2283-
}
2284-
22852295
Value *OpenMPIRBuilder::getGPUThreadID() {
22862296
return Builder.CreateCall(
22872297
getOrCreateRuntimeFunction(M,
@@ -6131,10 +6141,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
61316141
uint32_t SrcLocStrSize;
61326142
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
61336143
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6134-
Constant *IsSPMDVal = ConstantInt::getSigned(
6135-
Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6136-
Constant *UseGenericStateMachineVal =
6137-
ConstantInt::getSigned(Int8, !Attrs.IsSPMD);
6144+
Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags);
6145+
Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
6146+
Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
61386147
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
61396148
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
61406149

@@ -6765,6 +6774,12 @@ static Expected<Function *> createOutlinedFunction(
67656774
auto Func =
67666775
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
67676776

6777+
if (OMPBuilder.Config.isTargetDevice()) {
6778+
Value *ExecMode =
6779+
OMPBuilder.emitKernelExecutionMode(FuncName, DefaultAttrs.ExecFlags);
6780+
OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
6781+
}
6782+
67686783
// Save insert point.
67696784
IRBuilder<>::InsertPointGuard IPG(Builder);
67706785
// If there's a DISubprogram associated with current function, then
@@ -7312,6 +7327,7 @@ static void
73127327
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73137328
OpenMPIRBuilder::InsertPointTy AllocaIP,
73147329
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7330+
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
73157331
Function *OutlinedFn, Constant *OutlinedFnID,
73167332
SmallVectorImpl<Value *> &Args,
73177333
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7393,11 +7409,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73937409
/*ForEndCall=*/false);
73947410

73957411
SmallVector<Value *, 3> NumTeamsC;
7412+
for (auto [DefaultVal, RuntimeVal] :
7413+
zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
7414+
NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
7415+
7416+
// Calculate number of threads: 0 if no clauses specified, otherwise it is the
7417+
// minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7418+
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7419+
if (Clause)
7420+
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7421+
/*isSigned=*/false);
7422+
return Clause;
7423+
};
7424+
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7425+
if (Clause)
7426+
Result = Result
7427+
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7428+
Result, Clause)
7429+
: Clause;
7430+
};
7431+
7432+
// If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7433+
// the NUM_THREADS clause is overriden by THREAD_LIMIT.
73967434
SmallVector<Value *, 3> NumThreadsC;
7397-
for (auto V : DefaultAttrs.MaxTeams)
7398-
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7399-
for (auto V : DefaultAttrs.MaxThreads)
7400-
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7435+
Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
7436+
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
7437+
: nullptr;
7438+
7439+
for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
7440+
RuntimeAttrs.TargetThreadLimit)) {
7441+
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
7442+
Value *NumThreads = InitMaxThreadsClause(TargetVal);
7443+
7444+
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7445+
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7446+
7447+
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7448+
}
74017449

74027450
unsigned NumTargetItems = Info.NumberOfPtrs;
74037451
// TODO: Use correct device ID
@@ -7406,14 +7454,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74067454
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
74077455
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
74087456
llvm::omp::IdentFlag(0), 0);
7409-
// TODO: Use correct NumIterations
7410-
Value *NumIterations = Builder.getInt64(0);
7457+
7458+
Value *TripCount = RuntimeAttrs.LoopTripCount
7459+
? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
7460+
Builder.getInt64Ty(),
7461+
/*isSigned=*/false)
7462+
: Builder.getInt64(0);
7463+
74117464
// TODO: Use correct DynCGGroupMem
74127465
Value *DynCGGroupMem = Builder.getInt32(0);
74137466

7414-
KArgs = OpenMPIRBuilder::TargetKernelArgs(
7415-
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7416-
DynCGGroupMem, HasNoWait);
7467+
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
7468+
NumTeamsC, NumThreadsC,
7469+
DynCGGroupMem, HasNoWait);
74177470

74187471
// The presence of certain clauses on the target directive require the
74197472
// explicit generation of the target task.
@@ -7438,6 +7491,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74387491
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
74397492
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
74407493
const TargetKernelDefaultAttrs &DefaultAttrs,
7494+
const TargetKernelRuntimeAttrs &RuntimeAttrs,
74417495
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
74427496
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
74437497
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7462,8 +7516,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74627516
// to make a remote call (offload) to the previously outlined function
74637517
// that represents the target region. Do that now.
74647518
if (!Config.isTargetDevice())
7465-
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7466-
OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7519+
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7520+
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7521+
HasNowait);
74677522
return Builder.saveIP();
74687523
}
74697524

0 commit comments

Comments
 (0)