@@ -174,10 +174,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
174
174
if (op.getHint ())
175
175
op.emitWarning (" hint clause discarded" );
176
176
};
177
- auto checkHostEval = [&todo](auto op, LogicalResult &result) {
178
- if (!op.getHostEvalVars ().empty ())
179
- result = todo (" host_eval" );
180
- };
181
177
auto checkIf = [&todo](auto op, LogicalResult &result) {
182
178
if (op.getIfExpr ())
183
179
result = todo (" if" );
@@ -224,10 +220,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
224
220
op.getReductionSyms ())
225
221
result = todo (" reduction" );
226
222
};
227
- auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
228
- if (op.getThreadLimit ())
229
- result = todo (" thread_limit" );
230
- };
231
223
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
232
224
if (!op.getTaskReductionVars ().empty () || op.getTaskReductionByref () ||
233
225
op.getTaskReductionSyms ())
@@ -290,7 +282,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
290
282
checkAllocate (op, result);
291
283
checkDevice (op, result);
292
284
checkHasDeviceAddr (op, result);
293
- checkHostEval (op, result);
285
+
286
+ // Host evaluated clauses are supported, except for target SPMD loop
287
+ // bounds.
288
+ for (BlockArgument arg :
289
+ cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
290
+ for (Operation *user : arg.getUsers ())
291
+ if (isa<omp::LoopNestOp>(user))
292
+ result = op.emitError (" not yet implemented: host evaluation of "
293
+ " loop bounds in omp.target operation" );
294
+
294
295
checkIf (op, result);
295
296
checkInReduction (op, result);
296
297
checkIsDevicePtr (op, result);
@@ -311,7 +312,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
311
312
" structures in omp.target operation" );
312
313
}
313
314
}
314
- checkThreadLimit (op, result);
315
315
})
316
316
.Default ([](Operation &) {
317
317
// Assume all clauses for an operation can be translated unless they are
@@ -3815,6 +3815,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
3815
3815
return builder.saveIP ();
3816
3816
}
3817
3817
3818
+ // / Follow uses of `host_eval`-defined block arguments of the given `omp.target`
3819
+ // / operation and populate output variables with their corresponding host value
3820
+ // / (i.e. operand evaluated outside of the target region), based on their uses
3821
+ // / inside of the target region.
3822
+ // /
3823
+ // / Loop bounds and steps are only optionally populated, if output vectors are
3824
+ // / provided.
3825
+ static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
3826
+ Value &numTeamsLower, Value &numTeamsUpper,
3827
+ Value &threadLimit) {
3828
+ auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3829
+ for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
3830
+ blockArgIface.getHostEvalBlockArgs ())) {
3831
+ Value hostEvalVar = std::get<0 >(item), blockArg = std::get<1 >(item);
3832
+
3833
+ for (Operation *user : blockArg.getUsers ()) {
3834
+ llvm::TypeSwitch<Operation *>(user)
3835
+ .Case ([&](omp::TeamsOp teamsOp) {
3836
+ if (teamsOp.getNumTeamsLower () == blockArg)
3837
+ numTeamsLower = hostEvalVar;
3838
+ else if (teamsOp.getNumTeamsUpper () == blockArg)
3839
+ numTeamsUpper = hostEvalVar;
3840
+ else if (teamsOp.getThreadLimit () == blockArg)
3841
+ threadLimit = hostEvalVar;
3842
+ else
3843
+ llvm_unreachable (" unsupported host_eval use" );
3844
+ })
3845
+ .Case ([&](omp::ParallelOp parallelOp) {
3846
+ if (parallelOp.getNumThreads () == blockArg)
3847
+ numThreads = hostEvalVar;
3848
+ else
3849
+ llvm_unreachable (" unsupported host_eval use" );
3850
+ })
3851
+ .Case ([&](omp::LoopNestOp loopOp) {
3852
+ // TODO: Extract bounds and step values.
3853
+ })
3854
+ .Default ([](Operation *) {
3855
+ llvm_unreachable (" unsupported host_eval use" );
3856
+ });
3857
+ }
3858
+ }
3859
+ }
3860
+
3861
+ // / If \p op is of the given type parameter, return it casted to that type.
3862
+ // / Otherwise, if its immediate parent operation (or some other higher-level
3863
+ // / parent, if \p immediateParent is false) is of that type, return that parent
3864
+ // / casted to the given type.
3865
+ // /
3866
+ // / If \p op is \c null or neither it or its parent(s) are of the specified
3867
+ // / type, return a \c null operation.
3868
+ template <typename OpTy>
3869
+ static OpTy castOrGetParentOfType (Operation *op, bool immediateParent = false ) {
3870
+ if (!op)
3871
+ return OpTy ();
3872
+
3873
+ if (OpTy casted = dyn_cast<OpTy>(op))
3874
+ return casted;
3875
+
3876
+ if (immediateParent)
3877
+ return dyn_cast_if_present<OpTy>(op->getParentOp ());
3878
+
3879
+ return op->getParentOfType <OpTy>();
3880
+ }
3881
+
3882
+ // / Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
3883
+ // / values as stated by the corresponding clauses, if constant.
3884
+ // /
3885
+ // / These default values must be set before the creation of the outlined LLVM
3886
+ // / function for the target region, so that they can be used to initialize the
3887
+ // / corresponding global `ConfigurationEnvironmentTy` structure.
3888
+ static void
3889
+ initTargetDefaultAttrs (omp::TargetOp targetOp,
3890
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
3891
+ bool isTargetDevice) {
3892
+ // TODO: Handle constant 'if' clauses.
3893
+ Operation *capturedOp = targetOp.getInnermostCapturedOmpOp ();
3894
+
3895
+ Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
3896
+ if (!isTargetDevice) {
3897
+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
3898
+ threadLimit);
3899
+ } else {
3900
+ // In the target device, values for these clauses are not passed as
3901
+ // host_eval, but instead evaluated prior to entry to the region. This
3902
+ // ensures values are mapped and available inside of the target region.
3903
+ if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
3904
+ numTeamsLower = teamsOp.getNumTeamsLower ();
3905
+ numTeamsUpper = teamsOp.getNumTeamsUpper ();
3906
+ threadLimit = teamsOp.getThreadLimit ();
3907
+ }
3908
+
3909
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
3910
+ numThreads = parallelOp.getNumThreads ();
3911
+ }
3912
+
3913
+ auto extractConstInteger = [](Value value) -> std::optional<int64_t > {
3914
+ if (auto constOp =
3915
+ dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp ()))
3916
+ if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue ()))
3917
+ return constAttr.getInt ();
3918
+
3919
+ return std::nullopt;
3920
+ };
3921
+
3922
+ // Handle clauses impacting the number of teams.
3923
+
3924
+ int32_t minTeamsVal = 1 , maxTeamsVal = -1 ;
3925
+ if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
3926
+ // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
3927
+ // clang and set min and max to the same value.
3928
+ if (numTeamsUpper) {
3929
+ if (auto val = extractConstInteger (numTeamsUpper))
3930
+ minTeamsVal = maxTeamsVal = *val;
3931
+ } else {
3932
+ minTeamsVal = maxTeamsVal = 0 ;
3933
+ }
3934
+ } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
3935
+ /* immediateParent=*/ true ) ||
3936
+ castOrGetParentOfType<omp::SimdOp>(capturedOp,
3937
+ /* immediateParent=*/ true )) {
3938
+ minTeamsVal = maxTeamsVal = 1 ;
3939
+ } else {
3940
+ minTeamsVal = maxTeamsVal = -1 ;
3941
+ }
3942
+
3943
+ // Handle clauses impacting the number of threads.
3944
+
3945
+ auto setMaxValueFromClause = [&extractConstInteger](Value clauseValue,
3946
+ int32_t &result) {
3947
+ if (!clauseValue)
3948
+ return ;
3949
+
3950
+ if (auto val = extractConstInteger (clauseValue))
3951
+ result = *val;
3952
+
3953
+ // Found an applicable clause, so it's not undefined. Mark as unknown
3954
+ // because it's not constant.
3955
+ if (result < 0 )
3956
+ result = 0 ;
3957
+ };
3958
+
3959
+ // Extract 'thread_limit' clause from 'target' and 'teams' directives.
3960
+ int32_t targetThreadLimitVal = -1 , teamsThreadLimitVal = -1 ;
3961
+ setMaxValueFromClause (targetOp.getThreadLimit (), targetThreadLimitVal);
3962
+ setMaxValueFromClause (threadLimit, teamsThreadLimitVal);
3963
+
3964
+ // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
3965
+ int32_t maxThreadsVal = -1 ;
3966
+ if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
3967
+ setMaxValueFromClause (numThreads, maxThreadsVal);
3968
+ else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
3969
+ /* immediateParent=*/ true ))
3970
+ maxThreadsVal = 1 ;
3971
+
3972
+ // For max values, < 0 means unset, == 0 means set but unknown. Select the
3973
+ // minimum value between 'max_threads' and 'thread_limit' clauses that were
3974
+ // set.
3975
+ int32_t combinedMaxThreadsVal = targetThreadLimitVal;
3976
+ if (combinedMaxThreadsVal < 0 ||
3977
+ (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
3978
+ combinedMaxThreadsVal = teamsThreadLimitVal;
3979
+
3980
+ if (combinedMaxThreadsVal < 0 ||
3981
+ (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
3982
+ combinedMaxThreadsVal = maxThreadsVal;
3983
+
3984
+ // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
3985
+ attrs.MinTeams = minTeamsVal;
3986
+ attrs.MaxTeams .front () = maxTeamsVal;
3987
+ attrs.MinThreads = 1 ;
3988
+ attrs.MaxThreads .front () = combinedMaxThreadsVal;
3989
+ }
3990
+
3991
+ // / Gather LLVM runtime values for all clauses evaluated in the host that are
3992
+ // / passed to the kernel invocation.
3993
+ // /
3994
+ // / This function must be called only when compiling for the host. Also, it will
3995
+ // / only provide correct results if it's called after the body of \c targetOp
3996
+ // / has been fully generated.
3997
+ static void
3998
+ initTargetRuntimeAttrs (llvm::IRBuilderBase &builder,
3999
+ LLVM::ModuleTranslation &moduleTranslation,
4000
+ omp::TargetOp targetOp,
4001
+ llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4002
+ Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4003
+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
4004
+ teamsThreadLimit);
4005
+
4006
+ // TODO: Handle constant 'if' clauses.
4007
+ if (Value targetThreadLimit = targetOp.getThreadLimit ())
4008
+ attrs.TargetThreadLimit .front () =
4009
+ moduleTranslation.lookupValue (targetThreadLimit);
4010
+
4011
+ if (numTeamsLower)
4012
+ attrs.MinTeams = moduleTranslation.lookupValue (numTeamsLower);
4013
+
4014
+ if (numTeamsUpper)
4015
+ attrs.MaxTeams .front () = moduleTranslation.lookupValue (numTeamsUpper);
4016
+
4017
+ if (teamsThreadLimit)
4018
+ attrs.TeamsThreadLimit .front () =
4019
+ moduleTranslation.lookupValue (teamsThreadLimit);
4020
+
4021
+ if (numThreads)
4022
+ attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
4023
+
4024
+ // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4025
+ }
4026
+
3818
4027
static LogicalResult
3819
4028
convertOmpTarget (Operation &opInst, llvm::IRBuilderBase &builder,
3820
4029
LLVM::ModuleTranslation &moduleTranslation) {
@@ -3824,12 +4033,13 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3824
4033
3825
4034
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3826
4035
bool isTargetDevice = ompBuilder->Config .isTargetDevice ();
4036
+
3827
4037
auto parentFn = opInst.getParentOfType <LLVM::LLVMFuncOp>();
4038
+ auto blockIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
3828
4039
auto &targetRegion = targetOp.getRegion ();
3829
4040
DataLayout dl = DataLayout (opInst.getParentOfType <ModuleOp>());
3830
4041
SmallVector<Value> mapVars = targetOp.getMapVars ();
3831
- ArrayRef<BlockArgument> mapBlockArgs =
3832
- cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs ();
4042
+ ArrayRef<BlockArgument> mapBlockArgs = blockIface.getMapBlockArgs ();
3833
4043
llvm::Function *llvmOutlinedFn = nullptr ;
3834
4044
3835
4045
// TODO: It can also be false if a compile-time constant `false` IF clause is
@@ -3872,7 +4082,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3872
4082
OperandRange privateVars = targetOp.getPrivateVars ();
3873
4083
std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms ();
3874
4084
MutableArrayRef<BlockArgument> privateBlockArgs =
3875
- cast<omp::BlockArgOpenMPOpInterface>(opInst) .getPrivateBlockArgs ();
4085
+ blockIface .getPrivateBlockArgs ();
3876
4086
3877
4087
for (auto [privVar, privatizerNameAttr, privBlockArg] :
3878
4088
llvm::zip_equal (privateVars, *privateSyms, privateBlockArgs)) {
@@ -3951,13 +4161,29 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3951
4161
allocaIP, codeGenIP);
3952
4162
};
3953
4163
3954
- // TODO: Populate default and runtime attributes based on the construct and
3955
- // clauses.
3956
- llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
3957
- /* MaxTeams=*/ {-1 }, /* MinTeams=*/ 0 , /* MaxThreads=*/ {0 }, /* MinThreads=*/ 0 };
4164
+ llvm::SmallVector<llvm::Value *, 4 > kernelInput;
4165
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4166
+ initTargetDefaultAttrs (targetOp, defaultAttrs, isTargetDevice);
4167
+
4168
+ // Collect host-evaluated values needed to properly launch the kernel from the
4169
+ // host.
3958
4170
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
4171
+ if (!isTargetDevice)
4172
+ initTargetRuntimeAttrs (builder, moduleTranslation, targetOp, runtimeAttrs);
4173
+
4174
+ // Pass host-evaluated values as parameters to the kernel / host fallback,
4175
+ // except if they are constants. In any case, map the MLIR block argument to
4176
+ // the corresponding LLVM values.
4177
+ SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars ();
4178
+ ArrayRef<BlockArgument> hostEvalBlockArgs = blockIface.getHostEvalBlockArgs ();
4179
+ for (auto [arg, var] : llvm::zip_equal (hostEvalBlockArgs, hostEvalVars)) {
4180
+ llvm::Value *value = moduleTranslation.lookupValue (var);
4181
+ moduleTranslation.mapValue (arg, value);
4182
+
4183
+ if (!llvm::isa<llvm::Constant>(value))
4184
+ kernelInput.push_back (value);
4185
+ }
3959
4186
3960
- llvm::SmallVector<llvm::Value *, 4 > kernelInput;
3961
4187
for (size_t i = 0 ; i < mapVars.size (); ++i) {
3962
4188
// declare target arguments are not passed to kernels as arguments
3963
4189
// TODO: We currently do not handle cases where a member is explicitly
0 commit comments