@@ -6968,8 +6968,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6968
6968
}
6969
6969
6970
6970
OI.ExitBB = Builder.saveIP ().getBlock ();
6971
- OI.PostOutlineCB = [this , ToBeDeleted, Dependencies,
6972
- HasNoWait ](Function &OutlinedFn) mutable {
6971
+ OI.PostOutlineCB = [this , ToBeDeleted, Dependencies, HasNoWait,
6972
+ DeviceID ](Function &OutlinedFn) mutable {
6973
6973
assert (OutlinedFn.getNumUses () == 1 &&
6974
6974
" there must be a single user for the outlined function" );
6975
6975
@@ -6989,9 +6989,15 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6989
6989
getOrCreateSrcLocStr (LocationDescription (Builder), SrcLocStrSize);
6990
6990
Value *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
6991
6991
6992
- // @__kmpc_omp_task_alloc
6992
+ // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
6993
+ //
6994
+ // If `HasNoWait == true`, we call @__kmpc_omp_target_task_alloc to provide
6995
+ // the DeviceID to the deferred task and also since
6996
+ // @__kmpc_omp_target_task_alloc creates an untied/async task.
6993
6997
Function *TaskAllocFn =
6994
- getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_omp_task_alloc);
6998
+ !HasNoWait ? getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_omp_task_alloc)
6999
+ : getOrCreateRuntimeFunctionPtr (
7000
+ OMPRTL___kmpc_omp_target_task_alloc);
6995
7001
6996
7002
// Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
6997
7003
// call.
@@ -7032,10 +7038,18 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
7032
7038
// Emit the @__kmpc_omp_task_alloc runtime call
7033
7039
// The runtime call returns a pointer to an area where the task captured
7034
7040
// variables must be copied before the task is run (TaskData)
7035
- CallInst *TaskData = Builder.CreateCall (
7036
- TaskAllocFn, {/* loc_ref=*/ Ident, /* gtid=*/ ThreadID, /* flags=*/ Flags,
7037
- /* sizeof_task=*/ TaskSize, /* sizeof_shared=*/ SharedsSize,
7038
- /* task_func=*/ ProxyFn});
7041
+ CallInst *TaskData = nullptr ;
7042
+
7043
+ SmallVector<llvm::Value *> TaskAllocArgs = {
7044
+ /* loc_ref=*/ Ident, /* gtid=*/ ThreadID,
7045
+ /* flags=*/ Flags,
7046
+ /* sizeof_task=*/ TaskSize, /* sizeof_shared=*/ SharedsSize,
7047
+ /* task_func=*/ ProxyFn};
7048
+
7049
+ if (HasNoWait)
7050
+ TaskAllocArgs.push_back (DeviceID);
7051
+
7052
+ TaskData = Builder.CreateCall (TaskAllocFn, TaskAllocArgs);
7039
7053
7040
7054
if (HasShareds) {
7041
7055
Value *Shareds = StaleCI->getArgOperand (1 );
@@ -7118,13 +7132,14 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7118
7132
emitOffloadingArraysArgument (Builder, RTArgs, Info, ForEndCall);
7119
7133
}
7120
7134
7121
- static void emitTargetCall (
7122
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7123
- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7124
- Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7125
- ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7126
- OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7127
- SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
7135
+ static void
7136
+ emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7137
+ OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7138
+ Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7139
+ ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7140
+ OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7141
+ SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
7142
+ bool HasNoWait = false ) {
7128
7143
// Generate a function call to the host fallback implementation of the target
7129
7144
// region. This is called by the host when no offload entry was generated for
7130
7145
// the target region and when the offloading call fails at runtime.
@@ -7135,7 +7150,6 @@ static void emitTargetCall(
7135
7150
return Builder.saveIP ();
7136
7151
};
7137
7152
7138
- bool HasNoWait = false ;
7139
7153
bool HasDependencies = Dependencies.size () > 0 ;
7140
7154
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7141
7155
@@ -7211,7 +7225,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7211
7225
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7212
7226
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7213
7227
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7214
- SmallVector<DependData> Dependencies) {
7228
+ SmallVector<DependData> Dependencies, bool HasNowait ) {
7215
7229
7216
7230
if (!updateToLocation (Loc))
7217
7231
return InsertPointTy ();
@@ -7232,7 +7246,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7232
7246
// that represents the target region. Do that now.
7233
7247
if (!Config.isTargetDevice ())
7234
7248
emitTargetCall (*this , Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7235
- NumThreads, Args, GenMapInfoCB, Dependencies);
7249
+ NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait );
7236
7250
return Builder.saveIP ();
7237
7251
}
7238
7252
0 commit comments