Skip to content

Commit b1e4eb5

Browse files
committed
Fine-grained control of kernel execution mode
1 parent e9ea3a5 commit b1e4eb5

File tree

5 files changed

+47
-34
lines changed

5 files changed

+47
-34
lines changed

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "clang/AST/StmtVisitor.h"
2121
#include "clang/Basic/Cuda.h"
2222
#include "llvm/ADT/SmallPtrSet.h"
23+
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
2324
#include "llvm/Frontend/OpenMP/OMPGridValues.h"
2425

2526
using namespace clang;
@@ -748,7 +749,11 @@ void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
748749
computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
749750

750751
CGBuilderTy &Bld = CGF.Builder;
751-
Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, Attrs));
752+
Bld.restoreIP(OMPBuilder.createTargetInit(
753+
Bld,
754+
IsSPMD ? llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD
755+
: llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
756+
Attrs));
752757
if (!IsSPMD)
753758
emitGenericVarsProlog(CGF, EST.Loc);
754759
}

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2243,21 +2243,21 @@ class OpenMPIRBuilder {
22432243

22442244
/// Container to pass LLVM IR runtime values or constants related to the
22452245
/// number of teams and threads with which the kernel must be launched, as
2246-
/// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
2247-
/// must be defined in the host prior to the call to the kernel launch OpenMP
2248-
/// RTL function.
2246+
/// well as the trip count of the loop, if it is an SPMD or Generic-SPMD
2247+
/// kernel. These must be defined in the host prior to the call to the kernel
2248+
/// launch OpenMP RTL function.
22492249
struct TargetKernelRuntimeAttrs {
22502250
SmallVector<Value *, 3> MaxTeams = {nullptr};
22512251
Value *MinTeams = nullptr;
22522252
SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
22532253
SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
22542254

2255-
/// 'parallel' construct 'num_threads' clause value, if present and it is a
2256-
/// target SPMD kernel.
2255+
/// 'parallel' construct 'num_threads' clause value, if present and it is an
2256+
/// SPMD kernel.
22572257
Value *MaxThreads = nullptr;
22582258

2259-
/// Total number of iterations of the target SPMD kernel or null if it is a
2260-
/// generic kernel.
2259+
/// Total number of iterations of the SPMD or Generic-SPMD kernel or null if
2260+
/// it is a generic kernel.
22612261
Value *LoopTripCount = nullptr;
22622262
};
22632263

@@ -2763,11 +2763,12 @@ class OpenMPIRBuilder {
27632763
/// Create a runtime call for kmpc_target_init
27642764
///
27652765
/// \param Loc The insert and source location description.
2766+
/// \param ExecFlags Kernel execution mode flags.
27662767
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
27672768
/// \param Attrs Structure containing the default numbers of threads and teams
27682769
/// to launch the kernel with.
27692770
InsertPointTy createTargetInit(
2770-
const LocationDescription &Loc, bool IsSPMD,
2771+
const LocationDescription &Loc, omp::OMPTgtExecModeFlags ExecFlags,
27712772
const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
27722773

27732774
/// Create a runtime call for kmpc_target_deinit
@@ -2929,7 +2930,7 @@ class OpenMPIRBuilder {
29292930
///
29302931
/// \param Loc where the target data construct was encountered.
29312932
/// \param IsOffloadEntry whether it is an offload entry.
2932-
/// \param IsSPMD whether it is a target SPMD kernel.
2933+
/// \param ExecFlags kernel execution mode flags.
29332934
/// \param CodeGenIP The insertion point where the call to the outlined
29342935
/// function should be emitted.
29352936
/// \param EntryInfo The entry information about the function.
@@ -2946,7 +2947,8 @@ class OpenMPIRBuilder {
29462947
// dependency information as passed in the depend clause
29472948
// \param HasNowait Whether the target construct has a `nowait` clause or not.
29482949
InsertPointOrErrorTy createTarget(
2949-
const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
2950+
const LocationDescription &Loc, bool IsOffloadEntry,
2951+
omp::OMPTgtExecModeFlags ExecFlags,
29502952
OpenMPIRBuilder::InsertPointTy AllocaIP,
29512953
OpenMPIRBuilder::InsertPointTy CodeGenIP,
29522954
TargetRegionEntryInfo &EntryInfo,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6124,7 +6124,7 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
61246124
}
61256125

61266126
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
6127-
const LocationDescription &Loc, bool IsSPMD,
6127+
const LocationDescription &Loc, omp::OMPTgtExecModeFlags ExecFlags,
61286128
const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
61296129
assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
61306130
"expected num_threads and num_teams to be specified");
@@ -6135,9 +6135,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
61356135
uint32_t SrcLocStrSize;
61366136
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
61376137
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6138-
Constant *IsSPMDVal = ConstantInt::getSigned(
6139-
Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6140-
Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Int8, !IsSPMD);
6138+
Constant *IsSPMDVal = ConstantInt::getSigned(Int8, ExecFlags);
6139+
Constant *UseGenericStateMachineVal =
6140+
ConstantInt::getSigned(Int8, ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
61416141
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
61426142
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
61436143

@@ -6742,7 +6742,8 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
67426742
}
67436743

67446744
static Expected<Function *> createOutlinedFunction(
6745-
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
6745+
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6746+
omp::OMPTgtExecModeFlags ExecFlags,
67466747
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
67476748
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
67486749
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6773,8 +6774,7 @@ static Expected<Function *> createOutlinedFunction(
67736774
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
67746775

67756776
if (OMPBuilder.Config.isTargetDevice()) {
6776-
Value *ExecMode = OMPBuilder.emitKernelExecutionMode(
6777-
FuncName, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6777+
Value *ExecMode = OMPBuilder.emitKernelExecutionMode(FuncName, ExecFlags);
67786778
OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
67796779
}
67806780

@@ -6818,7 +6818,7 @@ static Expected<Function *> createOutlinedFunction(
68186818
// Insert target init call in the device compilation pass.
68196819
if (OMPBuilder.Config.isTargetDevice())
68206820
Builder.restoreIP(
6821-
OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
6821+
OMPBuilder.createTargetInit(Builder, ExecFlags, DefaultAttrs));
68226822

68236823
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
68246824

@@ -7014,7 +7014,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
70147014

70157015
static Error emitTargetOutlinedFunction(
70167016
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
7017-
bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
7017+
omp::OMPTgtExecModeFlags ExecFlags, TargetRegionEntryInfo &EntryInfo,
70187018
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
70197019
Function *&OutlinedFn, Constant *&OutlinedFnID,
70207020
SmallVectorImpl<Value *> &Inputs,
@@ -7023,8 +7023,8 @@ static Error emitTargetOutlinedFunction(
70237023

70247024
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
70257025
[&](StringRef EntryFnName) {
7026-
return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
7027-
EntryFnName, Inputs, CBFunc,
7026+
return createOutlinedFunction(OMPBuilder, Builder, ExecFlags,
7027+
DefaultAttrs, EntryFnName, Inputs, CBFunc,
70287028
ArgAccessorFuncCB);
70297029
};
70307030

@@ -7484,9 +7484,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74847484
}
74857485

74867486
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7487-
const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
7488-
InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
7489-
TargetRegionEntryInfo &EntryInfo,
7487+
const LocationDescription &Loc, bool IsOffloadEntry,
7488+
omp::OMPTgtExecModeFlags ExecFlags, InsertPointTy AllocaIP,
7489+
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
74907490
const TargetKernelDefaultAttrs &DefaultAttrs,
74917491
const TargetKernelRuntimeAttrs &RuntimeAttrs,
74927492
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
@@ -7505,7 +7505,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
75057505
// the target region itself is generated using the callbacks CBFunc
75067506
// and ArgAccessorFuncCB
75077507
if (Error Err = emitTargetOutlinedFunction(
7508-
*this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
7508+
*this, Builder, IsOffloadEntry, ExecFlags, EntryInfo, DefaultAttrs,
75097509
OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
75107510
return Err;
75117511

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6189,7 +6189,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
61896189
RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
61906190
RuntimeAttrs.MaxThreads = Builder.getInt32(40);
61916191
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
6192-
OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
6192+
OmpLoc, /*IsOffloadEntry=*/true,
6193+
omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, Builder.saveIP(),
61936194
Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
61946195
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
61956196
assert(AfterIP && "unexpected error");
@@ -6340,7 +6341,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
63406341
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
63416342
OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
63426343
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
6343-
Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
6344+
Loc, /*IsOffloadEntry=*/true,
6345+
omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, EntryIP, EntryIP,
63446346
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
63456347
BodyGenCB, SimpleArgAccessorCB);
63466348
assert(AfterIP && "unexpected error");
@@ -6480,7 +6482,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
64806482
OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
64816483
RuntimeAttrs.LoopTripCount = Builder.getInt64(1000);
64826484
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
6483-
OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, Builder.saveIP(),
6485+
OmpLoc, /*IsOffloadEntry=*/true,
6486+
omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, Builder.saveIP(),
64846487
Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
64856488
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
64866489
assert(AfterIP && "unexpected error");
@@ -6580,7 +6583,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
65806583
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
65816584
OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
65826585
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
6583-
Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, EntryIP, EntryIP,
6586+
Loc, /*IsOffloadEntry=*/true,
6587+
omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, EntryIP, EntryIP,
65846588
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
65856589
BodyGenCB, SimpleArgAccessorCB);
65866590
assert(AfterIP && "unexpected error");
@@ -6686,7 +6690,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
66866690
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
66876691
OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
66886692
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
6689-
Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
6693+
Loc, /*IsOffloadEntry=*/true,
6694+
omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, EntryIP, EntryIP,
66906695
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
66916696
BodyGenCB, SimpleArgAccessorCB);
66926697
assert(AfterIP && "unexpected error");

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3975,9 +3975,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39753975

39763976
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
39773977
moduleTranslation.getOpenMPBuilder()->createTarget(
3978-
ompLoc, isOffloadEntry, /*IsSPMD=*/false, allocaIP, builder.saveIP(),
3979-
entryInfo, defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB,
3980-
bodyCB, argAccessorCB, dds, targetOp.getNowait());
3978+
ompLoc, isOffloadEntry, llvm::omp::OMP_TGT_EXEC_MODE_GENERIC,
3979+
allocaIP, builder.saveIP(), entryInfo, defaultAttrs, runtimeAttrs,
3980+
kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
3981+
targetOp.getNowait());
39813982

39823983
if (failed(handleError(afterIP, opInst)))
39833984
return failure();

0 commit comments

Comments
 (0)