@@ -501,15 +501,20 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
501
501
Value *ZeroArray = Constant::getNullValue (ArrayType::get (Int32Ty, MaxDim));
502
502
Value *Flags = Builder.getInt64 (KernelArgs.HasNoWait );
503
503
504
- assert (!KernelArgs.NumTeams .empty ());
504
+ assert (!KernelArgs.NumTeams .empty () && !KernelArgs. NumThreads . empty () );
505
505
506
506
Value *NumTeams3D =
507
507
Builder.CreateInsertValue (ZeroArray, KernelArgs.NumTeams [0 ], {0 });
508
- for (unsigned I = 1 ; I < std::min (KernelArgs.NumTeams .size (), MaxDim); ++I)
508
+ Value *NumThreads3D =
509
+ Builder.CreateInsertValue (ZeroArray, KernelArgs.NumThreads [0 ], {0 });
510
+ for (unsigned I :
511
+ seq<unsigned >(1 , std::min (KernelArgs.NumTeams .size (), MaxDim)))
509
512
NumTeams3D =
510
513
Builder.CreateInsertValue (NumTeams3D, KernelArgs.NumTeams [I], {I});
511
- Value *NumThreads3D =
512
- Builder.CreateInsertValue (ZeroArray, KernelArgs.NumThreads , {0 });
514
+ for (unsigned I :
515
+ seq<unsigned >(1 , std::min (KernelArgs.NumThreads .size (), MaxDim)))
516
+ NumThreads3D =
517
+ Builder.CreateInsertValue (NumThreads3D, KernelArgs.NumThreads [I], {I});
513
518
514
519
ArgsVector = {Version,
515
520
PointerNum,
@@ -1114,9 +1119,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1114
1119
// __tgt_target_teams() launches a GPU kernel with the requested number
1115
1120
// of teams and threads so no additional calls to the runtime are required.
1116
1121
// Check the error code and execute the host version if required.
1117
- Builder.restoreIP (emitTargetKernel (Builder, AllocaIP, Return, RTLoc, DeviceID,
1118
- Args.NumTeams .front (), Args. NumThreads ,
1119
- OutlinedFnID, ArgsVector));
1122
+ Builder.restoreIP (emitTargetKernel (
1123
+ Builder, AllocaIP, Return, RTLoc, DeviceID, Args.NumTeams .front (),
1124
+ Args. NumThreads . front (), OutlinedFnID, ArgsVector));
1120
1125
1121
1126
BasicBlock *OffloadFailedBlock =
1122
1127
BasicBlock::Create (Builder.getContext (), " omp_offload.failed" );
@@ -7075,8 +7080,8 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7075
7080
static void emitTargetCall (
7076
7081
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7077
7082
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7078
- Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams, int32_t NumThreads,
7079
- SmallVectorImpl<Value *> &Args,
7083
+ Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7084
+ ArrayRef< int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7080
7085
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7081
7086
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
7082
7087
// Generate a function call to the host fallback implementation of the target
@@ -7123,13 +7128,15 @@ static void emitTargetCall(
7123
7128
/* ForEndCall=*/ false );
7124
7129
7125
7130
SmallVector<Value *, 3 > NumTeamsC;
7131
+ SmallVector<Value *, 3 > NumThreadsC;
7126
7132
for (auto V : NumTeams)
7127
7133
NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7134
+ for (auto V : NumThreads)
7135
+ NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7128
7136
7129
7137
unsigned NumTargetItems = Info.NumberOfPtrs ;
7130
7138
// TODO: Use correct device ID
7131
7139
Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7132
- Value *NumThreadsVal = Builder.getInt32 (NumThreads);
7133
7140
uint32_t SrcLocStrSize;
7134
7141
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7135
7142
Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
@@ -7140,8 +7147,8 @@ static void emitTargetCall(
7140
7147
Value *DynCGGroupMem = Builder.getInt32 (0 );
7141
7148
7142
7149
OpenMPIRBuilder::TargetKernelArgs KArgs (NumTargetItems, RTArgs, NumIterations,
7143
- NumTeamsC, NumThreadsVal ,
7144
- DynCGGroupMem, HasNoWait);
7150
+ NumTeamsC, NumThreadsC, DynCGGroupMem ,
7151
+ HasNoWait);
7145
7152
7146
7153
// The presence of certain clauses on the target directive require the
7147
7154
// explicit generation of the target task.
@@ -7159,11 +7166,11 @@ static void emitTargetCall(
7159
7166
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget (
7160
7167
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7161
7168
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7162
- ArrayRef<int32_t > NumTeams, int32_t NumThreads,
7169
+ ArrayRef<int32_t > NumTeams, ArrayRef< int32_t > NumThreads,
7163
7170
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7164
7171
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7165
7172
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7166
- SmallVector<DependData> Dependenciess ) {
7173
+ SmallVector<DependData> Dependencies ) {
7167
7174
7168
7175
if (!updateToLocation (Loc))
7169
7176
return InsertPointTy ();
@@ -7184,7 +7191,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7184
7191
// that represents the target region. Do that now.
7185
7192
if (!Config.isTargetDevice ())
7186
7193
emitTargetCall (*this , Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7187
- NumThreads, Args, GenMapInfoCB, Dependenciess );
7194
+ NumThreads, Args, GenMapInfoCB, Dependencies );
7188
7195
return Builder.saveIP ();
7189
7196
}
7190
7197
0 commit comments