@@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
830
830
return GV;
831
831
}
832
832
833
+ void OpenMPIRBuilder::emitUsed (StringRef Name, ArrayRef<WeakTrackingVH> List) {
834
+ if (List.empty ())
835
+ return ;
836
+
837
+ // Convert List to what ConstantArray needs.
838
+ SmallVector<Constant *, 8 > UsedArray;
839
+ UsedArray.resize (List.size ());
840
+ for (unsigned I = 0 , E = List.size (); I != E; ++I)
841
+ UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast (
842
+ cast<Constant>(&*List[I]), Builder.getPtrTy ());
843
+
844
+ if (UsedArray.empty ())
845
+ return ;
846
+ ArrayType *ATy = ArrayType::get (Builder.getPtrTy (), UsedArray.size ());
847
+
848
+ auto *GV = new GlobalVariable (M, ATy, false , GlobalValue::AppendingLinkage,
849
+ ConstantArray::get (ATy, UsedArray), Name);
850
+
851
+ GV->setSection (" llvm.metadata" );
852
+ }
853
+
854
+ GlobalVariable *
855
+ OpenMPIRBuilder::emitKernelExecutionMode (StringRef KernelName,
856
+ OMPTgtExecModeFlags Mode) {
857
+ auto *Int8Ty = Builder.getInt8Ty ();
858
+ auto *GVMode = new GlobalVariable (
859
+ M, Int8Ty, /* isConstant=*/ true , GlobalValue::WeakAnyLinkage,
860
+ ConstantInt::get (Int8Ty, Mode), Twine (KernelName, " _exec_mode" ));
861
+ GVMode->setVisibility (GlobalVariable::ProtectedVisibility);
862
+ return GVMode;
863
+ }
864
+
833
865
Constant *OpenMPIRBuilder::getOrCreateIdent (Constant *SrcLocStr,
834
866
uint32_t SrcLocStrSize,
835
867
IdentFlag LocFlags,
@@ -2260,28 +2292,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2260
2292
return OpenMPIRBuilder::InsertPointTy (I->getParent (), IT);
2261
2293
}
2262
2294
2263
- void OpenMPIRBuilder::emitUsed (StringRef Name,
2264
- std::vector<WeakTrackingVH> &List) {
2265
- if (List.empty ())
2266
- return ;
2267
-
2268
- // Convert List to what ConstantArray needs.
2269
- SmallVector<Constant *, 8 > UsedArray;
2270
- UsedArray.resize (List.size ());
2271
- for (unsigned I = 0 , E = List.size (); I != E; ++I)
2272
- UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast (
2273
- cast<Constant>(&*List[I]), Builder.getPtrTy ());
2274
-
2275
- if (UsedArray.empty ())
2276
- return ;
2277
- ArrayType *ATy = ArrayType::get (Builder.getPtrTy (), UsedArray.size ());
2278
-
2279
- auto *GV = new GlobalVariable (M, ATy, false , GlobalValue::AppendingLinkage,
2280
- ConstantArray::get (ATy, UsedArray), Name);
2281
-
2282
- GV->setSection (" llvm.metadata" );
2283
- }
2284
-
2285
2295
Value *OpenMPIRBuilder::getGPUThreadID () {
2286
2296
return Builder.CreateCall (
2287
2297
getOrCreateRuntimeFunction (M,
@@ -6131,10 +6141,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
6131
6141
uint32_t SrcLocStrSize;
6132
6142
Constant *SrcLocStr = getOrCreateSrcLocStr (Loc, SrcLocStrSize);
6133
6143
Constant *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
6134
- Constant *IsSPMDVal = ConstantInt::getSigned (
6135
- Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6136
- Constant *UseGenericStateMachineVal =
6137
- ConstantInt::getSigned (Int8, !Attrs.IsSPMD );
6144
+ Constant *IsSPMDVal = ConstantInt::getSigned (Int8, Attrs.ExecFlags );
6145
+ Constant *UseGenericStateMachineVal = ConstantInt::getSigned (
6146
+ Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
6138
6147
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned (Int8, true );
6139
6148
Constant *DebugIndentionLevelVal = ConstantInt::getSigned (Int16, 0 );
6140
6149
@@ -6765,6 +6774,12 @@ static Expected<Function *> createOutlinedFunction(
6765
6774
auto Func =
6766
6775
Function::Create (FuncType, GlobalValue::InternalLinkage, FuncName, M);
6767
6776
6777
+ if (OMPBuilder.Config .isTargetDevice ()) {
6778
+ Value *ExecMode =
6779
+ OMPBuilder.emitKernelExecutionMode (FuncName, DefaultAttrs.ExecFlags );
6780
+ OMPBuilder.emitUsed (" llvm.compiler.used" , {ExecMode});
6781
+ }
6782
+
6768
6783
// Save insert point.
6769
6784
IRBuilder<>::InsertPointGuard IPG (Builder);
6770
6785
// If there's a DISubprogram associated with current function, then
@@ -7312,6 +7327,7 @@ static void
7312
7327
emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7313
7328
OpenMPIRBuilder::InsertPointTy AllocaIP,
7314
7329
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7330
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7315
7331
Function *OutlinedFn, Constant *OutlinedFnID,
7316
7332
SmallVectorImpl<Value *> &Args,
7317
7333
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7393,11 +7409,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7393
7409
/* ForEndCall=*/ false );
7394
7410
7395
7411
SmallVector<Value *, 3 > NumTeamsC;
7412
+ for (auto [DefaultVal, RuntimeVal] :
7413
+ zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7414
+ NumTeamsC.push_back (RuntimeVal ? RuntimeVal : Builder.getInt32 (DefaultVal));
7415
+
7416
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7417
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7418
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7419
+ if (Clause)
7420
+ Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7421
+ /* isSigned=*/ false );
7422
+ return Clause;
7423
+ };
7424
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7425
+ if (Clause)
7426
+ Result = Result
7427
+ ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7428
+ Result, Clause)
7429
+ : Clause;
7430
+ };
7431
+
7432
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7433
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7396
7434
SmallVector<Value *, 3 > NumThreadsC;
7397
- for (auto V : DefaultAttrs.MaxTeams )
7398
- NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7399
- for (auto V : DefaultAttrs.MaxThreads )
7400
- NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7435
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit .size () == 1
7436
+ ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7437
+ : nullptr ;
7438
+
7439
+ for (auto [TeamsVal, TargetVal] : zip_equal (RuntimeAttrs.TeamsThreadLimit ,
7440
+ RuntimeAttrs.TargetThreadLimit )) {
7441
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7442
+ Value *NumThreads = InitMaxThreadsClause (TargetVal);
7443
+
7444
+ CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7445
+ CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7446
+
7447
+ NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7448
+ }
7401
7449
7402
7450
unsigned NumTargetItems = Info.NumberOfPtrs ;
7403
7451
// TODO: Use correct device ID
@@ -7406,14 +7454,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7406
7454
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7407
7455
Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7408
7456
llvm::omp::IdentFlag (0 ), 0 );
7409
- // TODO: Use correct NumIterations
7410
- Value *NumIterations = Builder.getInt64 (0 );
7457
+
7458
+ Value *TripCount = RuntimeAttrs.LoopTripCount
7459
+ ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7460
+ Builder.getInt64Ty (),
7461
+ /* isSigned=*/ false )
7462
+ : Builder.getInt64 (0 );
7463
+
7411
7464
// TODO: Use correct DynCGGroupMem
7412
7465
Value *DynCGGroupMem = Builder.getInt32 (0 );
7413
7466
7414
- KArgs = OpenMPIRBuilder::TargetKernelArgs (
7415
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7416
- DynCGGroupMem, HasNoWait);
7467
+ KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7468
+ NumTeamsC, NumThreadsC,
7469
+ DynCGGroupMem, HasNoWait);
7417
7470
7418
7471
// The presence of certain clauses on the target directive require the
7419
7472
// explicit generation of the target task.
@@ -7438,6 +7491,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7438
7491
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7439
7492
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7440
7493
const TargetKernelDefaultAttrs &DefaultAttrs,
7494
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
7441
7495
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7442
7496
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7443
7497
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7462,8 +7516,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7462
7516
// to make a remote call (offload) to the previously outlined function
7463
7517
// that represents the target region. Do that now.
7464
7518
if (!Config.isTargetDevice ())
7465
- emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7466
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7519
+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7520
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7521
+ HasNowait);
7467
7522
return Builder.saveIP ();
7468
7523
}
7469
7524
0 commit comments