@@ -176,15 +176,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
176
176
if (op.getHint ())
177
177
op.emitWarning (" hint clause discarded" );
178
178
};
179
- auto checkHostEval = [](auto op, LogicalResult &result) {
180
- // Host evaluated clauses are supported, except for loop bounds.
181
- for (BlockArgument arg :
182
- cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
183
- for (Operation *user : arg.getUsers ())
184
- if (isa<omp::LoopNestOp>(user))
185
- result = op.emitError (" not yet implemented: host evaluation of loop "
186
- " bounds in omp.target operation" );
187
- };
188
179
auto checkInReduction = [&todo](auto op, LogicalResult &result) {
189
180
if (!op.getInReductionVars ().empty () || op.getInReductionByref () ||
190
181
op.getInReductionSyms ())
@@ -321,7 +312,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
321
312
checkBare (op, result);
322
313
checkDevice (op, result);
323
314
checkHasDeviceAddr (op, result);
324
- checkHostEval (op, result);
325
315
checkInReduction (op, result);
326
316
checkIsDevicePtr (op, result);
327
317
checkPrivate (op, result);
@@ -4054,9 +4044,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
4054
4044
// /
4055
4045
// / Loop bounds and steps are only optionally populated, if output vectors are
4056
4046
// / provided.
4057
- static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4058
- Value &numTeamsLower, Value &numTeamsUpper,
4059
- Value &threadLimit) {
4047
+ static void
4048
+ extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4049
+ Value &numTeamsLower, Value &numTeamsUpper,
4050
+ Value &threadLimit,
4051
+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr ,
4052
+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr ,
4053
+ llvm::SmallVectorImpl<Value> *steps = nullptr ) {
4060
4054
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
4061
4055
for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
4062
4056
blockArgIface.getHostEvalBlockArgs ())) {
@@ -4081,11 +4075,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4081
4075
llvm_unreachable (" unsupported host_eval use" );
4082
4076
})
4083
4077
.Case ([&](omp::LoopNestOp loopOp) {
4084
- // TODO: Extract bounds and step values. Currently, this cannot be
4085
- // reached because translation would have been stopped earlier as a
4086
- // result of `checkImplementationStatus` detecting and reporting
4087
- // this situation.
4088
- llvm_unreachable (" unsupported host_eval use" );
4078
+ auto processBounds =
4079
+ [&](OperandRange opBounds,
4080
+ llvm::SmallVectorImpl<Value> *outBounds) -> bool {
4081
+ bool found = false ;
4082
+ for (auto [i, lb] : llvm::enumerate (opBounds)) {
4083
+ if (lb == blockArg) {
4084
+ found = true ;
4085
+ if (outBounds)
4086
+ (*outBounds)[i] = hostEvalVar;
4087
+ }
4088
+ }
4089
+ return found;
4090
+ };
4091
+ bool found =
4092
+ processBounds (loopOp.getLoopLowerBounds (), lowerBounds);
4093
+ found = processBounds (loopOp.getLoopUpperBounds (), upperBounds) ||
4094
+ found;
4095
+ found = processBounds (loopOp.getLoopSteps (), steps) || found;
4096
+ if (!found)
4097
+ llvm_unreachable (" unsupported host_eval use" );
4089
4098
})
4090
4099
.Default ([](Operation *) {
4091
4100
llvm_unreachable (" unsupported host_eval use" );
@@ -4222,6 +4231,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
4222
4231
combinedMaxThreadsVal = maxThreadsVal;
4223
4232
4224
4233
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4234
+ attrs.ExecFlags = targetOp.getKernelExecFlags ();
4225
4235
attrs.MinTeams = minTeamsVal;
4226
4236
attrs.MaxTeams .front () = maxTeamsVal;
4227
4237
attrs.MinThreads = 1 ;
@@ -4239,9 +4249,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4239
4249
LLVM::ModuleTranslation &moduleTranslation,
4240
4250
omp::TargetOp targetOp,
4241
4251
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4252
+ omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4253
+ targetOp.getInnermostCapturedOmpOp ());
4254
+ unsigned numLoops = loopOp ? loopOp.getNumLoops () : 0 ;
4255
+
4242
4256
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4257
+ llvm::SmallVector<Value> lowerBounds (numLoops), upperBounds (numLoops),
4258
+ steps (numLoops);
4243
4259
extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
4244
- teamsThreadLimit);
4260
+ teamsThreadLimit, &lowerBounds, &upperBounds, &steps );
4245
4261
4246
4262
// TODO: Handle constant 'if' clauses.
4247
4263
if (Value targetThreadLimit = targetOp.getThreadLimit ())
@@ -4261,7 +4277,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4261
4277
if (numThreads)
4262
4278
attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
4263
4279
4264
- // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4280
+ if (targetOp.getKernelExecFlags () != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4281
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
4282
+ attrs.LoopTripCount = nullptr ;
4283
+
4284
+ // To calculate the trip count, we multiply together the trip counts of
4285
+ // every collapsed canonical loop. We don't need to create the loop nests
4286
+ // here, since we're only interested in the trip count.
4287
+ for (auto [loopLower, loopUpper, loopStep] :
4288
+ llvm::zip_equal (lowerBounds, upperBounds, steps)) {
4289
+ llvm::Value *lowerBound = moduleTranslation.lookupValue (loopLower);
4290
+ llvm::Value *upperBound = moduleTranslation.lookupValue (loopUpper);
4291
+ llvm::Value *step = moduleTranslation.lookupValue (loopStep);
4292
+
4293
+ llvm::OpenMPIRBuilder::LocationDescription loc (builder);
4294
+ llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount (
4295
+ loc, lowerBound, upperBound, step, /* IsSigned=*/ true ,
4296
+ loopOp.getLoopInclusive ());
4297
+
4298
+ if (!attrs.LoopTripCount ) {
4299
+ attrs.LoopTripCount = tripCount;
4300
+ continue ;
4301
+ }
4302
+
4303
+ // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4304
+ attrs.LoopTripCount = builder.CreateMul (attrs.LoopTripCount , tripCount,
4305
+ {}, /* HasNUW=*/ true );
4306
+ }
4307
+ }
4265
4308
}
4266
4309
4267
4310
static LogicalResult
0 commit comments