@@ -173,15 +173,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
173
173
if (op.getHint ())
174
174
op.emitWarning (" hint clause discarded" );
175
175
};
176
- auto checkHostEval = [](auto op, LogicalResult &result) {
177
- // Host evaluated clauses are supported, except for loop bounds.
178
- for (BlockArgument arg :
179
- cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
180
- for (Operation *user : arg.getUsers ())
181
- if (isa<omp::LoopNestOp>(user))
182
- result = op.emitError (" not yet implemented: host evaluation of loop "
183
- " bounds in omp.target operation" );
184
- };
185
176
auto checkInReduction = [&todo](auto op, LogicalResult &result) {
186
177
if (!op.getInReductionVars ().empty () || op.getInReductionByref () ||
187
178
op.getInReductionSyms ())
@@ -318,7 +309,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
318
309
checkBare (op, result);
319
310
checkDevice (op, result);
320
311
checkHasDeviceAddr (op, result);
321
- checkHostEval (op, result);
322
312
checkInReduction (op, result);
323
313
checkIsDevicePtr (op, result);
324
314
checkPrivate (op, result);
@@ -4158,9 +4148,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
4158
4148
// /
4159
4149
// / Loop bounds and steps are only optionally populated, if output vectors are
4160
4150
// / provided.
4161
- static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4162
- Value &numTeamsLower, Value &numTeamsUpper,
4163
- Value &threadLimit) {
4151
+ static void
4152
+ extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
4153
+ Value &numTeamsLower, Value &numTeamsUpper,
4154
+ Value &threadLimit,
4155
+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr ,
4156
+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr ,
4157
+ llvm::SmallVectorImpl<Value> *steps = nullptr ) {
4164
4158
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
4165
4159
for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
4166
4160
blockArgIface.getHostEvalBlockArgs ())) {
@@ -4185,11 +4179,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4185
4179
llvm_unreachable (" unsupported host_eval use" );
4186
4180
})
4187
4181
.Case ([&](omp::LoopNestOp loopOp) {
4188
- // TODO: Extract bounds and step values. Currently, this cannot be
4189
- // reached because translation would have been stopped earlier as a
4190
- // result of `checkImplementationStatus` detecting and reporting
4191
- // this situation.
4192
- llvm_unreachable (" unsupported host_eval use" );
4182
+ auto processBounds =
4183
+ [&](OperandRange opBounds,
4184
+ llvm::SmallVectorImpl<Value> *outBounds) -> bool {
4185
+ bool found = false ;
4186
+ for (auto [i, lb] : llvm::enumerate (opBounds)) {
4187
+ if (lb == blockArg) {
4188
+ found = true ;
4189
+ if (outBounds)
4190
+ (*outBounds)[i] = hostEvalVar;
4191
+ }
4192
+ }
4193
+ return found;
4194
+ };
4195
+ bool found =
4196
+ processBounds (loopOp.getLoopLowerBounds (), lowerBounds);
4197
+ found = processBounds (loopOp.getLoopUpperBounds (), upperBounds) ||
4198
+ found;
4199
+ found = processBounds (loopOp.getLoopSteps (), steps) || found;
4200
+ (void )found;
4201
+ assert (found && " unsupported host_eval use" );
4193
4202
})
4194
4203
.Default ([](Operation *) {
4195
4204
llvm_unreachable (" unsupported host_eval use" );
@@ -4326,6 +4335,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
4326
4335
combinedMaxThreadsVal = maxThreadsVal;
4327
4336
4328
4337
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4338
+ attrs.ExecFlags = targetOp.getKernelExecFlags ();
4329
4339
attrs.MinTeams = minTeamsVal;
4330
4340
attrs.MaxTeams .front () = maxTeamsVal;
4331
4341
attrs.MinThreads = 1 ;
@@ -4343,9 +4353,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4343
4353
LLVM::ModuleTranslation &moduleTranslation,
4344
4354
omp::TargetOp targetOp,
4345
4355
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4356
+ omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4357
+ targetOp.getInnermostCapturedOmpOp ());
4358
+ unsigned numLoops = loopOp ? loopOp.getNumLoops () : 0 ;
4359
+
4346
4360
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4361
+ llvm::SmallVector<Value> lowerBounds (numLoops), upperBounds (numLoops),
4362
+ steps (numLoops);
4347
4363
extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
4348
- teamsThreadLimit);
4364
+ teamsThreadLimit, &lowerBounds, &upperBounds, &steps );
4349
4365
4350
4366
// TODO: Handle constant 'if' clauses.
4351
4367
if (Value targetThreadLimit = targetOp.getThreadLimit ())
@@ -4365,7 +4381,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4365
4381
if (numThreads)
4366
4382
attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
4367
4383
4368
- // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4384
+ if (targetOp.getKernelExecFlags () != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4385
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
4386
+ attrs.LoopTripCount = nullptr ;
4387
+
4388
+ // To calculate the trip count, we multiply together the trip counts of
4389
+ // every collapsed canonical loop. We don't need to create the loop nests
4390
+ // here, since we're only interested in the trip count.
4391
+ for (auto [loopLower, loopUpper, loopStep] :
4392
+ llvm::zip_equal (lowerBounds, upperBounds, steps)) {
4393
+ llvm::Value *lowerBound = moduleTranslation.lookupValue (loopLower);
4394
+ llvm::Value *upperBound = moduleTranslation.lookupValue (loopUpper);
4395
+ llvm::Value *step = moduleTranslation.lookupValue (loopStep);
4396
+
4397
+ llvm::OpenMPIRBuilder::LocationDescription loc (builder);
4398
+ llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount (
4399
+ loc, lowerBound, upperBound, step, /* IsSigned=*/ true ,
4400
+ loopOp.getLoopInclusive ());
4401
+
4402
+ if (!attrs.LoopTripCount ) {
4403
+ attrs.LoopTripCount = tripCount;
4404
+ continue ;
4405
+ }
4406
+
4407
+ // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4408
+ attrs.LoopTripCount = builder.CreateMul (attrs.LoopTripCount , tripCount,
4409
+ {}, /* HasNUW=*/ true );
4410
+ }
4411
+ }
4369
4412
}
4370
4413
4371
4414
static LogicalResult
0 commit comments