@@ -6113,10 +6113,12 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
6113
6113
return Builder.CreateCall (Fn, Args);
6114
6114
}
6115
6115
6116
- OpenMPIRBuilder::InsertPointTy
6117
- OpenMPIRBuilder::createTargetInit (const LocationDescription &Loc, bool IsSPMD,
6118
- int32_t MinThreadsVal, int32_t MaxThreadsVal,
6119
- int32_t MinTeamsVal, int32_t MaxTeamsVal) {
6116
+ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit (
6117
+ const LocationDescription &Loc, bool IsSPMD,
6118
+ const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
6119
+ assert (!Attrs.MaxThreads .empty () && !Attrs.MaxTeams .empty () &&
6120
+ " expected num_threads and num_teams to be specified" );
6121
+
6120
6122
if (!updateToLocation (Loc))
6121
6123
return Loc.IP ;
6122
6124
@@ -6143,21 +6145,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
6143
6145
6144
6146
// Manifest the launch configuration in the metadata matching the kernel
6145
6147
// environment.
6146
- if (MinTeamsVal > 1 || MaxTeamsVal > 0 )
6147
- writeTeamsForKernel (T, *Kernel, MinTeamsVal, MaxTeamsVal );
6148
+ if (Attrs. MinTeams > 1 || Attrs. MaxTeams . front () > 0 )
6149
+ writeTeamsForKernel (T, *Kernel, Attrs. MinTeams , Attrs. MaxTeams . front () );
6148
6150
6149
- // For max values, < 0 means unset, == 0 means set but unknown.
6151
+ // If MaxThreads not set, select the maximum between the default workgroup
6152
+ // size and the MinThreads value.
6153
+ int32_t MaxThreadsVal = Attrs.MaxThreads .front ();
6150
6154
if (MaxThreadsVal < 0 )
6151
6155
MaxThreadsVal = std::max (
6152
- int32_t (getGridValue (T, Kernel).GV_Default_WG_Size ), MinThreadsVal );
6156
+ int32_t (getGridValue (T, Kernel).GV_Default_WG_Size ), Attrs. MinThreads );
6153
6157
6154
6158
if (MaxThreadsVal > 0 )
6155
- writeThreadBoundsForKernel (T, *Kernel, MinThreadsVal , MaxThreadsVal);
6159
+ writeThreadBoundsForKernel (T, *Kernel, Attrs. MinThreads , MaxThreadsVal);
6156
6160
6157
- Constant *MinThreads = ConstantInt::getSigned (Int32, MinThreadsVal );
6161
+ Constant *MinThreads = ConstantInt::getSigned (Int32, Attrs. MinThreads );
6158
6162
Constant *MaxThreads = ConstantInt::getSigned (Int32, MaxThreadsVal);
6159
- Constant *MinTeams = ConstantInt::getSigned (Int32, MinTeamsVal );
6160
- Constant *MaxTeams = ConstantInt::getSigned (Int32, MaxTeamsVal );
6163
+ Constant *MinTeams = ConstantInt::getSigned (Int32, Attrs. MinTeams );
6164
+ Constant *MaxTeams = ConstantInt::getSigned (Int32, Attrs. MaxTeams . front () );
6161
6165
Constant *ReductionDataSize = ConstantInt::getSigned (Int32, 0 );
6162
6166
Constant *ReductionBufferLength = ConstantInt::getSigned (Int32, 0 );
6163
6167
@@ -6728,8 +6732,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6728
6732
}
6729
6733
6730
6734
static Expected<Function *> createOutlinedFunction (
6731
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
6732
- SmallVectorImpl<Value *> &Inputs,
6735
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6736
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6737
+ StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
6733
6738
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6734
6739
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6735
6740
SmallVector<Type *> ParameterTypes;
@@ -6796,7 +6801,8 @@ static Expected<Function *> createOutlinedFunction(
6796
6801
6797
6802
// Insert target init call in the device compilation pass.
6798
6803
if (OMPBuilder.Config .isTargetDevice ())
6799
- Builder.restoreIP (OMPBuilder.createTargetInit (Builder, /* IsSPMD*/ false ));
6804
+ Builder.restoreIP (
6805
+ OMPBuilder.createTargetInit (Builder, /* IsSPMD=*/ false , DefaultAttrs));
6800
6806
6801
6807
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock ();
6802
6808
@@ -6992,16 +6998,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
6992
6998
6993
6999
static Error emitTargetOutlinedFunction (
6994
7000
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
6995
- TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
6996
- Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
7001
+ TargetRegionEntryInfo &EntryInfo,
7002
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7003
+ Function *&OutlinedFn, Constant *&OutlinedFnID,
7004
+ SmallVectorImpl<Value *> &Inputs,
6997
7005
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6998
7006
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6999
7007
7000
7008
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7001
- [&OMPBuilder, &Builder, &Inputs, &CBFunc,
7002
- &ArgAccessorFuncCB](StringRef EntryFnName) {
7003
- return createOutlinedFunction (OMPBuilder, Builder, EntryFnName, Inputs,
7004
- CBFunc, ArgAccessorFuncCB);
7009
+ [&](StringRef EntryFnName) {
7010
+ return createOutlinedFunction (OMPBuilder, Builder, DefaultAttrs,
7011
+ EntryFnName, Inputs, CBFunc ,
7012
+ ArgAccessorFuncCB);
7005
7013
};
7006
7014
7007
7015
return OMPBuilder.emitTargetRegionFunction (
@@ -7297,9 +7305,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7297
7305
7298
7306
static void
7299
7307
emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7300
- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7301
- Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7302
- ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7308
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
7309
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7310
+ Function *OutlinedFn, Constant *OutlinedFnID,
7311
+ SmallVectorImpl<Value *> &Args,
7303
7312
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7304
7313
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
7305
7314
bool HasNoWait = false ) {
@@ -7380,9 +7389,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7380
7389
7381
7390
SmallVector<Value *, 3 > NumTeamsC;
7382
7391
SmallVector<Value *, 3 > NumThreadsC;
7383
- for (auto V : NumTeams )
7392
+ for (auto V : DefaultAttrs. MaxTeams )
7384
7393
NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7385
- for (auto V : NumThreads )
7394
+ for (auto V : DefaultAttrs. MaxThreads )
7386
7395
NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7387
7396
7388
7397
unsigned NumTargetItems = Info.NumberOfPtrs ;
@@ -7423,7 +7432,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7423
7432
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
7424
7433
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7425
7434
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7426
- ArrayRef< int32_t > NumTeams, ArrayRef< int32_t > NumThreads ,
7435
+ const TargetKernelDefaultAttrs &DefaultAttrs ,
7427
7436
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7428
7437
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7429
7438
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7440,16 +7449,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7440
7449
// the target region itself is generated using the callbacks CBFunc
7441
7450
// and ArgAccessorFuncCB
7442
7451
if (Error Err = emitTargetOutlinedFunction (
7443
- *this , Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID ,
7444
- Args, CBFunc, ArgAccessorFuncCB))
7452
+ *this , Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn ,
7453
+ OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7445
7454
return Err;
7446
7455
7447
7456
// If we are not on the target device, then we need to generate code
7448
7457
// to make a remote call (offload) to the previously outlined function
7449
7458
// that represents the target region. Do that now.
7450
7459
if (!Config.isTargetDevice ())
7451
- emitTargetCall (*this , Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams ,
7452
- NumThreads , Args, GenMapInfoCB, Dependencies, HasNowait);
7460
+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, OutlinedFn ,
7461
+ OutlinedFnID , Args, GenMapInfoCB, Dependencies, HasNowait);
7453
7462
return Builder.saveIP ();
7454
7463
}
7455
7464
0 commit comments