@@ -343,6 +343,115 @@ class GenericLoopConversionPattern
343
343
}
344
344
};
345
345
346
+ // / According to the spec (v5.2, p340, 36):
347
+ // /
348
+ // / ```
349
+ // / The effect of the reduction clause is as if it is applied to all leaf
350
+ // / constructs that permit the clause, except for the following constructs:
351
+ // / * ....
352
+ // / * The teams construct, when combined with the loop construct.
353
+ // / ```
354
+ // /
355
+ // / Therefore, for a combined directive similar to: `!$omp teams loop
356
+ // / reduction(...)`, the earlier stages of the compiler assign the `reduction`
357
+ // / clauses only to the `loop` leaf and not to the `teams` leaf.
358
+ // /
359
+ // / On the other hand, if we have a combined construct similar to: `!$omp teams
360
+ // / distribute parallel do`, the `reduction` clauses are assigned both to the
361
+ // / `teams` and the `do` leaves. We need to match this behavior when we convert
362
+ // / `teams` op with a nested `loop` op since the target set of constructs/ops
363
+ // / will be incorrect without moving the reductions up to the `teams` op as
364
+ // / well.
365
+ // /
366
+ // / This pattern does exactly this. Given the following input:
367
+ // / ```
368
+ // / omp.teams {
369
+ // / omp.loop reduction(@red_sym %red_op -> %red_arg : !fir.ref<i32>) {
370
+ // / omp.loop_nest ... {
371
+ // / ...
372
+ // / }
373
+ // / }
374
+ // / }
375
+ // / ```
376
+ // / this pattern updates the `omp.teams` op in-place to:
377
+ // / ```
378
+ // / omp.teams reduction(@red_sym %red_op -> %teams_red_arg : !fir.ref<i32>) {
379
+ // / omp.loop reduction(@red_sym %teams_red_arg -> %red_arg : !fir.ref<i32>) {
380
+ // / omp.loop_nest ... {
381
+ // / ...
382
+ // / }
383
+ // / }
384
+ // / }
385
+ // / ```
386
+ // /
387
+ // / Note the following:
388
+ // / * The nested `omp.loop` is not rewritten by this pattern, this happens
389
+ // / through `GenericLoopConversionPattern`.
390
+ // / * The reduction info are cloned from the nested `omp.loop` op to the parent
391
+ // / `omp.teams` op.
392
+ // / * The reduction operand of the `omp.loop` op is updated to be the **new**
393
+ // / reduction block argument of the `omp.teams` op.
394
+ class ReductionsHoistingPattern
395
+ : public mlir::OpConversionPattern<mlir::omp::TeamsOp> {
396
+ public:
397
+ using mlir::OpConversionPattern<mlir::omp::TeamsOp>::OpConversionPattern;
398
+
399
+ static mlir::omp::LoopOp
400
+ tryToFindNestedLoopWithReduction (mlir::omp::TeamsOp teamsOp) {
401
+ if (teamsOp.getRegion ().getBlocks ().size () != 1 )
402
+ return nullptr ;
403
+
404
+ mlir::Block &teamsBlock = *teamsOp.getRegion ().begin ();
405
+ auto loopOpIter = llvm::find_if (teamsBlock, [](mlir::Operation &op) {
406
+ auto nestedLoopOp = llvm::dyn_cast<mlir::omp::LoopOp>(&op);
407
+
408
+ if (!nestedLoopOp)
409
+ return false ;
410
+
411
+ return !nestedLoopOp.getReductionVars ().empty ();
412
+ });
413
+
414
+ if (loopOpIter == teamsBlock.end ())
415
+ return nullptr ;
416
+
417
+ // TODO return error if more than one loop op is nested. We need to
418
+ // coalesce reductions in this case.
419
+ return llvm::cast<mlir::omp::LoopOp>(loopOpIter);
420
+ }
421
+
422
+ mlir::LogicalResult
423
+ matchAndRewrite (mlir::omp::TeamsOp teamsOp, OpAdaptor adaptor,
424
+ mlir::ConversionPatternRewriter &rewriter) const override {
425
+ mlir::omp::LoopOp nestedLoopOp = tryToFindNestedLoopWithReduction (teamsOp);
426
+
427
+ rewriter.modifyOpInPlace (teamsOp, [&]() {
428
+ teamsOp.setReductionMod (nestedLoopOp.getReductionMod ());
429
+ teamsOp.getReductionVarsMutable ().assign (nestedLoopOp.getReductionVars ());
430
+ teamsOp.setReductionByref (nestedLoopOp.getReductionByref ());
431
+ teamsOp.setReductionSymsAttr (nestedLoopOp.getReductionSymsAttr ());
432
+
433
+ auto blockArgIface =
434
+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*teamsOp);
435
+ unsigned reductionArgsStart = blockArgIface.getPrivateBlockArgsStart () +
436
+ blockArgIface.numPrivateBlockArgs ();
437
+ llvm::SmallVector<mlir::Value> newLoopOpReductionOperands;
438
+
439
+ for (auto [idx, reductionVar] :
440
+ llvm::enumerate (nestedLoopOp.getReductionVars ())) {
441
+ mlir::BlockArgument newTeamsOpReductionBlockArg =
442
+ teamsOp.getRegion ().insertArgument (reductionArgsStart + idx,
443
+ reductionVar.getType (),
444
+ reductionVar.getLoc ());
445
+ newLoopOpReductionOperands.push_back (newTeamsOpReductionBlockArg);
446
+ }
447
+
448
+ nestedLoopOp.getReductionVarsMutable ().assign (newLoopOpReductionOperands);
449
+ });
450
+
451
+ return mlir::success ();
452
+ }
453
+ };
454
+
346
455
class GenericLoopConversionPass
347
456
: public flangomp::impl::GenericLoopConversionPassBase<
348
457
GenericLoopConversionPass> {
@@ -357,11 +466,23 @@ class GenericLoopConversionPass
357
466
358
467
mlir::MLIRContext *context = &getContext ();
359
468
mlir::RewritePatternSet patterns (context);
360
- patterns.insert <GenericLoopConversionPattern>(context);
469
+ patterns.insert <ReductionsHoistingPattern, GenericLoopConversionPattern>(
470
+ context);
361
471
mlir::ConversionTarget target (*context);
362
472
363
473
target.markUnknownOpDynamicallyLegal (
364
474
[](mlir::Operation *) { return true ; });
475
+
476
+ target.addDynamicallyLegalOp <mlir::omp::TeamsOp>(
477
+ [](mlir::omp::TeamsOp teamsOp) {
478
+ // If teamsOp's reductions are already populated, then the op is
479
+ // legal. Additionally, the op is legal if it does not nest a LoopOp
480
+ // with reductions.
481
+ return !teamsOp.getReductionVars ().empty () ||
482
+ ReductionsHoistingPattern::tryToFindNestedLoopWithReduction (
483
+ teamsOp) == nullptr ;
484
+ });
485
+
365
486
target.addDynamicallyLegalOp <mlir::omp::LoopOp>(
366
487
[](mlir::omp::LoopOp loopOp) {
367
488
return mlir::failed (
0 commit comments