Skip to content

Commit e3cd88a

Browse files
authored
[flang] Fixed StackArrays assertion after #121919. (#122550)
`findAllocaLoopInsertionPoint()` hit assertion not being able to find the `fir.freemem` because of the `fir.convert`. I think it is better to look for `fir.freemem` same way with the look-through walk.
1 parent 01ee66e commit e3cd88a

File tree

2 files changed

+73
-23
lines changed

2 files changed

+73
-23
lines changed

flang/lib/Optimizer/Transforms/StackArrays.cpp

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,19 @@ class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
198198

199199
/// Determine where to insert the alloca operation. The returned value should
200200
/// be checked to see if it is inside a loop
201-
static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
201+
static InsertionPoint
202+
findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc,
203+
const llvm::SmallVector<mlir::Operation *> &freeOps);
202204

203205
private:
204206
/// Handle to the DFA (already run)
205207
const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
206208

207209
/// If we failed to find an insertion point not inside a loop, see if it would
208210
/// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
209-
static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc);
211+
static InsertionPoint findAllocaLoopInsertionPoint(
212+
fir::AllocMemOp &oldAlloc,
213+
const llvm::SmallVector<mlir::Operation *> &freeOps);
210214

211215
/// Returns the alloca if it was successfully inserted, otherwise {}
212216
std::optional<fir::AllocaOp>
@@ -484,14 +488,31 @@ StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
484488
llvm::DenseSet<mlir::Value> freedValues;
485489
point.appendFreedValues(freedValues);
486490

491+
// Find all fir.freemem operations corresponding to fir.allocmem
492+
// in freedValues. It is best to find the association going back
493+
// from fir.freemem to fir.allocmem through the def-use chains,
494+
// so that we can use lookThroughDeclaresAndConverts same way
495+
// the AllocationAnalysis is handling them.
496+
llvm::DenseMap<mlir::Operation *, llvm::SmallVector<mlir::Operation *>>
497+
allocToFreeMemMap;
498+
func->walk([&](fir::FreeMemOp freeOp) {
499+
mlir::Value memref = lookThroughDeclaresAndConverts(freeOp.getHeapref());
500+
if (!freedValues.count(memref))
501+
return;
502+
503+
auto allocMem = memref.getDefiningOp<fir::AllocMemOp>();
504+
allocToFreeMemMap[allocMem].push_back(freeOp);
505+
});
506+
487507
// We only replace allocations which are definately freed on all routes
488508
// through the function because otherwise the allocation may have an intende
489509
// lifetime longer than the current stack frame (e.g. a heap allocation which
490510
// is then freed by another function).
491511
for (mlir::Value freedValue : freedValues) {
492512
fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
493513
InsertionPoint insertionPoint =
494-
AllocMemConversion::findAllocaInsertionPoint(allocmem);
514+
AllocMemConversion::findAllocaInsertionPoint(
515+
allocmem, allocToFreeMemMap[allocmem]);
495516
if (insertionPoint)
496517
candidateOps.insert({allocmem, insertionPoint});
497518
}
@@ -578,8 +599,9 @@ static bool isInLoop(mlir::Operation *op) {
578599
op->getParentOfType<mlir::LoopLikeOpInterface>();
579600
}
580601

581-
InsertionPoint
582-
AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
602+
InsertionPoint AllocMemConversion::findAllocaInsertionPoint(
603+
fir::AllocMemOp &oldAlloc,
604+
const llvm::SmallVector<mlir::Operation *> &freeOps) {
583605
// Ideally the alloca should be inserted at the end of the function entry
584606
// block so that we do not allocate stack space in a loop. However,
585607
// the operands to the alloca may not be available that early, so insert it
@@ -596,7 +618,7 @@ AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
596618
if (isInLoop(oldAllocOp)) {
597619
// where we want to put it is in a loop, and even the old location is in
598620
// a loop. Give up.
599-
return findAllocaLoopInsertionPoint(oldAlloc);
621+
return findAllocaLoopInsertionPoint(oldAlloc, freeOps);
600622
}
601623
return {oldAllocOp};
602624
}
@@ -657,28 +679,14 @@ AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
657679
return checkReturn(&entryBlock);
658680
}
659681

660-
InsertionPoint
661-
AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
682+
InsertionPoint AllocMemConversion::findAllocaLoopInsertionPoint(
683+
fir::AllocMemOp &oldAlloc,
684+
const llvm::SmallVector<mlir::Operation *> &freeOps) {
662685
mlir::Operation *oldAllocOp = oldAlloc;
663686
// This is only called as a last resort. We should try to insert at the
664687
// location of the old allocation, which is inside of a loop, using
665688
// llvm.stacksave/llvm.stackrestore
666689

667-
// find freemem ops
668-
llvm::SmallVector<mlir::Operation *, 1> freeOps;
669-
670-
for (mlir::Operation *user : oldAllocOp->getUsers()) {
671-
if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
672-
for (mlir::Operation *user : declareOp->getUsers()) {
673-
if (mlir::isa<fir::FreeMemOp>(user))
674-
freeOps.push_back(user);
675-
}
676-
}
677-
678-
if (mlir::isa<fir::FreeMemOp>(user))
679-
freeOps.push_back(user);
680-
}
681-
682690
assert(freeOps.size() && "DFA should only return freed memory");
683691

684692
// Don't attempt to reason about a stacksave/stackrestore between different

flang/test/Transforms/stack-arrays.fir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,45 @@ func.func @lookthrough() {
418418
// CHECK: func.func @lookthrough() {
419419
// CHECK: fir.alloca !fir.array<42xi32>
420420
// CHECK-NOT: fir.freemem
421+
422+
// StackArrays is better to find fir.freemem ops corresponding to fir.allocmem
423+
// using the same look through mechanism as during the allocation analysis,
424+
// looking through fir.convert and fir.declare.
425+
func.func @finding_freemem_in_block() {
426+
%c0 = arith.constant 0 : index
427+
%c10_i32 = arith.constant 10 : i32
428+
%c1_i32 = arith.constant 1 : i32
429+
%0 = fir.alloca i32 {bindc_name = "k", uniq_name = "k"}
430+
%1 = fir.declare %0 {uniq_name = "k"} : (!fir.ref<i32>) -> !fir.ref<i32>
431+
fir.store %c1_i32 to %1 : !fir.ref<i32>
432+
cf.br ^bb1
433+
^bb1: // 2 preds: ^bb0, ^bb2
434+
%2 = fir.load %1 : !fir.ref<i32>
435+
%3 = arith.cmpi sle, %2, %c10_i32 : i32
436+
cf.cond_br %3, ^bb2, ^bb3
437+
^bb2: // pred: ^bb1
438+
%4 = fir.declare %1 {fortran_attrs = #fir.var_attrs<intent_in>, uniq_name = "x"} : (!fir.ref<i32>) -> !fir.ref<i32>
439+
%5 = fir.load %4 : !fir.ref<i32>
440+
%6 = fir.convert %5 : (i32) -> index
441+
%7 = arith.cmpi sgt, %6, %c0 : index
442+
%8 = arith.select %7, %6, %c0 : index
443+
%9 = fir.shape %8 : (index) -> !fir.shape<1>
444+
%10 = fir.allocmem !fir.array<?xi32>, %8 {bindc_name = ".tmp.expr_result", uniq_name = ""}
445+
%11 = fir.convert %10 : (!fir.heap<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
446+
%12 = fir.declare %11(%9) {uniq_name = ".tmp.expr_result"} : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<?xi32>>
447+
%13 = fir.embox %12(%9) : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xi32>>
448+
%14 = fir.call @_QPfunc(%1) fastmath<fast> : (!fir.ref<i32>) -> !fir.array<?xi32>
449+
fir.save_result %14 to %12(%9) : !fir.array<?xi32>, !fir.ref<!fir.array<?xi32>>, !fir.shape<1>
450+
fir.call @_QPsub(%13) fastmath<fast> : (!fir.box<!fir.array<?xi32>>) -> ()
451+
%15 = fir.convert %12 : (!fir.ref<!fir.array<?xi32>>) -> !fir.heap<!fir.array<?xi32>>
452+
fir.freemem %15 : !fir.heap<!fir.array<?xi32>>
453+
%16 = fir.load %1 : !fir.ref<i32>
454+
%17 = arith.addi %16, %c1_i32 : i32
455+
fir.store %17 to %1 : !fir.ref<i32>
456+
cf.br ^bb1
457+
^bb3: // pred: ^bb1
458+
return
459+
}
460+
// CHECK: func.func @finding_freemem_in_block() {
461+
// CHECK: fir.alloca !fir.array<?xi32>
462+
// CHECK-NOT: fir.freemem

0 commit comments

Comments
 (0)