@@ -361,6 +361,64 @@ void sinkLoopIVArgs(mlir::ConversionPatternRewriter &rewriter,
361
361
++idx;
362
362
}
363
363
}
364
+
365
+ // / Collects values that are local to a loop: "loop-local values". A loop-local
366
+ // / value is one that is used exclusively inside the loop but allocated outside
367
+ // / of it. This usually corresponds to temporary values that are used inside the
368
+ // / loop body for initialzing other variables for example.
369
+ // /
370
+ // / See `flang/test/Transforms/DoConcurrent/locally_destroyed_temp.f90` for an
371
+ // / example of why we need this.
372
+ // /
373
+ // / \param [in] doLoop - the loop within which the function searches for values
374
+ // / used exclusively inside.
375
+ // /
376
+ // / \param [out] locals - the list of loop-local values detected for \p doLoop.
377
+ void collectLoopLocalValues (fir::DoLoopOp doLoop,
378
+ llvm::SetVector<mlir::Value> &locals) {
379
+ doLoop.walk ([&](mlir::Operation *op) {
380
+ for (mlir::Value operand : op->getOperands ()) {
381
+ if (locals.contains (operand))
382
+ continue ;
383
+
384
+ bool isLocal = true ;
385
+
386
+ if (!mlir::isa_and_present<fir::AllocaOp>(operand.getDefiningOp ()))
387
+ continue ;
388
+
389
+ // Values defined inside the loop are not interesting since they do not
390
+ // need to be localized.
391
+ if (doLoop->isAncestor (operand.getDefiningOp ()))
392
+ continue ;
393
+
394
+ for (auto *user : operand.getUsers ()) {
395
+ if (!doLoop->isAncestor (user)) {
396
+ isLocal = false ;
397
+ break ;
398
+ }
399
+ }
400
+
401
+ if (isLocal)
402
+ locals.insert (operand);
403
+ }
404
+ });
405
+ }
406
+
407
+ // / For a "loop-local" value \p local within a loop's scope, localizes that
408
+ // / value within the scope of the parallel region the loop maps to. Towards that
409
+ // / end, this function moves the allocation of \p local within \p allocRegion.
410
+ // /
411
+ // / \param local - the value used exclusively within a loop's scope (see
412
+ // / collectLoopLocalValues).
413
+ // /
414
+ // / \param allocRegion - the parallel region where \p local's allocation will be
415
+ // / privatized.
416
+ // /
417
+ // / \param rewriter - builder used for updating \p allocRegion.
418
+ static void localizeLoopLocalValue (mlir::Value local, mlir::Region &allocRegion,
419
+ mlir::ConversionPatternRewriter &rewriter) {
420
+ rewriter.moveOpBefore (local.getDefiningOp (), &allocRegion.front ().front ());
421
+ }
364
422
} // namespace looputils
365
423
366
424
class DoConcurrentConversion : public mlir ::OpConversionPattern<fir::DoLoopOp> {
@@ -383,13 +441,21 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
383
441
" Some `do concurent` loops are not perfectly-nested. "
384
442
" These will be serialized." );
385
443
444
+ llvm::SetVector<mlir::Value> locals;
445
+ looputils::collectLoopLocalValues (loopNest.back ().first , locals);
386
446
looputils::sinkLoopIVArgs (rewriter, loopNest);
447
+
387
448
mlir::IRMapping mapper;
388
- genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
449
+ mlir::omp::ParallelOp parallelOp =
450
+ genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
389
451
mlir::omp::LoopNestOperands loopNestClauseOps;
390
452
genLoopNestClauseOps (doLoop.getLoc (), rewriter, loopNest, mapper,
391
453
loopNestClauseOps);
392
454
455
+ for (mlir::Value local : locals)
456
+ looputils::localizeLoopLocalValue (local, parallelOp.getRegion (),
457
+ rewriter);
458
+
393
459
mlir::omp::LoopNestOp ompLoopNest =
394
460
genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps,
395
461
/* isComposite=*/ mapToDevice);
0 commit comments