@@ -537,6 +537,18 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
537
537
llvm_unreachable (" Unknown ClauseProcBindKind kind" );
538
538
}
539
539
540
+ // / Maps elements of \p blockArgs (which are MLIR values) to the corresponding
541
+ // / LLVM values of \p operands' elements. This is useful when an OpenMP region
542
+ // / with entry block arguments is converted to LLVM. In this case \p blockArgs
543
+ // / are (part of) of the OpenMP region's entry arguments and \p operands are
544
+ // / (part of) of the operands to the OpenMP op containing the region.
545
+ static void forwardArgs (LLVM::ModuleTranslation &moduleTranslation,
546
+ llvm::ArrayRef<BlockArgument> blockArgs,
547
+ OperandRange operands) {
548
+ for (auto [arg, var] : llvm::zip_equal (blockArgs, operands))
549
+ moduleTranslation.mapValue (arg, moduleTranslation.lookupValue (var));
550
+ }
551
+
540
552
// / Helper function to map block arguments defined by ignored loop wrappers to
541
553
// / LLVM values and prevent any uses of those from triggering null pointer
542
554
// / dereferences.
@@ -549,18 +561,12 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
549
561
// Map block arguments directly to the LLVM value associated to the
550
562
// corresponding operand. This is semantically equivalent to this wrapper not
551
563
// being present.
552
- auto forwardArgs =
553
- [&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
554
- OperandRange operands) {
555
- for (auto [arg, var] : llvm::zip_equal (blockArgs, operands))
556
- moduleTranslation.mapValue (arg, moduleTranslation.lookupValue (var));
557
- };
558
-
559
564
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
560
565
.Case ([&](omp::SimdOp op) {
561
566
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
562
- forwardArgs (blockArgIface.getPrivateBlockArgs (), op.getPrivateVars ());
563
- forwardArgs (blockArgIface.getReductionBlockArgs (),
567
+ forwardArgs (moduleTranslation, blockArgIface.getPrivateBlockArgs (),
568
+ op.getPrivateVars ());
569
+ forwardArgs (moduleTranslation, blockArgIface.getReductionBlockArgs (),
564
570
op.getReductionVars ());
565
571
op.emitWarning () << " simd information on composite construct discarded" ;
566
572
return success ();
@@ -5296,6 +5302,28 @@ convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
5296
5302
return convertHostOrTargetOperation (op, builder, moduleTranslation);
5297
5303
}
5298
5304
5305
+ // / Forwards private entry block arguments, \see forwardArgs for more details.
5306
+ template <typename OMPOp>
5307
+ static void forwardPrivateArgs (OMPOp ompOp,
5308
+ LLVM::ModuleTranslation &moduleTranslation) {
5309
+ auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*ompOp);
5310
+ if (blockArgIface) {
5311
+ forwardArgs (moduleTranslation, blockArgIface.getPrivateBlockArgs (),
5312
+ ompOp.getPrivateVars ());
5313
+ }
5314
+ }
5315
+
5316
+ // / Forwards reduction entry block arguments, \see forwardArgs for more details.
5317
+ template <typename OMPOp>
5318
+ static void forwardReductionArgs (OMPOp ompOp,
5319
+ LLVM::ModuleTranslation &moduleTranslation) {
5320
+ auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*ompOp);
5321
+ if (blockArgIface) {
5322
+ forwardArgs (moduleTranslation, blockArgIface.getReductionBlockArgs (),
5323
+ ompOp.getReductionVars ());
5324
+ }
5325
+ }
5326
+
5299
5327
static LogicalResult
5300
5328
convertTargetOpsInNest (Operation *op, llvm::IRBuilderBase &builder,
5301
5329
LLVM::ModuleTranslation &moduleTranslation) {
@@ -5315,6 +5343,51 @@ convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
5315
5343
return WalkResult::interrupt ();
5316
5344
return WalkResult::skip ();
5317
5345
}
5346
+
5347
+ // Non-target ops might nest target-related ops, therefore, we
5348
+ // translate them as non-OpenMP scopes. Translating them is needed by
5349
+ // nested target-related ops since they might LLVM values defined in
5350
+ // their parent non-target ops.
5351
+ if (isa<omp::OpenMPDialect>(oper->getDialect ()) &&
5352
+ oper->getParentOfType <LLVM::LLVMFuncOp>() &&
5353
+ !oper->getRegions ().empty ()) {
5354
+
5355
+ // TODO Handle other ops with entry block args.
5356
+ llvm::TypeSwitch<Operation &>(*oper)
5357
+ .Case ([&](omp::WsloopOp wsloopOp) {
5358
+ forwardPrivateArgs (wsloopOp, moduleTranslation);
5359
+ forwardReductionArgs (wsloopOp, moduleTranslation);
5360
+ })
5361
+ .Case ([&](omp::ParallelOp parallelOp) {
5362
+ forwardPrivateArgs (parallelOp, moduleTranslation);
5363
+ forwardReductionArgs (parallelOp, moduleTranslation);
5364
+ })
5365
+ .Case ([&](omp::TaskOp taskOp) {
5366
+ forwardPrivateArgs (taskOp, moduleTranslation);
5367
+ });
5368
+
5369
+ if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5370
+ for (auto iv : loopNest.getIVs ()) {
5371
+ // Create fake allocas just to maintain IR validity.
5372
+ moduleTranslation.mapValue (
5373
+ iv, builder.CreateAlloca (
5374
+ moduleTranslation.convertType (iv.getType ())));
5375
+ }
5376
+ }
5377
+
5378
+ for (Region ®ion : oper->getRegions ()) {
5379
+ auto result = convertOmpOpRegions (
5380
+ region, oper->getName ().getStringRef ().str () + " .fake.region" ,
5381
+ builder, moduleTranslation);
5382
+ if (failed (handleError (result, *oper)))
5383
+ return WalkResult::interrupt ();
5384
+
5385
+ builder.SetInsertPoint (result.get (), result.get ()->end ());
5386
+ }
5387
+
5388
+ return WalkResult::skip ();
5389
+ }
5390
+
5318
5391
return WalkResult::advance ();
5319
5392
}).wasInterrupted ();
5320
5393
return failure (interrupted);
0 commit comments