Skip to content

Commit 3ec01da

Browse files
committed
[Clang][OMPX] Add the code generation for multi-dim thread_limit clause
1 parent fd59f45 commit 3ec01da

File tree

4 files changed

+54
-44
lines changed

4 files changed

+54
-44
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9588,15 +9588,17 @@ static void genMapInfo(const OMPExecutableDirective &D, CodeGenFunction &CGF,
95889588
genMapInfo(MEHandler, CGF, CombinedInfo, OMPBuilder, MappedVarSet);
95899589
}
95909590

9591-
static void emitNumTeamsForBareTargetDirective(
9591+
template <typename ClauseTy>
9592+
static void emitClauseForBareTargetDirective(
95929593
CodeGenFunction &CGF, const OMPExecutableDirective &D,
9593-
llvm::SmallVectorImpl<llvm::Value *> &NumTeams) {
9594-
const auto *C = D.getSingleClause<OMPNumTeamsClause>();
9595-
assert(!C->varlist_empty() && "ompx_bare requires explicit num_teams");
9596-
CodeGenFunction::RunCleanupsScope NumTeamsScope(CGF);
9597-
for (auto *E : C->getNumTeams()) {
9594+
llvm::SmallVectorImpl<llvm::Value *> &Valuess) {
9595+
const auto *C = D.getSingleClause<ClauseTy>();
9596+
assert(!C->varlist_empty() &&
9597+
"ompx_bare requires explicit num_teams and thread_limit");
9598+
CodeGenFunction::RunCleanupsScope Scope(CGF);
9599+
for (auto *E : C->varlist()) {
95989600
llvm::Value *V = CGF.EmitScalarExpr(E);
9599-
NumTeams.push_back(
9601+
Valuess.push_back(
96009602
CGF.Builder.CreateIntCast(V, CGF.Int32Ty, /*isSigned=*/true));
96019603
}
96029604
}
@@ -9672,14 +9674,17 @@ static void emitTargetCallKernelLaunch(
96729674

96739675
bool IsBare = D.hasClausesOfKind<OMPXBareClause>();
96749676
SmallVector<llvm::Value *, 3> NumTeams;
9675-
if (IsBare)
9676-
emitNumTeamsForBareTargetDirective(CGF, D, NumTeams);
9677-
else
9677+
SmallVector<llvm::Value *, 3> NumThreads;
9678+
if (IsBare) {
9679+
emitClauseForBareTargetDirective<OMPNumTeamsClause>(CGF, D, NumTeams);
9680+
emitClauseForBareTargetDirective<OMPThreadLimitClause>(CGF, D,
9681+
NumThreads);
9682+
} else {
96789683
NumTeams.push_back(OMPRuntime->emitNumTeamsForTargetDirective(CGF, D));
9684+
NumThreads.push_back(OMPRuntime->emitNumThreadsForTargetDirective(CGF, D));
9685+
}
96799686

96809687
llvm::Value *DeviceID = emitDeviceID(Device, CGF);
9681-
llvm::Value *NumThreads =
9682-
OMPRuntime->emitNumThreadsForTargetDirective(CGF, D);
96839688
llvm::Value *RTLoc = OMPRuntime->emitUpdateLocation(CGF, D.getBeginLoc());
96849689
llvm::Value *NumIterations =
96859690
OMPRuntime->emitTargetNumIterationsCall(CGF, D, SizeEmitter);

clang/test/OpenMP/target_teams_codegen.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ int foo(int n) {
127127
aa += 1;
128128
}
129129

130-
#pragma omp target teams ompx_bare num_teams(1, 2) thread_limit(1)
130+
#pragma omp target teams ompx_bare num_teams(1, 2) thread_limit(1, 2)
131131
{
132132
a += 1;
133133
aa += 1;
134134
}
135135

136-
#pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(1)
136+
#pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(1, 2, 3)
137137
{
138138
a += 1;
139139
aa += 1;
@@ -667,7 +667,7 @@ int bar(int n){
667667
// CHECK1-NEXT: [[TMP144:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 10
668668
// CHECK1-NEXT: store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP144]], align 4
669669
// CHECK1-NEXT: [[TMP145:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 11
670-
// CHECK1-NEXT: store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP145]], align 4
670+
// CHECK1-NEXT: store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP145]], align 4
671671
// CHECK1-NEXT: [[TMP146:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 12
672672
// CHECK1-NEXT: store i32 0, ptr [[TMP146]], align 4
673673
// CHECK1-NEXT: [[TMP147:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l130.region_id, ptr [[KERNEL_ARGS29]])
@@ -720,7 +720,7 @@ int bar(int n){
720720
// CHECK1-NEXT: [[TMP171:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 10
721721
// CHECK1-NEXT: store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP171]], align 4
722722
// CHECK1-NEXT: [[TMP172:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 11
723-
// CHECK1-NEXT: store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP172]], align 4
723+
// CHECK1-NEXT: store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP172]], align 4
724724
// CHECK1-NEXT: [[TMP173:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 12
725725
// CHECK1-NEXT: store i32 0, ptr [[TMP173]], align 4
726726
// CHECK1-NEXT: [[TMP174:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l136.region_id, ptr [[KERNEL_ARGS37]])
@@ -2458,7 +2458,7 @@ int bar(int n){
24582458
// CHECK3-NEXT: [[TMP142:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 10
24592459
// CHECK3-NEXT: store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP142]], align 4
24602460
// CHECK3-NEXT: [[TMP143:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 11
2461-
// CHECK3-NEXT: store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP143]], align 4
2461+
// CHECK3-NEXT: store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP143]], align 4
24622462
// CHECK3-NEXT: [[TMP144:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 12
24632463
// CHECK3-NEXT: store i32 0, ptr [[TMP144]], align 4
24642464
// CHECK3-NEXT: [[TMP145:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l130.region_id, ptr [[KERNEL_ARGS29]])
@@ -2511,7 +2511,7 @@ int bar(int n){
25112511
// CHECK3-NEXT: [[TMP169:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 10
25122512
// CHECK3-NEXT: store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP169]], align 4
25132513
// CHECK3-NEXT: [[TMP170:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 11
2514-
// CHECK3-NEXT: store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP170]], align 4
2514+
// CHECK3-NEXT: store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP170]], align 4
25152515
// CHECK3-NEXT: [[TMP171:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 12
25162516
// CHECK3-NEXT: store i32 0, ptr [[TMP171]], align 4
25172517
// CHECK3-NEXT: [[TMP172:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l136.region_id, ptr [[KERNEL_ARGS37]])

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2195,7 +2195,7 @@ class OpenMPIRBuilder {
21952195
/// The number of teams.
21962196
ArrayRef<Value *> NumTeams;
21972197
/// The number of threads.
2198-
Value *NumThreads = nullptr;
2198+
ArrayRef<Value *> NumThreads;
21992199
/// The size of the dynamic shared memory.
22002200
Value *DynCGGroupMem = nullptr;
22012201
/// True if the kernel has 'no wait' clause.
@@ -2205,7 +2205,8 @@ class OpenMPIRBuilder {
22052205
TargetKernelArgs() {}
22062206
TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
22072207
Value *NumIterations, ArrayRef<Value *> NumTeams,
2208-
Value *NumThreads, Value *DynCGGroupMem, bool HasNoWait)
2208+
ArrayRef<Value *> NumThreads, Value *DynCGGroupMem,
2209+
bool HasNoWait)
22092210
: NumTargetItems(NumTargetItems), RTArgs(RTArgs),
22102211
NumIterations(NumIterations), NumTeams(NumTeams),
22112212
NumThreads(NumThreads), DynCGGroupMem(DynCGGroupMem),
@@ -2852,17 +2853,16 @@ class OpenMPIRBuilder {
28522853
/// instructions for passed in target arguments where neccessary
28532854
/// \param Dependencies A vector of DependData objects that carry
28542855
// dependency information as passed in the depend clause
2855-
InsertPointTy createTarget(const LocationDescription &Loc,
2856-
bool IsOffloadEntry,
2857-
OpenMPIRBuilder::InsertPointTy AllocaIP,
2858-
OpenMPIRBuilder::InsertPointTy CodeGenIP,
2859-
TargetRegionEntryInfo &EntryInfo,
2860-
ArrayRef<int32_t> NumTeams, int32_t NumThreads,
2861-
SmallVectorImpl<Value *> &Inputs,
2862-
GenMapInfoCallbackTy GenMapInfoCB,
2863-
TargetBodyGenCallbackTy BodyGenCB,
2864-
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
2865-
SmallVector<DependData> Dependencies = {});
2856+
InsertPointTy
2857+
createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
2858+
OpenMPIRBuilder::InsertPointTy AllocaIP,
2859+
OpenMPIRBuilder::InsertPointTy CodeGenIP,
2860+
TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
2861+
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
2862+
GenMapInfoCallbackTy GenMapInfoCB,
2863+
TargetBodyGenCallbackTy BodyGenCB,
2864+
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
2865+
SmallVector<DependData> Dependencies = {});
28662866

28672867
/// Returns __kmpc_for_static_init_* runtime function for the specified
28682868
/// size \a IVSize and sign \a IVSigned. Will create a distribute call

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -505,11 +505,14 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
505505

506506
Value *NumTeams3D =
507507
Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams[0], {0});
508+
Value *NumThreads3D =
509+
Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads[0], {0});
508510
for (unsigned I = 1; I < std::min(KernelArgs.NumTeams.size(), MaxDim); ++I)
509511
NumTeams3D =
510512
Builder.CreateInsertValue(NumTeams3D, KernelArgs.NumTeams[I], {I});
511-
Value *NumThreads3D =
512-
Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0});
513+
for (unsigned I = 1; I < std::min(KernelArgs.NumThreads.size(), MaxDim); ++I)
514+
NumThreads3D =
515+
Builder.CreateInsertValue(NumThreads3D, KernelArgs.NumTeams[I], {I});
513516

514517
ArgsVector = {Version,
515518
PointerNum,
@@ -1114,9 +1117,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
11141117
// __tgt_target_teams() launches a GPU kernel with the requested number
11151118
// of teams and threads so no additional calls to the runtime are required.
11161119
// Check the error code and execute the host version if required.
1117-
Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, DeviceID,
1118-
Args.NumTeams.front(), Args.NumThreads,
1119-
OutlinedFnID, ArgsVector));
1120+
Builder.restoreIP(emitTargetKernel(
1121+
Builder, AllocaIP, Return, RTLoc, DeviceID, Args.NumTeams.front(),
1122+
Args.NumThreads.front(), OutlinedFnID, ArgsVector));
11201123

11211124
BasicBlock *OffloadFailedBlock =
11221125
BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
@@ -7075,8 +7078,8 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
70757078
static void emitTargetCall(
70767079
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
70777080
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7078-
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams, int32_t NumThreads,
7079-
SmallVectorImpl<Value *> &Args,
7081+
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
7082+
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
70807083
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
70817084
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
70827085
// Generate a function call to the host fallback implementation of the target
@@ -7123,13 +7126,15 @@ static void emitTargetCall(
71237126
/*ForEndCall=*/false);
71247127

71257128
SmallVector<Value *, 3> NumTeamsC;
7129+
SmallVector<Value *, 3> NumThreadsC;
71267130
for (auto V : NumTeams)
71277131
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7132+
for (auto V : NumThreads)
7133+
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
71287134

71297135
unsigned NumTargetItems = Info.NumberOfPtrs;
71307136
// TODO: Use correct device ID
71317137
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7132-
Value *NumThreadsVal = Builder.getInt32(NumThreads);
71337138
uint32_t SrcLocStrSize;
71347139
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
71357140
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
@@ -7140,8 +7145,8 @@ static void emitTargetCall(
71407145
Value *DynCGGroupMem = Builder.getInt32(0);
71417146

71427147
OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
7143-
NumTeamsC, NumThreadsVal,
7144-
DynCGGroupMem, HasNoWait);
7148+
NumTeamsC, NumThreadsC, DynCGGroupMem,
7149+
HasNoWait);
71457150

71467151
// The presence of certain clauses on the target directive require the
71477152
// explicit generation of the target task.
@@ -7159,11 +7164,11 @@ static void emitTargetCall(
71597164
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
71607165
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
71617166
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7162-
ArrayRef<int32_t> NumTeams, int32_t NumThreads,
7167+
ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
71637168
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
71647169
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
71657170
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7166-
SmallVector<DependData> Dependenciess) {
7171+
SmallVector<DependData> Dependencies) {
71677172

71687173
if (!updateToLocation(Loc))
71697174
return InsertPointTy();
@@ -7184,7 +7189,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
71847189
// that represents the target region. Do that now.
71857190
if (!Config.isTargetDevice())
71867191
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7187-
NumThreads, Args, GenMapInfoCB, Dependenciess);
7192+
NumThreads, Args, GenMapInfoCB, Dependencies);
71887193
return Builder.saveIP();
71897194
}
71907195

0 commit comments

Comments
 (0)