@@ -6968,8 +6968,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6968
6968
}
6969
6969
6970
6970
OI.ExitBB = Builder.saveIP ().getBlock ();
6971
- OI.PostOutlineCB = [this , ToBeDeleted, Dependencies,
6972
- HasNoWait ](Function &OutlinedFn) mutable {
6971
+ OI.PostOutlineCB = [this , ToBeDeleted, Dependencies, HasNoWait,
6972
+ DeviceID ](Function &OutlinedFn) mutable {
6973
6973
assert (OutlinedFn.getNumUses () == 1 &&
6974
6974
" there must be a single user for the outlined function" );
6975
6975
@@ -6989,9 +6989,14 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6989
6989
getOrCreateSrcLocStr (LocationDescription (Builder), SrcLocStrSize);
6990
6990
Value *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
6991
6991
6992
- // @__kmpc_omp_task_alloc
6992
+ // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
6993
+ //
6994
+ // If `HasNoWait == true`, we call @__kmpc_omp_target_task_alloc to provide
6995
+ // the DeviceID to the deferred task.
6993
6996
Function *TaskAllocFn =
6994
- getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_omp_task_alloc);
6997
+ !HasNoWait ? getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_omp_task_alloc)
6998
+ : getOrCreateRuntimeFunctionPtr (
6999
+ OMPRTL___kmpc_omp_target_task_alloc);
6995
7000
6996
7001
// Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
6997
7002
// call.
@@ -7032,10 +7037,18 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
7032
7037
// Emit the @__kmpc_omp_task_alloc runtime call
7033
7038
// The runtime call returns a pointer to an area where the task captured
7034
7039
// variables must be copied before the task is run (TaskData)
7035
- CallInst *TaskData = Builder.CreateCall (
7036
- TaskAllocFn, {/* loc_ref=*/ Ident, /* gtid=*/ ThreadID, /* flags=*/ Flags,
7037
- /* sizeof_task=*/ TaskSize, /* sizeof_shared=*/ SharedsSize,
7038
- /* task_func=*/ ProxyFn});
7040
+ CallInst *TaskData = nullptr ;
7041
+
7042
+ SmallVector<llvm::Value *> TaskAllocArgs = {
7043
+ /* loc_ref=*/ Ident, /* gtid=*/ ThreadID,
7044
+ /* flags=*/ Flags,
7045
+ /* sizeof_task=*/ TaskSize, /* sizeof_shared=*/ SharedsSize,
7046
+ /* task_func=*/ ProxyFn};
7047
+
7048
+ if (HasNoWait)
7049
+ TaskAllocArgs.push_back (DeviceID);
7050
+
7051
+ TaskData = Builder.CreateCall (TaskAllocFn, TaskAllocArgs);
7039
7052
7040
7053
if (HasShareds) {
7041
7054
Value *Shareds = StaleCI->getArgOperand (1 );
@@ -7118,13 +7131,14 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7118
7131
emitOffloadingArraysArgument (Builder, RTArgs, Info, ForEndCall);
7119
7132
}
7120
7133
7121
- static void emitTargetCall (
7122
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7123
- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7124
- Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7125
- ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7126
- OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7127
- SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
7134
+ static void
7135
+ emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7136
+ OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7137
+ Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7138
+ ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7139
+ OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7140
+ SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
7141
+ bool HasNoWait = false ) {
7128
7142
// Generate a function call to the host fallback implementation of the target
7129
7143
// region. This is called by the host when no offload entry was generated for
7130
7144
// the target region and when the offloading call fails at runtime.
@@ -7135,7 +7149,6 @@ static void emitTargetCall(
7135
7149
return Builder.saveIP ();
7136
7150
};
7137
7151
7138
- bool HasNoWait = false ;
7139
7152
bool HasDependencies = Dependencies.size () > 0 ;
7140
7153
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7141
7154
@@ -7211,7 +7224,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7211
7224
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7212
7225
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7213
7226
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7214
- SmallVector<DependData> Dependencies) {
7227
+ SmallVector<DependData> Dependencies, bool HasNowait ) {
7215
7228
7216
7229
if (!updateToLocation (Loc))
7217
7230
return InsertPointTy ();
@@ -7232,7 +7245,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7232
7245
// that represents the target region. Do that now.
7233
7246
if (!Config.isTargetDevice ())
7234
7247
emitTargetCall (*this , Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7235
- NumThreads, Args, GenMapInfoCB, Dependencies);
7248
+ NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait );
7236
7249
return Builder.saveIP ();
7237
7250
}
7238
7251
0 commit comments