@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6727
6727
return getOrCreateRuntimeFunction (M, omp::OMPRTL___kmpc_dispatch_deinit);
6728
6728
}
6729
6729
6730
+ static void emitUsed (StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
6731
+ Module &M) {
6732
+ if (List.empty ())
6733
+ return ;
6734
+
6735
+ Type *PtrTy = PointerType::get (M.getContext (), /* AddressSpace=*/ 0 );
6736
+
6737
+ // Convert List to what ConstantArray needs.
6738
+ SmallVector<Constant *, 8 > UsedArray;
6739
+ UsedArray.reserve (List.size ());
6740
+ for (auto Item : List)
6741
+ UsedArray.push_back (ConstantExpr::getPointerBitCastOrAddrSpaceCast (
6742
+ cast<Constant>(&*Item), PtrTy));
6743
+
6744
+ ArrayType *ArrTy = ArrayType::get (PtrTy, UsedArray.size ());
6745
+ auto *GV =
6746
+ new GlobalVariable (M, ArrTy, false , llvm::GlobalValue::AppendingLinkage,
6747
+ llvm::ConstantArray::get (ArrTy, UsedArray), Name);
6748
+
6749
+ GV->setSection (" llvm.metadata" );
6750
+ }
6751
+
6752
+ static void
6753
+ emitExecutionMode (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6754
+ StringRef FunctionName, OMPTgtExecModeFlags Mode,
6755
+ std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
6756
+ auto *Int8Ty = Type::getInt8Ty (Builder.getContext ());
6757
+ auto *GVMode = new llvm::GlobalVariable (
6758
+ OMPBuilder.M , Int8Ty, /* isConstant=*/ true ,
6759
+ llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get (Int8Ty, Mode),
6760
+ Twine (FunctionName, " _exec_mode" ));
6761
+ GVMode->setVisibility (llvm::GlobalVariable::ProtectedVisibility);
6762
+ LLVMCompilerUsed.emplace_back (GVMode);
6763
+ }
6764
+
6730
6765
static Expected<Function *> createOutlinedFunction (
6731
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6766
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
6732
6767
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6733
6768
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
6734
6769
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,15 @@ static Expected<Function *> createOutlinedFunction(
6758
6793
auto Func =
6759
6794
Function::Create (FuncType, GlobalValue::InternalLinkage, FuncName, M);
6760
6795
6796
+ if (OMPBuilder.Config .isTargetDevice ()) {
6797
+ std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
6798
+ emitExecutionMode (OMPBuilder, Builder, FuncName,
6799
+ IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
6800
+ : OMP_TGT_EXEC_MODE_GENERIC,
6801
+ LLVMCompilerUsed);
6802
+ emitUsed (" llvm.compiler.used" , LLVMCompilerUsed, OMPBuilder.M );
6803
+ }
6804
+
6761
6805
// Save insert point.
6762
6806
IRBuilder<>::InsertPointGuard IPG (Builder);
6763
6807
// If there's a DISubprogram associated with current function, then
@@ -6798,7 +6842,7 @@ static Expected<Function *> createOutlinedFunction(
6798
6842
// Insert target init call in the device compilation pass.
6799
6843
if (OMPBuilder.Config .isTargetDevice ())
6800
6844
Builder.restoreIP (
6801
- OMPBuilder.createTargetInit (Builder, /* IsSPMD= */ false , DefaultAttrs));
6845
+ OMPBuilder.createTargetInit (Builder, IsSPMD, DefaultAttrs));
6802
6846
6803
6847
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock ();
6804
6848
@@ -6995,7 +7039,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
6995
7039
6996
7040
static Error emitTargetOutlinedFunction (
6997
7041
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
6998
- TargetRegionEntryInfo &EntryInfo,
7042
+ bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
6999
7043
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7000
7044
Function *&OutlinedFn, Constant *&OutlinedFnID,
7001
7045
SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7048,7 @@ static Error emitTargetOutlinedFunction(
7004
7048
7005
7049
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7006
7050
[&](StringRef EntryFnName) {
7007
- return createOutlinedFunction (OMPBuilder, Builder, DefaultAttrs,
7051
+ return createOutlinedFunction (OMPBuilder, Builder, IsSPMD, DefaultAttrs,
7008
7052
EntryFnName, Inputs, CBFunc,
7009
7053
ArgAccessorFuncCB);
7010
7054
};
@@ -7304,6 +7348,7 @@ static void
7304
7348
emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7305
7349
OpenMPIRBuilder::InsertPointTy AllocaIP,
7306
7350
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7351
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7307
7352
Function *OutlinedFn, Constant *OutlinedFnID,
7308
7353
SmallVectorImpl<Value *> &Args,
7309
7354
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7430,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7385
7430
/* ForEndCall=*/ false );
7386
7431
7387
7432
SmallVector<Value *, 3 > NumTeamsC;
7433
+ for (auto [DefaultVal, RuntimeVal] :
7434
+ zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7435
+ NumTeamsC.push_back (RuntimeVal ? RuntimeVal : Builder.getInt32 (DefaultVal));
7436
+
7437
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7438
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7439
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7440
+ if (Clause)
7441
+ Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7442
+ /* isSigned=*/ false );
7443
+ return Clause;
7444
+ };
7445
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7446
+ if (Clause)
7447
+ Result = Result
7448
+ ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7449
+ Result, Clause)
7450
+ : Clause;
7451
+ };
7452
+
7453
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7454
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7388
7455
SmallVector<Value *, 3 > NumThreadsC;
7389
- for (auto V : DefaultAttrs.MaxTeams )
7390
- NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7391
- for (auto V : DefaultAttrs.MaxThreads )
7392
- NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7456
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit .size () == 1
7457
+ ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7458
+ : nullptr ;
7459
+
7460
+ for (auto [TeamsVal, TargetVal] : llvm::zip_equal (
7461
+ RuntimeAttrs.TeamsThreadLimit , RuntimeAttrs.TargetThreadLimit )) {
7462
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7463
+ Value *NumThreads = InitMaxThreadsClause (TargetVal);
7464
+
7465
+ CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7466
+ CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7467
+
7468
+ NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7469
+ }
7393
7470
7394
7471
unsigned NumTargetItems = Info.NumberOfPtrs ;
7395
7472
// TODO: Use correct device ID
@@ -7398,14 +7475,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7398
7475
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7399
7476
Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7400
7477
llvm::omp::IdentFlag (0 ), 0 );
7401
- // TODO: Use correct NumIterations
7402
- Value *NumIterations = Builder.getInt64 (0 );
7478
+
7479
+ Value *TripCount = RuntimeAttrs.LoopTripCount
7480
+ ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7481
+ Builder.getInt64Ty (),
7482
+ /* isSigned=*/ false )
7483
+ : Builder.getInt64 (0 );
7484
+
7403
7485
// TODO: Use correct DynCGGroupMem
7404
7486
Value *DynCGGroupMem = Builder.getInt32 (0 );
7405
7487
7406
- KArgs = OpenMPIRBuilder::TargetKernelArgs (
7407
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7408
- DynCGGroupMem, HasNoWait);
7488
+ KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7489
+ NumTeamsC, NumThreadsC,
7490
+ DynCGGroupMem, HasNoWait);
7409
7491
7410
7492
// The presence of certain clauses on the target directive require the
7411
7493
// explicit generation of the target task.
@@ -7427,13 +7509,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7427
7509
}
7428
7510
7429
7511
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
7430
- const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7431
- InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7512
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
7513
+ InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
7514
+ TargetRegionEntryInfo &EntryInfo,
7432
7515
const TargetKernelDefaultAttrs &DefaultAttrs,
7516
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
7433
7517
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7434
7518
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7435
7519
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7436
7520
SmallVector<DependData> Dependencies, bool HasNowait) {
7521
+ assert ((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
7522
+ " trip count not expected if IsSPMD=false" );
7437
7523
7438
7524
if (!updateToLocation (Loc))
7439
7525
return InsertPointTy ();
@@ -7446,16 +7532,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7446
7532
// the target region itself is generated using the callbacks CBFunc
7447
7533
// and ArgAccessorFuncCB
7448
7534
if (Error Err = emitTargetOutlinedFunction (
7449
- *this , Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn ,
7450
- OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7535
+ *this , Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs ,
7536
+ OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7451
7537
return Err;
7452
7538
7453
7539
// If we are not on the target device, then we need to generate code
7454
7540
// to make a remote call (offload) to the previously outlined function
7455
7541
// that represents the target region. Do that now.
7456
7542
if (!Config.isTargetDevice ())
7457
- emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7458
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7543
+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7544
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7545
+ HasNowait);
7459
7546
return Builder.saveIP ();
7460
7547
}
7461
7548
0 commit comments