@@ -174,10 +174,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
174
174
if (op.getHint ())
175
175
op.emitWarning (" hint clause discarded" );
176
176
};
177
- auto checkHostEval = [&todo](auto op, LogicalResult &result) {
178
- if (!op.getHostEvalVars ().empty ())
179
- result = todo (" host_eval" );
180
- };
181
177
auto checkIf = [&todo](auto op, LogicalResult &result) {
182
178
if (op.getIfExpr ())
183
179
result = todo (" if" );
@@ -228,10 +224,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
228
224
op.getReductionSyms ())
229
225
result = todo (" reduction" );
230
226
};
231
- auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
232
- if (op.getThreadLimit ())
233
- result = todo (" thread_limit" );
234
- };
235
227
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
236
228
if (!op.getTaskReductionVars ().empty () || op.getTaskReductionByref () ||
237
229
op.getTaskReductionSyms ())
@@ -295,7 +287,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
295
287
checkAllocate (op, result);
296
288
checkDevice (op, result);
297
289
checkHasDeviceAddr (op, result);
298
- checkHostEval (op, result);
290
+
291
+ // Host evaluated clauses are supported, except for target SPMD loop
292
+ // bounds.
293
+ for (BlockArgument arg :
294
+ cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
295
+ for (Operation *user : arg.getUsers ())
296
+ if (isa<omp::LoopNestOp>(user))
297
+ result = op.emitError (" not yet implemented: host evaluation of "
298
+ " loop bounds in omp.target operation" );
299
+
299
300
checkIf (op, result);
300
301
checkInReduction (op, result);
301
302
checkIsDevicePtr (op, result);
@@ -316,7 +317,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
316
317
" structures in omp.target operation" );
317
318
}
318
319
}
319
- checkThreadLimit (op, result);
320
320
})
321
321
.Default ([](Operation &) {
322
322
// Assume all clauses for an operation can be translated unless they are
@@ -3800,6 +3800,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
3800
3800
return builder.saveIP ();
3801
3801
}
3802
3802
3803
+ // / Follow uses of `host_eval`-defined block arguments of the given `omp.target`
3804
+ // / operation and populate output variables with their corresponding host value
3805
+ // / (i.e. operand evaluated outside of the target region), based on their uses
3806
+ // / inside of the target region.
3807
+ // /
3808
+ // / Loop bounds and steps are only optionally populated, if output vectors are
3809
+ // / provided.
3810
+ static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
3811
+ Value &numTeamsLower, Value &numTeamsUpper,
3812
+ Value &threadLimit) {
3813
+ auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3814
+ for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
3815
+ blockArgIface.getHostEvalBlockArgs ())) {
3816
+ Value hostEvalVar = std::get<0 >(item), blockArg = std::get<1 >(item);
3817
+
3818
+ for (Operation *user : blockArg.getUsers ()) {
3819
+ llvm::TypeSwitch<Operation *>(user)
3820
+ .Case ([&](omp::TeamsOp teamsOp) {
3821
+ if (teamsOp.getNumTeamsLower () == blockArg)
3822
+ numTeamsLower = hostEvalVar;
3823
+ else if (teamsOp.getNumTeamsUpper () == blockArg)
3824
+ numTeamsUpper = hostEvalVar;
3825
+ else if (teamsOp.getThreadLimit () == blockArg)
3826
+ threadLimit = hostEvalVar;
3827
+ else
3828
+ llvm_unreachable (" unsupported host_eval use" );
3829
+ })
3830
+ .Case ([&](omp::ParallelOp parallelOp) {
3831
+ if (parallelOp.getNumThreads () == blockArg)
3832
+ numThreads = hostEvalVar;
3833
+ else
3834
+ llvm_unreachable (" unsupported host_eval use" );
3835
+ })
3836
+ .Case ([&](omp::LoopNestOp loopOp) {
3837
+ // TODO: Extract bounds and step values.
3838
+ })
3839
+ .Default ([](Operation *) {
3840
+ llvm_unreachable (" unsupported host_eval use" );
3841
+ });
3842
+ }
3843
+ }
3844
+ }
3845
+
3846
+ // / If \p op is of the given type parameter, return it casted to that type.
3847
+ // / Otherwise, if its immediate parent operation (or some other higher-level
3848
+ // / parent, if \p immediateParent is false) is of that type, return that parent
3849
+ // / casted to the given type.
3850
+ // /
3851
+ // / If \p op is \c null or neither it or its parent(s) are of the specified
3852
+ // / type, return a \c null operation.
3853
+ template <typename OpTy>
3854
+ static OpTy castOrGetParentOfType (Operation *op, bool immediateParent = false ) {
3855
+ if (!op)
3856
+ return OpTy ();
3857
+
3858
+ if (OpTy casted = dyn_cast<OpTy>(op))
3859
+ return casted;
3860
+
3861
+ if (immediateParent)
3862
+ return dyn_cast_if_present<OpTy>(op->getParentOp ());
3863
+
3864
+ return op->getParentOfType <OpTy>();
3865
+ }
3866
+
3867
+ // / Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
3868
+ // / values as stated by the corresponding clauses, if constant.
3869
+ // /
3870
+ // / These default values must be set before the creation of the outlined LLVM
3871
+ // / function for the target region, so that they can be used to initialize the
3872
+ // / corresponding global `ConfigurationEnvironmentTy` structure.
3873
+ static void
3874
+ initTargetDefaultAttrs (omp::TargetOp targetOp,
3875
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
3876
+ bool isTargetDevice) {
3877
+ // TODO: Handle constant 'if' clauses.
3878
+ Operation *capturedOp = targetOp.getInnermostCapturedOmpOp ();
3879
+
3880
+ Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
3881
+ if (!isTargetDevice) {
3882
+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
3883
+ threadLimit);
3884
+ } else {
3885
+ // In the target device, values for these clauses are not passed as
3886
+ // host_eval, but instead evaluated prior to entry to the region. This
3887
+ // ensures values are mapped and available inside of the target region.
3888
+ if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
3889
+ numTeamsLower = teamsOp.getNumTeamsLower ();
3890
+ numTeamsUpper = teamsOp.getNumTeamsUpper ();
3891
+ threadLimit = teamsOp.getThreadLimit ();
3892
+ }
3893
+
3894
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
3895
+ numThreads = parallelOp.getNumThreads ();
3896
+ }
3897
+
3898
+ auto extractConstInteger = [](Value value) -> std::optional<int64_t > {
3899
+ if (auto constOp =
3900
+ dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp ()))
3901
+ if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue ()))
3902
+ return constAttr.getInt ();
3903
+
3904
+ return std::nullopt;
3905
+ };
3906
+
3907
+ // Handle clauses impacting the number of teams.
3908
+
3909
+ int32_t minTeamsVal = 1 , maxTeamsVal = -1 ;
3910
+ if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
3911
+ // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
3912
+ // clang and set min and max to the same value.
3913
+ if (numTeamsUpper) {
3914
+ if (auto val = extractConstInteger (numTeamsUpper))
3915
+ minTeamsVal = maxTeamsVal = *val;
3916
+ } else {
3917
+ minTeamsVal = maxTeamsVal = 0 ;
3918
+ }
3919
+ } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
3920
+ /* immediateParent=*/ true ) ||
3921
+ castOrGetParentOfType<omp::SimdOp>(capturedOp,
3922
+ /* immediateParent=*/ true )) {
3923
+ minTeamsVal = maxTeamsVal = 1 ;
3924
+ } else {
3925
+ minTeamsVal = maxTeamsVal = -1 ;
3926
+ }
3927
+
3928
+ // Handle clauses impacting the number of threads.
3929
+
3930
+ auto setMaxValueFromClause = [&extractConstInteger](Value clauseValue,
3931
+ int32_t &result) {
3932
+ if (!clauseValue)
3933
+ return ;
3934
+
3935
+ if (auto val = extractConstInteger (clauseValue))
3936
+ result = *val;
3937
+
3938
+ // Found an applicable clause, so it's not undefined. Mark as unknown
3939
+ // because it's not constant.
3940
+ if (result < 0 )
3941
+ result = 0 ;
3942
+ };
3943
+
3944
+ // Extract 'thread_limit' clause from 'target' and 'teams' directives.
3945
+ int32_t targetThreadLimitVal = -1 , teamsThreadLimitVal = -1 ;
3946
+ setMaxValueFromClause (targetOp.getThreadLimit (), targetThreadLimitVal);
3947
+ setMaxValueFromClause (threadLimit, teamsThreadLimitVal);
3948
+
3949
+ // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
3950
+ int32_t maxThreadsVal = -1 ;
3951
+ if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
3952
+ setMaxValueFromClause (numThreads, maxThreadsVal);
3953
+ else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
3954
+ /* immediateParent=*/ true ))
3955
+ maxThreadsVal = 1 ;
3956
+
3957
+ // For max values, < 0 means unset, == 0 means set but unknown. Select the
3958
+ // minimum value between 'max_threads' and 'thread_limit' clauses that were
3959
+ // set.
3960
+ int32_t combinedMaxThreadsVal = targetThreadLimitVal;
3961
+ if (combinedMaxThreadsVal < 0 ||
3962
+ (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
3963
+ combinedMaxThreadsVal = teamsThreadLimitVal;
3964
+
3965
+ if (combinedMaxThreadsVal < 0 ||
3966
+ (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
3967
+ combinedMaxThreadsVal = maxThreadsVal;
3968
+
3969
+ // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
3970
+ attrs.MinTeams = minTeamsVal;
3971
+ attrs.MaxTeams .front () = maxTeamsVal;
3972
+ attrs.MinThreads = 1 ;
3973
+ attrs.MaxThreads .front () = combinedMaxThreadsVal;
3974
+ }
3975
+
3976
+ // / Gather LLVM runtime values for all clauses evaluated in the host that are
3977
+ // / passed to the kernel invocation.
3978
+ // /
3979
+ // / This function must be called only when compiling for the host. Also, it will
3980
+ // / only provide correct results if it's called after the body of \c targetOp
3981
+ // / has been fully generated.
3982
+ static void
3983
+ initTargetRuntimeAttrs (llvm::IRBuilderBase &builder,
3984
+ LLVM::ModuleTranslation &moduleTranslation,
3985
+ omp::TargetOp targetOp,
3986
+ llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
3987
+ Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
3988
+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
3989
+ teamsThreadLimit);
3990
+
3991
+ // TODO: Handle constant 'if' clauses.
3992
+ if (Value targetThreadLimit = targetOp.getThreadLimit ())
3993
+ attrs.TargetThreadLimit .front () =
3994
+ moduleTranslation.lookupValue (targetThreadLimit);
3995
+
3996
+ if (numTeamsLower)
3997
+ attrs.MinTeams = moduleTranslation.lookupValue (numTeamsLower);
3998
+
3999
+ if (numTeamsUpper)
4000
+ attrs.MaxTeams .front () = moduleTranslation.lookupValue (numTeamsUpper);
4001
+
4002
+ if (teamsThreadLimit)
4003
+ attrs.TeamsThreadLimit .front () =
4004
+ moduleTranslation.lookupValue (teamsThreadLimit);
4005
+
4006
+ if (numThreads)
4007
+ attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
4008
+
4009
+ // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4010
+ }
4011
+
3803
4012
static LogicalResult
3804
4013
convertOmpTarget (Operation &opInst, llvm::IRBuilderBase &builder,
3805
4014
LLVM::ModuleTranslation &moduleTranslation) {
@@ -3809,12 +4018,13 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3809
4018
3810
4019
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3811
4020
bool isTargetDevice = ompBuilder->Config .isTargetDevice ();
4021
+
3812
4022
auto parentFn = opInst.getParentOfType <LLVM::LLVMFuncOp>();
4023
+ auto blockIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
3813
4024
auto &targetRegion = targetOp.getRegion ();
3814
4025
DataLayout dl = DataLayout (opInst.getParentOfType <ModuleOp>());
3815
4026
SmallVector<Value> mapVars = targetOp.getMapVars ();
3816
- ArrayRef<BlockArgument> mapBlockArgs =
3817
- cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs ();
4027
+ ArrayRef<BlockArgument> mapBlockArgs = blockIface.getMapBlockArgs ();
3818
4028
llvm::Function *llvmOutlinedFn = nullptr ;
3819
4029
3820
4030
// TODO: It can also be false if a compile-time constant `false` IF clause is
@@ -3857,7 +4067,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3857
4067
OperandRange privateVars = targetOp.getPrivateVars ();
3858
4068
std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms ();
3859
4069
MutableArrayRef<BlockArgument> privateBlockArgs =
3860
- cast<omp::BlockArgOpenMPOpInterface>(opInst) .getPrivateBlockArgs ();
4070
+ blockIface .getPrivateBlockArgs ();
3861
4071
3862
4072
for (auto [privVar, privatizerNameAttr, privBlockArg] :
3863
4073
llvm::zip_equal (privateVars, *privateSyms, privateBlockArgs)) {
@@ -3936,13 +4146,29 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
3936
4146
allocaIP, codeGenIP);
3937
4147
};
3938
4148
3939
- // TODO: Populate default and runtime attributes based on the construct and
3940
- // clauses.
3941
- llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
3942
- /* MaxTeams=*/ {-1 }, /* MinTeams=*/ 0 , /* MaxThreads=*/ {0 }, /* MinThreads=*/ 0 };
4149
+ llvm::SmallVector<llvm::Value *, 4 > kernelInput;
4150
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4151
+ initTargetDefaultAttrs (targetOp, defaultAttrs, isTargetDevice);
4152
+
4153
+ // Collect host-evaluated values needed to properly launch the kernel from the
4154
+ // host.
3943
4155
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
4156
+ if (!isTargetDevice)
4157
+ initTargetRuntimeAttrs (builder, moduleTranslation, targetOp, runtimeAttrs);
4158
+
4159
+ // Pass host-evaluated values as parameters to the kernel / host fallback,
4160
+ // except if they are constants. In any case, map the MLIR block argument to
4161
+ // the corresponding LLVM values.
4162
+ SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars ();
4163
+ ArrayRef<BlockArgument> hostEvalBlockArgs = blockIface.getHostEvalBlockArgs ();
4164
+ for (auto [arg, var] : llvm::zip_equal (hostEvalBlockArgs, hostEvalVars)) {
4165
+ llvm::Value *value = moduleTranslation.lookupValue (var);
4166
+ moduleTranslation.mapValue (arg, value);
4167
+
4168
+ if (!llvm::isa<llvm::Constant>(value))
4169
+ kernelInput.push_back (value);
4170
+ }
3944
4171
3945
- llvm::SmallVector<llvm::Value *, 4 > kernelInput;
3946
4172
for (size_t i = 0 ; i < mapVars.size (); ++i) {
3947
4173
// declare target arguments are not passed to kernels as arguments
3948
4174
// TODO: We currently do not handle cases where a member is explicitly
0 commit comments