@@ -6731,8 +6731,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6731
6731
return getOrCreateRuntimeFunction (M, omp::OMPRTL___kmpc_dispatch_deinit);
6732
6732
}
6733
6733
6734
+ static void emitUsed (StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
6735
+ Module &M) {
6736
+ if (List.empty ())
6737
+ return ;
6738
+
6739
+ Type *PtrTy = PointerType::get (M.getContext (), /* AddressSpace=*/ 0 );
6740
+
6741
+ // Convert List to what ConstantArray needs.
6742
+ SmallVector<Constant *, 8 > UsedArray;
6743
+ UsedArray.reserve (List.size ());
6744
+ for (auto Item : List)
6745
+ UsedArray.push_back (ConstantExpr::getPointerBitCastOrAddrSpaceCast (
6746
+ cast<Constant>(&*Item), PtrTy));
6747
+
6748
+ ArrayType *ArrTy = ArrayType::get (PtrTy, UsedArray.size ());
6749
+ auto *GV =
6750
+ new GlobalVariable (M, ArrTy, false , llvm::GlobalValue::AppendingLinkage,
6751
+ llvm::ConstantArray::get (ArrTy, UsedArray), Name);
6752
+
6753
+ GV->setSection (" llvm.metadata" );
6754
+ }
6755
+
6756
+ static void
6757
+ emitExecutionMode (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6758
+ StringRef FunctionName, OMPTgtExecModeFlags Mode,
6759
+ std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
6760
+ auto *Int8Ty = Type::getInt8Ty (Builder.getContext ());
6761
+ auto *GVMode = new llvm::GlobalVariable (
6762
+ OMPBuilder.M , Int8Ty, /* isConstant=*/ true ,
6763
+ llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get (Int8Ty, Mode),
6764
+ Twine (FunctionName, " _exec_mode" ));
6765
+ GVMode->setVisibility (llvm::GlobalVariable::ProtectedVisibility);
6766
+ LLVMCompilerUsed.emplace_back (GVMode);
6767
+ }
6768
+
6734
6769
static Expected<Function *> createOutlinedFunction (
6735
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6770
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
6736
6771
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6737
6772
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
6738
6773
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6762,6 +6797,15 @@ static Expected<Function *> createOutlinedFunction(
6762
6797
auto Func =
6763
6798
Function::Create (FuncType, GlobalValue::InternalLinkage, FuncName, M);
6764
6799
6800
+ if (OMPBuilder.Config .isTargetDevice ()) {
6801
+ std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
6802
+ emitExecutionMode (OMPBuilder, Builder, FuncName,
6803
+ IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
6804
+ : OMP_TGT_EXEC_MODE_GENERIC,
6805
+ LLVMCompilerUsed);
6806
+ emitUsed (" llvm.compiler.used" , LLVMCompilerUsed, OMPBuilder.M );
6807
+ }
6808
+
6765
6809
// Save insert point.
6766
6810
IRBuilder<>::InsertPointGuard IPG (Builder);
6767
6811
// If there's a DISubprogram associated with current function, then
@@ -6802,7 +6846,7 @@ static Expected<Function *> createOutlinedFunction(
6802
6846
// Insert target init call in the device compilation pass.
6803
6847
if (OMPBuilder.Config .isTargetDevice ())
6804
6848
Builder.restoreIP (
6805
- OMPBuilder.createTargetInit (Builder, /* IsSPMD= */ false , DefaultAttrs));
6849
+ OMPBuilder.createTargetInit (Builder, IsSPMD, DefaultAttrs));
6806
6850
6807
6851
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock ();
6808
6852
@@ -6998,7 +7042,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
6998
7042
6999
7043
static Error emitTargetOutlinedFunction (
7000
7044
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
7001
- TargetRegionEntryInfo &EntryInfo,
7045
+ bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
7002
7046
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7003
7047
Function *&OutlinedFn, Constant *&OutlinedFnID,
7004
7048
SmallVectorImpl<Value *> &Inputs,
@@ -7007,7 +7051,7 @@ static Error emitTargetOutlinedFunction(
7007
7051
7008
7052
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7009
7053
[&](StringRef EntryFnName) {
7010
- return createOutlinedFunction (OMPBuilder, Builder, DefaultAttrs,
7054
+ return createOutlinedFunction (OMPBuilder, Builder, IsSPMD, DefaultAttrs,
7011
7055
EntryFnName, Inputs, CBFunc,
7012
7056
ArgAccessorFuncCB);
7013
7057
};
@@ -7307,6 +7351,7 @@ static void
7307
7351
emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7308
7352
OpenMPIRBuilder::InsertPointTy AllocaIP,
7309
7353
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7354
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7310
7355
Function *OutlinedFn, Constant *OutlinedFnID,
7311
7356
SmallVectorImpl<Value *> &Args,
7312
7357
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7388,11 +7433,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7388
7433
/* ForEndCall=*/ false );
7389
7434
7390
7435
SmallVector<Value *, 3 > NumTeamsC;
7436
+ for (auto [DefaultVal, RuntimeVal] :
7437
+ zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7438
+ NumTeamsC.push_back (RuntimeVal ? RuntimeVal : Builder.getInt32 (DefaultVal));
7439
+
7440
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7441
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7442
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7443
+ if (Clause)
7444
+ Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7445
+ /* isSigned=*/ false );
7446
+ return Clause;
7447
+ };
7448
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7449
+ if (Clause)
7450
+ Result = Result
7451
+ ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7452
+ Result, Clause)
7453
+ : Clause;
7454
+ };
7455
+
7456
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7457
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7391
7458
SmallVector<Value *, 3 > NumThreadsC;
7392
- for (auto V : DefaultAttrs.MaxTeams )
7393
- NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7394
- for (auto V : DefaultAttrs.MaxThreads )
7395
- NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7459
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit .size () == 1
7460
+ ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7461
+ : nullptr ;
7462
+
7463
+ for (auto [TeamsVal, TargetVal] : llvm::zip_equal (
7464
+ RuntimeAttrs.TeamsThreadLimit , RuntimeAttrs.TargetThreadLimit )) {
7465
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7466
+ Value *NumThreads = InitMaxThreadsClause (TargetVal);
7467
+
7468
+ CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7469
+ CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7470
+
7471
+ NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7472
+ }
7396
7473
7397
7474
unsigned NumTargetItems = Info.NumberOfPtrs ;
7398
7475
// TODO: Use correct device ID
@@ -7401,14 +7478,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7401
7478
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7402
7479
Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7403
7480
llvm::omp::IdentFlag (0 ), 0 );
7404
- // TODO: Use correct NumIterations
7405
- Value *NumIterations = Builder.getInt64 (0 );
7481
+
7482
+ Value *TripCount = RuntimeAttrs.LoopTripCount
7483
+ ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7484
+ Builder.getInt64Ty (),
7485
+ /* isSigned=*/ false )
7486
+ : Builder.getInt64 (0 );
7487
+
7406
7488
// TODO: Use correct DynCGGroupMem
7407
7489
Value *DynCGGroupMem = Builder.getInt32 (0 );
7408
7490
7409
- KArgs = OpenMPIRBuilder::TargetKernelArgs (
7410
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7411
- DynCGGroupMem, HasNoWait);
7491
+ KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7492
+ NumTeamsC, NumThreadsC,
7493
+ DynCGGroupMem, HasNoWait);
7412
7494
7413
7495
// The presence of certain clauses on the target directive require the
7414
7496
// explicit generation of the target task.
@@ -7430,13 +7512,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7430
7512
}
7431
7513
7432
7514
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
7433
- const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7434
- InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7515
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
7516
+ InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
7517
+ TargetRegionEntryInfo &EntryInfo,
7435
7518
const TargetKernelDefaultAttrs &DefaultAttrs,
7519
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
7436
7520
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7437
7521
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7438
7522
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7439
7523
SmallVector<DependData> Dependencies, bool HasNowait) {
7524
+ assert ((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
7525
+ " trip count not expected if IsSPMD=false" );
7440
7526
7441
7527
if (!updateToLocation (Loc))
7442
7528
return InsertPointTy ();
@@ -7449,16 +7535,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7449
7535
// the target region itself is generated using the callbacks CBFunc
7450
7536
// and ArgAccessorFuncCB
7451
7537
if (Error Err = emitTargetOutlinedFunction (
7452
- *this , Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn ,
7453
- OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7538
+ *this , Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs ,
7539
+ OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7454
7540
return Err;
7455
7541
7456
7542
// If we are not on the target device, then we need to generate code
7457
7543
// to make a remote call (offload) to the previously outlined function
7458
7544
// that represents the target region. Do that now.
7459
7545
if (!Config.isTargetDevice ())
7460
- emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7461
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7546
+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7547
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7548
+ HasNowait);
7462
7549
return Builder.saveIP ();
7463
7550
}
7464
7551
0 commit comments