Skip to content

Commit 464d321

Browse files
authored
[flang][stack-arrays] Extend pass to work on declare ops and within omp regions (#98810)
Extends the stack-arrays pass to support `fir.declare` ops. Before that, we did not recognize malloc-free pairs for which `fir.declare` is used to declare the allocated entity. This is because the `free` op was invoked on the result of the `fir.declare` op and did not directly use the allocated memory SSA value. This also extends the pass to collect the analysis results within OpenMP regions.
1 parent c184b94 commit 464d321

File tree

3 files changed

+138
-12
lines changed

3 files changed

+138
-12
lines changed

flang/lib/Optimizer/Transforms/StackArrays.cpp

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
287287

288288
void LatticePoint::print(llvm::raw_ostream &os) const {
289289
for (const auto &[value, state] : stateMap) {
290-
os << value << ": ";
290+
os << "\n * " << value << ": ";
291291
::print(os, state);
292292
}
293293
}
@@ -361,6 +361,13 @@ void AllocationAnalysis::visitOperation(mlir::Operation *op,
361361
} else if (mlir::isa<fir::FreeMemOp>(op)) {
362362
assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
363363
mlir::Value operand = op->getOperand(0);
364+
365+
// Note: StackArrays is scheduled in the pass pipeline after lowering hlfir
366+
// to fir. Therefore, we only need to handle `fir::DeclareOp`s.
367+
if (auto declareOp =
368+
llvm::dyn_cast_if_present<fir::DeclareOp>(operand.getDefiningOp()))
369+
operand = declareOp.getMemref();
370+
364371
std::optional<AllocationState> operandState = before.get(operand);
365372
if (operandState && *operandState == AllocationState::Allocated) {
366373
// don't tag things not allocated in this function as freed, so that we
@@ -452,6 +459,9 @@ StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
452459
};
453460
func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
454461
func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
462+
func->walk(
463+
[&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); });
464+
455465
llvm::DenseSet<mlir::Value> freedValues;
456466
point.appendFreedValues(freedValues);
457467

@@ -518,9 +528,18 @@ AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
518528

519529
// remove freemem operations
520530
llvm::SmallVector<mlir::Operation *> erases;
521-
for (mlir::Operation *user : allocmem.getOperation()->getUsers())
531+
for (mlir::Operation *user : allocmem.getOperation()->getUsers()) {
532+
if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
533+
for (mlir::Operation *user : declareOp->getUsers()) {
534+
if (mlir::isa<fir::FreeMemOp>(user))
535+
erases.push_back(user);
536+
}
537+
}
538+
522539
if (mlir::isa<fir::FreeMemOp>(user))
523540
erases.push_back(user);
541+
}
542+
524543
// now we are done iterating the users, it is safe to mutate them
525544
for (mlir::Operation *erase : erases)
526545
rewriter.eraseOp(erase);
@@ -633,9 +652,19 @@ AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
633652

634653
// find freemem ops
635654
llvm::SmallVector<mlir::Operation *, 1> freeOps;
636-
for (mlir::Operation *user : oldAllocOp->getUsers())
655+
656+
for (mlir::Operation *user : oldAllocOp->getUsers()) {
657+
if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
658+
for (mlir::Operation *user : declareOp->getUsers()) {
659+
if (mlir::isa<fir::FreeMemOp>(user))
660+
freeOps.push_back(user);
661+
}
662+
}
663+
637664
if (mlir::isa<fir::FreeMemOp>(user))
638665
freeOps.push_back(user);
666+
}
667+
639668
assert(freeOps.size() && "DFA should only return freed memory");
640669

641670
// Don't attempt to reason about a stacksave/stackrestore between different
@@ -717,12 +746,23 @@ void AllocMemConversion::insertStackSaveRestore(
717746
mlir::SymbolRefAttr stackRestoreSym =
718747
builder.getSymbolRefAttr(stackRestoreFn.getName());
719748

749+
auto createStackRestoreCall = [&](mlir::Operation *user) {
750+
builder.setInsertionPoint(user);
751+
builder.create<fir::CallOp>(user->getLoc(),
752+
stackRestoreFn.getFunctionType().getResults(),
753+
stackRestoreSym, mlir::ValueRange{sp});
754+
};
755+
720756
for (mlir::Operation *user : oldAlloc->getUsers()) {
757+
if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
758+
for (mlir::Operation *user : declareOp->getUsers()) {
759+
if (mlir::isa<fir::FreeMemOp>(user))
760+
createStackRestoreCall(user);
761+
}
762+
}
763+
721764
if (mlir::isa<fir::FreeMemOp>(user)) {
722-
builder.setInsertionPoint(user);
723-
builder.create<fir::CallOp>(user->getLoc(),
724-
stackRestoreFn.getFunctionType().getResults(),
725-
stackRestoreSym, mlir::ValueRange{sp});
765+
createStackRestoreCall(user);
726766
}
727767
}
728768

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
! Similar to stack-arrays.f90; i.e. both test the stack-arrays pass for different
2+
! kinds of supported inputs. This one differs in that it takes the hlfir lowering
3+
! path in flag rather than the fir one. For example, temp arrays are lowered
4+
! differently in hlfir vs. fir and the IR that reaches the stack arrays pass looks
5+
! quite different.
6+
7+
8+
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - \
9+
! RUN: | fir-opt --lower-hlfir-ordered-assignments \
10+
! RUN: --bufferize-hlfir \
11+
! RUN: --convert-hlfir-to-fir \
12+
! RUN: --array-value-copy \
13+
! RUN: --stack-arrays \
14+
! RUN: | FileCheck %s
15+
16+
subroutine temp_array
17+
implicit none
18+
integer (8) :: lV
19+
integer (8), dimension (2) :: iaVS
20+
21+
lV = 202
22+
23+
iaVS = [lV, lV]
24+
end subroutine temp_array
25+
! CHECK-LABEL: func.func @_QPtemp_array{{.*}} {
26+
! CHECK-NOT: fir.allocmem
27+
! CHECK-NOT: fir.freemem
28+
! CHECK: fir.alloca !fir.array<2xi64>
29+
! CHECK-NOT: fir.allocmem
30+
! CHECK-NOT: fir.freemem
31+
! CHECK: return
32+
! CHECK-NEXT: }
33+
34+
subroutine omp_temp_array
35+
implicit none
36+
integer (8) :: lV
37+
integer (8), dimension (2) :: iaVS
38+
39+
lV = 202
40+
41+
!$omp target
42+
iaVS = [lV, lV]
43+
!$omp end target
44+
end subroutine omp_temp_array
45+
! CHECK-LABEL: func.func @_QPomp_temp_array{{.*}} {
46+
! CHECK: omp.target {{.*}} {
47+
! CHECK-NOT: fir.allocmem
48+
! CHECK-NOT: fir.freemem
49+
! CHECK: fir.alloca !fir.array<2xi64>
50+
! CHECK-NOT: fir.allocmem
51+
! CHECK-NOT: fir.freemem
52+
! CHECK: omp.terminator
53+
! CHECK-NEXT: }
54+
! CHECK: return
55+
! CHECK-NEXT: }

flang/test/Transforms/stack-arrays.fir

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,10 @@ func.func @omp_placement1() {
339339
return
340340
}
341341
// CHECK: func.func @omp_placement1() {
342+
// CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array<42xi32>
343+
// CHECK-NEXT: %[[MEM_CONV:.*]] = fir.convert %[[MEM]] : (!fir.ref<!fir.array<42xi32>>) -> !fir.heap<!fir.array<42xi32>>
342344
// CHECK-NEXT: omp.sections {
343345
// CHECK-NEXT: omp.section {
344-
// CHECK-NEXT: %[[MEM:.*]] = fir.allocmem !fir.array<42xi32>
345-
// TODO: this allocation should be moved to the stack. Unfortunately, the data
346-
// flow analysis fails to propogate the lattice out of the omp region to the
347-
// return satement.
348-
// CHECK-NEXT: fir.freemem %[[MEM]] : !fir.heap<!fir.array<42xi32>>
349346
// CHECK-NEXT: omp.terminator
350347
// CHECK-NEXT: }
351348
// CHECK-NEXT: omp.terminator
@@ -369,3 +366,37 @@ func.func @stop_terminator() {
369366
// CHECK-NEXT: %[[NONE:.*]] = fir.call @_FortranAStopStatement(%[[ZERO]], %[[FALSE]], %[[FALSE]]) : (i32, i1, i1) -> none
370367
// CHECK-NEXT: fir.unreachable
371368
// CHECK-NEXT: }
369+
370+
371+
// check that stack allocations that use fir.declare which must be placed in loops
372+
// use stacksave
373+
func.func @placement_loop_declare() {
374+
%c1 = arith.constant 1 : index
375+
%c1_i32 = fir.convert %c1 : (index) -> i32
376+
%c2 = arith.constant 2 : index
377+
%c10 = arith.constant 10 : index
378+
%0:2 = fir.do_loop %arg0 = %c1 to %c10 step %c1 iter_args(%arg1 = %c1_i32) -> (index, i32) {
379+
%3 = arith.addi %c1, %c2 : index
380+
// operand is now available
381+
%4 = fir.allocmem !fir.array<?xi32>, %3
382+
%5 = fir.declare %4 {uniq_name = "temp"} : (!fir.heap<!fir.array<?xi32>>) -> !fir.heap<!fir.array<?xi32>>
383+
// ...
384+
fir.freemem %5 : !fir.heap<!fir.array<?xi32>>
385+
fir.result %3, %c1_i32 : index, i32
386+
}
387+
return
388+
}
389+
// CHECK: func.func @placement_loop_declare() {
390+
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
391+
// CHECK-NEXT: %[[C1_I32:.*]] = fir.convert %[[C1]] : (index) -> i32
392+
// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index
393+
// CHECK-NEXT: %[[C10:.*]] = arith.constant 10 : index
394+
// CHECK-NEXT: fir.do_loop
395+
// CHECK-NEXT: %[[SUM:.*]] = arith.addi %[[C1]], %[[C2]] : index
396+
// CHECK-NEXT: %[[SP:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
397+
// CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array<?xi32>, %[[SUM]]
398+
// CHECK: fir.call @llvm.stackrestore.p0(%[[SP]])
399+
// CHECK-NEXT: fir.result
400+
// CHECK-NEXT: }
401+
// CHECK-NEXT: return
402+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)