@@ -173,15 +173,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
173
173
if (op.getHint ())
174
174
op.emitWarning (" hint clause discarded" );
175
175
};
176
- auto checkHostEval = [](auto op, LogicalResult &result) {
177
- // Host evaluated clauses are supported, except for loop bounds.
178
- for (BlockArgument arg :
179
- cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
180
- for (Operation *user : arg.getUsers ())
181
- if (isa<omp::LoopNestOp>(user))
182
- result = op.emitError (" not yet implemented: host evaluation of loop "
183
- " bounds in omp.target operation" );
184
- };
185
176
auto checkInReduction = [&todo](auto op, LogicalResult &result) {
186
177
if (!op.getInReductionVars ().empty () || op.getInReductionByref () ||
187
178
op.getInReductionSyms ())
@@ -318,7 +309,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
318
309
checkBare (op, result);
319
310
checkDevice (op, result);
320
311
checkHasDeviceAddr (op, result);
321
- checkHostEval (op, result);
322
312
checkInReduction (op, result);
323
313
checkIsDevicePtr (op, result);
324
314
checkPrivate (op, result);
@@ -4058,9 +4048,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
4058
4048
// /
4059
4049
// / Loop bounds and steps are only optionally populated, if output vectors are
4060
4050
// / provided.
4061
- static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4062
- Value &numTeamsLower, Value &numTeamsUpper,
4063
- Value &threadLimit) {
4051
+ static void
4052
+ extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4053
+ Value &numTeamsLower, Value &numTeamsUpper,
4054
+ Value &threadLimit,
4055
+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr ,
4056
+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr ,
4057
+ llvm::SmallVectorImpl<Value> *steps = nullptr ) {
4064
4058
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
4065
4059
for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
4066
4060
blockArgIface.getHostEvalBlockArgs ())) {
@@ -4085,11 +4079,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4085
4079
llvm_unreachable (" unsupported host_eval use" );
4086
4080
})
4087
4081
.Case ([&](omp::LoopNestOp loopOp) {
4088
- // TODO: Extract bounds and step values. Currently, this cannot be
4089
- // reached because translation would have been stopped earlier as a
4090
- // result of `checkImplementationStatus` detecting and reporting
4091
- // this situation.
4092
- llvm_unreachable (" unsupported host_eval use" );
4082
+ auto processBounds =
4083
+ [&](OperandRange opBounds,
4084
+ llvm::SmallVectorImpl<Value> *outBounds) -> bool {
4085
+ bool found = false ;
4086
+ for (auto [i, lb] : llvm::enumerate (opBounds)) {
4087
+ if (lb == blockArg) {
4088
+ found = true ;
4089
+ if (outBounds)
4090
+ (*outBounds)[i] = hostEvalVar;
4091
+ }
4092
+ }
4093
+ return found;
4094
+ };
4095
+ bool found =
4096
+ processBounds (loopOp.getLoopLowerBounds (), lowerBounds);
4097
+ found = processBounds (loopOp.getLoopUpperBounds (), upperBounds) ||
4098
+ found;
4099
+ found = processBounds (loopOp.getLoopSteps (), steps) || found;
4100
+ if (!found)
4101
+ llvm_unreachable (" unsupported host_eval use" );
4093
4102
})
4094
4103
.Default ([](Operation *) {
4095
4104
llvm_unreachable (" unsupported host_eval use" );
@@ -4226,6 +4235,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
4226
4235
combinedMaxThreadsVal = maxThreadsVal;
4227
4236
4228
4237
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4238
+ attrs.ExecFlags = targetOp.getKernelExecFlags ();
4229
4239
attrs.MinTeams = minTeamsVal;
4230
4240
attrs.MaxTeams .front () = maxTeamsVal;
4231
4241
attrs.MinThreads = 1 ;
@@ -4243,9 +4253,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4243
4253
LLVM::ModuleTranslation &moduleTranslation,
4244
4254
omp::TargetOp targetOp,
4245
4255
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4256
+ omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4257
+ targetOp.getInnermostCapturedOmpOp ());
4258
+ unsigned numLoops = loopOp ? loopOp.getNumLoops () : 0 ;
4259
+
4246
4260
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4261
+ llvm::SmallVector<Value> lowerBounds (numLoops), upperBounds (numLoops),
4262
+ steps (numLoops);
4247
4263
extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
4248
- teamsThreadLimit);
4264
+ teamsThreadLimit, &lowerBounds, &upperBounds, &steps );
4249
4265
4250
4266
// TODO: Handle constant 'if' clauses.
4251
4267
if (Value targetThreadLimit = targetOp.getThreadLimit ())
@@ -4265,7 +4281,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4265
4281
if (numThreads)
4266
4282
attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
4267
4283
4268
- // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4284
+ if (targetOp.getKernelExecFlags () != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4285
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
4286
+ attrs.LoopTripCount = nullptr ;
4287
+
4288
+ // To calculate the trip count, we multiply together the trip counts of
4289
+ // every collapsed canonical loop. We don't need to create the loop nests
4290
+ // here, since we're only interested in the trip count.
4291
+ for (auto [loopLower, loopUpper, loopStep] :
4292
+ llvm::zip_equal (lowerBounds, upperBounds, steps)) {
4293
+ llvm::Value *lowerBound = moduleTranslation.lookupValue (loopLower);
4294
+ llvm::Value *upperBound = moduleTranslation.lookupValue (loopUpper);
4295
+ llvm::Value *step = moduleTranslation.lookupValue (loopStep);
4296
+
4297
+ llvm::OpenMPIRBuilder::LocationDescription loc (builder);
4298
+ llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount (
4299
+ loc, lowerBound, upperBound, step, /* IsSigned=*/ true ,
4300
+ loopOp.getLoopInclusive ());
4301
+
4302
+ if (!attrs.LoopTripCount ) {
4303
+ attrs.LoopTripCount = tripCount;
4304
+ continue ;
4305
+ }
4306
+
4307
+ // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4308
+ attrs.LoopTripCount = builder.CreateMul (attrs.LoopTripCount , tripCount,
4309
+ {}, /* HasNUW=*/ true );
4310
+ }
4311
+ }
4269
4312
}
4270
4313
4271
4314
static LogicalResult
0 commit comments