@@ -542,6 +542,18 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
542
542
llvm_unreachable (" Unknown ClauseProcBindKind kind" );
543
543
}
544
544
545
+ // / Maps elements of \p blockArgs (which are MLIR values) to the corresponding
546
+ // / LLVM values of \p operands' elements. This is useful when an OpenMP region
547
+ // / with entry block arguments is converted to LLVM. In this case \p blockArgs
548
+ // / are (part of) of the OpenMP region's entry arguments and \p operands are
549
+ // / (part of) of the operands to the OpenMP op containing the region.
550
+ static void forwardArgs (LLVM::ModuleTranslation &moduleTranslation,
551
+ llvm::ArrayRef<BlockArgument> blockArgs,
552
+ OperandRange operands) {
553
+ for (auto [arg, var] : llvm::zip_equal (blockArgs, operands))
554
+ moduleTranslation.mapValue (arg, moduleTranslation.lookupValue (var));
555
+ }
556
+
545
557
// / Helper function to map block arguments defined by ignored loop wrappers to
546
558
// / LLVM values and prevent any uses of those from triggering null pointer
547
559
// / dereferences.
@@ -554,18 +566,12 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
554
566
// Map block arguments directly to the LLVM value associated to the
555
567
// corresponding operand. This is semantically equivalent to this wrapper not
556
568
// being present.
557
- auto forwardArgs =
558
- [&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
559
- OperandRange operands) {
560
- for (auto [arg, var] : llvm::zip_equal (blockArgs, operands))
561
- moduleTranslation.mapValue (arg, moduleTranslation.lookupValue (var));
562
- };
563
-
564
569
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
565
570
.Case ([&](omp::SimdOp op) {
566
571
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
567
- forwardArgs (blockArgIface.getPrivateBlockArgs (), op.getPrivateVars ());
568
- forwardArgs (blockArgIface.getReductionBlockArgs (),
572
+ forwardArgs (moduleTranslation, blockArgIface.getPrivateBlockArgs (),
573
+ op.getPrivateVars ());
574
+ forwardArgs (moduleTranslation, blockArgIface.getReductionBlockArgs (),
569
575
op.getReductionVars ());
570
576
op.emitWarning () << " simd information on composite construct discarded" ;
571
577
return success ();
@@ -5236,6 +5242,28 @@ convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
5236
5242
return convertHostOrTargetOperation (op, builder, moduleTranslation);
5237
5243
}
5238
5244
5245
+ // / Forwards private entry block arguments, \see forwardArgs for more details.
5246
+ template <typename OMPOp>
5247
+ static void forwardPrivateArgs (OMPOp ompOp,
5248
+ LLVM::ModuleTranslation &moduleTranslation) {
5249
+ auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*ompOp);
5250
+ if (blockArgIface) {
5251
+ forwardArgs (moduleTranslation, blockArgIface.getPrivateBlockArgs (),
5252
+ ompOp.getPrivateVars ());
5253
+ }
5254
+ }
5255
+
5256
+ // / Forwards reduction entry block arguments, \see forwardArgs for more details.
5257
+ template <typename OMPOp>
5258
+ static void forwardReductionArgs (OMPOp ompOp,
5259
+ LLVM::ModuleTranslation &moduleTranslation) {
5260
+ auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*ompOp);
5261
+ if (blockArgIface) {
5262
+ forwardArgs (moduleTranslation, blockArgIface.getReductionBlockArgs (),
5263
+ ompOp.getReductionVars ());
5264
+ }
5265
+ }
5266
+
5239
5267
static LogicalResult
5240
5268
convertTargetOpsInNest (Operation *op, llvm::IRBuilderBase &builder,
5241
5269
LLVM::ModuleTranslation &moduleTranslation) {
@@ -5255,6 +5283,51 @@ convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
5255
5283
return WalkResult::interrupt ();
5256
5284
return WalkResult::skip ();
5257
5285
}
5286
+
5287
+ // Non-target ops might nest target-related ops, therefore, we
5288
+ // translate them as non-OpenMP scopes. Translating them is needed by
5289
+ // nested target-related ops since they might LLVM values defined in
5290
+ // their parent non-target ops.
5291
+ if (isa<omp::OpenMPDialect>(oper->getDialect ()) &&
5292
+ oper->getParentOfType <LLVM::LLVMFuncOp>() &&
5293
+ !oper->getRegions ().empty ()) {
5294
+
5295
+ // TODO Handle other ops with entry block args.
5296
+ llvm::TypeSwitch<Operation &>(*oper)
5297
+ .Case ([&](omp::WsloopOp wsloopOp) {
5298
+ forwardPrivateArgs (wsloopOp, moduleTranslation);
5299
+ forwardReductionArgs (wsloopOp, moduleTranslation);
5300
+ })
5301
+ .Case ([&](omp::ParallelOp parallelOp) {
5302
+ forwardPrivateArgs (parallelOp, moduleTranslation);
5303
+ forwardReductionArgs (parallelOp, moduleTranslation);
5304
+ })
5305
+ .Case ([&](omp::TaskOp taskOp) {
5306
+ forwardPrivateArgs (taskOp, moduleTranslation);
5307
+ });
5308
+
5309
+ if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5310
+ for (auto iv : loopNest.getIVs ()) {
5311
+ // Create fake allocas just to maintain IR validity.
5312
+ moduleTranslation.mapValue (
5313
+ iv, builder.CreateAlloca (
5314
+ moduleTranslation.convertType (iv.getType ())));
5315
+ }
5316
+ }
5317
+
5318
+ for (Region ®ion : oper->getRegions ()) {
5319
+ auto result = convertOmpOpRegions (
5320
+ region, oper->getName ().getStringRef ().str () + " .fake.region" ,
5321
+ builder, moduleTranslation);
5322
+ if (failed (handleError (result, *oper)))
5323
+ return WalkResult::interrupt ();
5324
+
5325
+ builder.SetInsertPoint (result.get (), result.get ()->end ());
5326
+ }
5327
+
5328
+ return WalkResult::skip ();
5329
+ }
5330
+
5258
5331
return WalkResult::advance ();
5259
5332
}).wasInterrupted ();
5260
5333
return failure (interrupted);
0 commit comments