@@ -287,7 +287,7 @@ mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
287
287
288
288
void LatticePoint::print (llvm::raw_ostream &os) const {
289
289
for (const auto &[value, state] : stateMap) {
290
- os << value << " : " ;
290
+ os << " \n * " << value << " : " ;
291
291
::print (os, state);
292
292
}
293
293
}
@@ -361,6 +361,13 @@ void AllocationAnalysis::visitOperation(mlir::Operation *op,
361
361
} else if (mlir::isa<fir::FreeMemOp>(op)) {
362
362
assert (op->getNumOperands () == 1 && " fir.freemem has one operand" );
363
363
mlir::Value operand = op->getOperand (0 );
364
+
365
+ // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir
366
+ // to fir. Therefore, we only need to handle `fir::DeclareOp`s.
367
+ if (auto declareOp =
368
+ llvm::dyn_cast_if_present<fir::DeclareOp>(operand.getDefiningOp ()))
369
+ operand = declareOp.getMemref ();
370
+
364
371
std::optional<AllocationState> operandState = before.get (operand);
365
372
if (operandState && *operandState == AllocationState::Allocated) {
366
373
// don't tag things not allocated in this function as freed, so that we
@@ -452,6 +459,9 @@ StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
452
459
};
453
460
func->walk ([&](mlir::func::ReturnOp child) { joinOperationLattice (child); });
454
461
func->walk ([&](fir::UnreachableOp child) { joinOperationLattice (child); });
462
+ func->walk (
463
+ [&](mlir::omp::TerminatorOp child) { joinOperationLattice (child); });
464
+
455
465
llvm::DenseSet<mlir::Value> freedValues;
456
466
point.appendFreedValues (freedValues);
457
467
@@ -518,9 +528,18 @@ AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
518
528
519
529
// remove freemem operations
520
530
llvm::SmallVector<mlir::Operation *> erases;
521
- for (mlir::Operation *user : allocmem.getOperation ()->getUsers ())
531
+ for (mlir::Operation *user : allocmem.getOperation ()->getUsers ()) {
532
+ if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
533
+ for (mlir::Operation *user : declareOp->getUsers ()) {
534
+ if (mlir::isa<fir::FreeMemOp>(user))
535
+ erases.push_back (user);
536
+ }
537
+ }
538
+
522
539
if (mlir::isa<fir::FreeMemOp>(user))
523
540
erases.push_back (user);
541
+ }
542
+
524
543
// now we are done iterating the users, it is safe to mutate them
525
544
for (mlir::Operation *erase : erases)
526
545
rewriter.eraseOp (erase);
@@ -633,9 +652,19 @@ AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
633
652
634
653
// find freemem ops
635
654
llvm::SmallVector<mlir::Operation *, 1 > freeOps;
636
- for (mlir::Operation *user : oldAllocOp->getUsers ())
655
+
656
+ for (mlir::Operation *user : oldAllocOp->getUsers ()) {
657
+ if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
658
+ for (mlir::Operation *user : declareOp->getUsers ()) {
659
+ if (mlir::isa<fir::FreeMemOp>(user))
660
+ freeOps.push_back (user);
661
+ }
662
+ }
663
+
637
664
if (mlir::isa<fir::FreeMemOp>(user))
638
665
freeOps.push_back (user);
666
+ }
667
+
639
668
assert (freeOps.size () && " DFA should only return freed memory" );
640
669
641
670
// Don't attempt to reason about a stacksave/stackrestore between different
@@ -717,12 +746,23 @@ void AllocMemConversion::insertStackSaveRestore(
717
746
mlir::SymbolRefAttr stackRestoreSym =
718
747
builder.getSymbolRefAttr (stackRestoreFn.getName ());
719
748
749
+ auto createStackRestoreCall = [&](mlir::Operation *user) {
750
+ builder.setInsertionPoint (user);
751
+ builder.create <fir::CallOp>(user->getLoc (),
752
+ stackRestoreFn.getFunctionType ().getResults (),
753
+ stackRestoreSym, mlir::ValueRange{sp});
754
+ };
755
+
720
756
for (mlir::Operation *user : oldAlloc->getUsers ()) {
757
+ if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
758
+ for (mlir::Operation *user : declareOp->getUsers ()) {
759
+ if (mlir::isa<fir::FreeMemOp>(user))
760
+ createStackRestoreCall (user);
761
+ }
762
+ }
763
+
721
764
if (mlir::isa<fir::FreeMemOp>(user)) {
722
- builder.setInsertionPoint (user);
723
- builder.create <fir::CallOp>(user->getLoc (),
724
- stackRestoreFn.getFunctionType ().getResults (),
725
- stackRestoreSym, mlir::ValueRange{sp});
765
+ createStackRestoreCall (user);
726
766
}
727
767
}
728
768
0 commit comments