@@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
830
830
return GV;
831
831
}
832
832
833
+ void OpenMPIRBuilder::emitUsed (StringRef Name, ArrayRef<WeakTrackingVH> List) {
834
+ if (List.empty ())
835
+ return ;
836
+
837
+ // Convert List to what ConstantArray needs.
838
+ SmallVector<Constant *, 8 > UsedArray;
839
+ UsedArray.resize (List.size ());
840
+ for (unsigned I = 0 , E = List.size (); I != E; ++I)
841
+ UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast (
842
+ cast<Constant>(&*List[I]), Builder.getPtrTy ());
843
+
844
+ if (UsedArray.empty ())
845
+ return ;
846
+ ArrayType *ATy = ArrayType::get (Builder.getPtrTy (), UsedArray.size ());
847
+
848
+ auto *GV = new GlobalVariable (M, ATy, false , GlobalValue::AppendingLinkage,
849
+ ConstantArray::get (ATy, UsedArray), Name);
850
+
851
+ GV->setSection (" llvm.metadata" );
852
+ }
853
+
854
+ GlobalVariable *
855
+ OpenMPIRBuilder::emitKernelExecutionMode (StringRef KernelName,
856
+ OMPTgtExecModeFlags Mode) {
857
+ auto *Int8Ty = Builder.getInt8Ty ();
858
+ auto *GVMode = new GlobalVariable (
859
+ M, Int8Ty, /* isConstant=*/ true , GlobalValue::WeakAnyLinkage,
860
+ ConstantInt::get (Int8Ty, Mode), Twine (KernelName, " _exec_mode" ));
861
+ GVMode->setVisibility (GlobalVariable::ProtectedVisibility);
862
+ return GVMode;
863
+ }
864
+
833
865
Constant *OpenMPIRBuilder::getOrCreateIdent (Constant *SrcLocStr,
834
866
uint32_t SrcLocStrSize,
835
867
IdentFlag LocFlags,
@@ -2260,28 +2292,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2260
2292
return OpenMPIRBuilder::InsertPointTy (I->getParent (), IT);
2261
2293
}
2262
2294
2263
- void OpenMPIRBuilder::emitUsed (StringRef Name,
2264
- std::vector<WeakTrackingVH> &List) {
2265
- if (List.empty ())
2266
- return ;
2267
-
2268
- // Convert List to what ConstantArray needs.
2269
- SmallVector<Constant *, 8 > UsedArray;
2270
- UsedArray.resize (List.size ());
2271
- for (unsigned I = 0 , E = List.size (); I != E; ++I)
2272
- UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast (
2273
- cast<Constant>(&*List[I]), Builder.getPtrTy ());
2274
-
2275
- if (UsedArray.empty ())
2276
- return ;
2277
- ArrayType *ATy = ArrayType::get (Builder.getPtrTy (), UsedArray.size ());
2278
-
2279
- auto *GV = new GlobalVariable (M, ATy, false , GlobalValue::AppendingLinkage,
2280
- ConstantArray::get (ATy, UsedArray), Name);
2281
-
2282
- GV->setSection (" llvm.metadata" );
2283
- }
2284
-
2285
2295
Value *OpenMPIRBuilder::getGPUThreadID () {
2286
2296
return Builder.CreateCall (
2287
2297
getOrCreateRuntimeFunction (M,
@@ -6140,10 +6150,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
6140
6150
uint32_t SrcLocStrSize;
6141
6151
Constant *SrcLocStr = getOrCreateSrcLocStr (Loc, SrcLocStrSize);
6142
6152
Constant *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
6143
- Constant *IsSPMDVal = ConstantInt::getSigned (
6144
- Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6145
- Constant *UseGenericStateMachineVal =
6146
- ConstantInt::getSigned (Int8, !Attrs.IsSPMD );
6153
+ Constant *IsSPMDVal = ConstantInt::getSigned (Int8, Attrs.ExecFlags );
6154
+ Constant *UseGenericStateMachineVal = ConstantInt::getSigned (
6155
+ Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
6147
6156
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned (Int8, true );
6148
6157
Constant *DebugIndentionLevelVal = ConstantInt::getSigned (Int16, 0 );
6149
6158
@@ -6778,6 +6787,12 @@ static Expected<Function *> createOutlinedFunction(
6778
6787
auto Func =
6779
6788
Function::Create (FuncType, GlobalValue::InternalLinkage, FuncName, M);
6780
6789
6790
+ if (OMPBuilder.Config .isTargetDevice ()) {
6791
+ Value *ExecMode =
6792
+ OMPBuilder.emitKernelExecutionMode (FuncName, DefaultAttrs.ExecFlags );
6793
+ OMPBuilder.emitUsed (" llvm.compiler.used" , {ExecMode});
6794
+ }
6795
+
6781
6796
// Save insert point.
6782
6797
IRBuilder<>::InsertPointGuard IPG (Builder);
6783
6798
// If there's a DISubprogram associated with current function, then
@@ -7325,6 +7340,7 @@ static void
7325
7340
emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7326
7341
OpenMPIRBuilder::InsertPointTy AllocaIP,
7327
7342
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7343
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7328
7344
Function *OutlinedFn, Constant *OutlinedFnID,
7329
7345
SmallVectorImpl<Value *> &Args,
7330
7346
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7406,11 +7422,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7406
7422
/* ForEndCall=*/ false );
7407
7423
7408
7424
SmallVector<Value *, 3 > NumTeamsC;
7425
+ for (auto [DefaultVal, RuntimeVal] :
7426
+ zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7427
+ NumTeamsC.push_back (RuntimeVal ? RuntimeVal : Builder.getInt32 (DefaultVal));
7428
+
7429
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7430
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7431
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7432
+ if (Clause)
7433
+ Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7434
+ /* isSigned=*/ false );
7435
+ return Clause;
7436
+ };
7437
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7438
+ if (Clause)
7439
+ Result = Result
7440
+ ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7441
+ Result, Clause)
7442
+ : Clause;
7443
+ };
7444
+
7445
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7446
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7409
7447
SmallVector<Value *, 3 > NumThreadsC;
7410
- for (auto V : DefaultAttrs.MaxTeams )
7411
- NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7412
- for (auto V : DefaultAttrs.MaxThreads )
7413
- NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7448
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit .size () == 1
7449
+ ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7450
+ : nullptr ;
7451
+
7452
+ for (auto [TeamsVal, TargetVal] : zip_equal (RuntimeAttrs.TeamsThreadLimit ,
7453
+ RuntimeAttrs.TargetThreadLimit )) {
7454
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7455
+ Value *NumThreads = InitMaxThreadsClause (TargetVal);
7456
+
7457
+ CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7458
+ CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7459
+
7460
+ NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7461
+ }
7414
7462
7415
7463
unsigned NumTargetItems = Info.NumberOfPtrs ;
7416
7464
// TODO: Use correct device ID
@@ -7419,14 +7467,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7419
7467
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7420
7468
Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7421
7469
llvm::omp::IdentFlag (0 ), 0 );
7422
- // TODO: Use correct NumIterations
7423
- Value *NumIterations = Builder.getInt64 (0 );
7470
+
7471
+ Value *TripCount = RuntimeAttrs.LoopTripCount
7472
+ ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7473
+ Builder.getInt64Ty (),
7474
+ /* isSigned=*/ false )
7475
+ : Builder.getInt64 (0 );
7476
+
7424
7477
// TODO: Use correct DynCGGroupMem
7425
7478
Value *DynCGGroupMem = Builder.getInt32 (0 );
7426
7479
7427
- KArgs = OpenMPIRBuilder::TargetKernelArgs (
7428
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7429
- DynCGGroupMem, HasNoWait);
7480
+ KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7481
+ NumTeamsC, NumThreadsC,
7482
+ DynCGGroupMem, HasNoWait);
7430
7483
7431
7484
// The presence of certain clauses on the target directive require the
7432
7485
// explicit generation of the target task.
@@ -7451,6 +7504,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7451
7504
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7452
7505
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7453
7506
const TargetKernelDefaultAttrs &DefaultAttrs,
7507
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
7454
7508
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7455
7509
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7456
7510
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7475,8 +7529,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7475
7529
// to make a remote call (offload) to the previously outlined function
7476
7530
// that represents the target region. Do that now.
7477
7531
if (!Config.isTargetDevice ())
7478
- emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7479
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7532
+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7533
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7534
+ HasNowait);
7480
7535
return Builder.saveIP ();
7481
7536
}
7482
7537
0 commit comments