Skip to content

Commit ff7fca7

Browse files
authored
[flang][cuda] Support memory cleanup at a return statement (#116304)
We generate `cuf.free` and `func.return` twice if a return statement exists at the end of program. ```f90 program test integer, device :: a(10) return end ``` ``` % flang -x cuda test.cuf -mmlir --mlir-print-ir-after-all error: loc("/path/to/test.cuf":3:3): 'func.return' op must be the last operation in the parent block // -----// IR Dump After Fortran::lower::VerifierPass Failed () //----- // ``` Dumped IR: ```mlir "func.func"() <{function_type = () -> (), sym_name = "_QQmain"}> ({ ... "cuf.free"(%5#1) <{data_attr = #cuf.cuda<device>}> : (!fir.ref<!fir.array<10xi32>>) -> () "func.return"() : () -> () "cuf.free"(%5#1) <{data_attr = #cuf.cuda<device>}> : (!fir.ref<!fir.array<10xi32>>) -> () "func.return"() : () -> () } ... ``` The routine `genExitRoutine` in `Bridge.cpp` is guarded by `blockIsUnterminated()` to make sure that `func.return` is generated only at the end of a block. However, we redundantly run `bridge.fctCtx().finalizeAndKeep()` before `genExitRoutine` in this case, resulting in two pairs of `cuf.free` and `func.return`. This PR fixes `Bridge.cpp` by using `blockIsUnterminated()` to guard `finalizeAndKeep` as well.
1 parent 798a894 commit ff7fca7

File tree

4 files changed

+83
-21
lines changed

4 files changed

+83
-21
lines changed

flang/include/flang/Lower/StatementContext.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,13 @@ class StatementContext {
9292
cufs.back().reset();
9393
}
9494

95+
/// Pop the stack top list.
96+
void pop() { cufs.pop_back(); }
97+
9598
/// Make cleanup calls. Pop the stack top list.
9699
void finalizeAndPop() {
97100
finalizeAndKeep();
98-
cufs.pop_back();
101+
pop();
99102
}
100103

101104
bool hasCode() const {

flang/lib/Lower/Bridge.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,13 +1621,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
16211621
// Termination of symbolically referenced execution units
16221622
//===--------------------------------------------------------------------===//
16231623

1624-
/// END of program
1624+
/// Exit of a routine
16251625
///
1626-
/// Generate the cleanup block before the program exits
1627-
void genExitRoutine() {
1628-
1629-
if (blockIsUnterminated())
1630-
builder->create<mlir::func::ReturnOp>(toLocation());
1626+
/// Generate the cleanup block before the routine exits
1627+
void genExitRoutine(bool earlyReturn, mlir::ValueRange retval = {}) {
1628+
if (blockIsUnterminated()) {
1629+
bridge.openAccCtx().finalizeAndKeep();
1630+
bridge.fctCtx().finalizeAndKeep();
1631+
builder->create<mlir::func::ReturnOp>(toLocation(), retval);
1632+
}
1633+
if (!earlyReturn) {
1634+
bridge.openAccCtx().pop();
1635+
bridge.fctCtx().pop();
1636+
}
16311637
}
16321638

16331639
/// END of procedure-like constructs
@@ -1684,9 +1690,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
16841690
resultRef = builder->createConvert(loc, resultRefType, resultRef);
16851691
return builder->create<fir::LoadOp>(loc, resultRef);
16861692
});
1687-
bridge.openAccCtx().finalizeAndPop();
1688-
bridge.fctCtx().finalizeAndPop();
1689-
builder->create<mlir::func::ReturnOp>(loc, resultVal);
1693+
genExitRoutine(false, resultVal);
16901694
}
16911695

16921696
/// Get the return value of a call to \p symbol, which is a subroutine entry
@@ -1712,13 +1716,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
17121716
} else if (Fortran::semantics::HasAlternateReturns(symbol)) {
17131717
mlir::Value retval = builder->create<fir::LoadOp>(
17141718
toLocation(), getAltReturnResult(symbol));
1715-
bridge.openAccCtx().finalizeAndPop();
1716-
bridge.fctCtx().finalizeAndPop();
1717-
builder->create<mlir::func::ReturnOp>(toLocation(), retval);
1719+
genExitRoutine(false, retval);
17181720
} else {
1719-
bridge.openAccCtx().finalizeAndPop();
1720-
bridge.fctCtx().finalizeAndPop();
1721-
genExitRoutine();
1721+
genExitRoutine(false);
17221722
}
17231723
}
17241724

@@ -5018,8 +5018,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
50185018
it->stmtCtx.finalizeAndKeep();
50195019
}
50205020
if (funit->isMainProgram()) {
5021-
bridge.fctCtx().finalizeAndKeep();
5022-
genExitRoutine();
5021+
genExitRoutine(true);
50235022
return;
50245023
}
50255024
mlir::Location loc = toLocation();
@@ -5478,9 +5477,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
54785477
void endNewFunction(Fortran::lower::pft::FunctionLikeUnit &funit) {
54795478
setCurrentPosition(Fortran::lower::pft::stmtSourceLoc(funit.endStmt));
54805479
if (funit.isMainProgram()) {
5481-
bridge.openAccCtx().finalizeAndPop();
5482-
bridge.fctCtx().finalizeAndPop();
5483-
genExitRoutine();
5480+
genExitRoutine(false);
54845481
} else {
54855482
genFIRProcedureExit(funit, funit.getSubprogramSymbol());
54865483
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
! Check if finalization works with a return statement
4+
5+
program main
6+
integer, device :: a(10)
7+
return
8+
end
9+
10+
! CHECK: func.func @_QQmain() attributes {fir.bindc_name = "main"} {
11+
! CHECK: %[[DECL:.*]]:2 = hlfir.declare
12+
! CHECK-NEXT: cuf.free %[[DECL]]#1 : !fir.ref<!fir.array<10xi32>>
13+
! CHECK-NEXT: return
14+
! CHECK-NEXT: }
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
! Check if finalization works with multiple return statements
4+
5+
program test
6+
integer, device :: a(10)
7+
logical :: l
8+
9+
if (l) then
10+
return
11+
end if
12+
13+
return
14+
end
15+
16+
! CHECK: func.func @_QQmain() attributes {fir.bindc_name = "test"} {
17+
! CHECK: %[[DECL:.*]]:2 = hlfir.declare
18+
! CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
19+
! CHECK-NEXT: ^bb1:
20+
! CHECK-NEXT: cuf.free %[[DECL]]#1 : !fir.ref<!fir.array<10xi32>>
21+
! CHECK-NEXT: return
22+
! CHECK-NEXT: ^bb2:
23+
! CHECK-NEXT: cuf.free %[[DECL]]#1 : !fir.ref<!fir.array<10xi32>>
24+
! CHECK-NEXT: return
25+
! CHECK-NEXT: }
26+
27+
subroutine sub(l)
28+
integer, device :: a(10)
29+
logical :: l
30+
31+
if (l) then
32+
l = .false.
33+
return
34+
end if
35+
36+
return
37+
end
38+
39+
! CHECK: func.func @_QPsub(%arg0: !fir.ref<!fir.logical<4>> {fir.bindc_name = "l"}) {
40+
! CHECK: %[[DECL:.*]]:2 = hlfir.declare
41+
! CHECK: cf.cond_br %6, ^bb1, ^bb2
42+
! CHECK: ^bb1:
43+
! CHECK: cf.br ^bb3
44+
! CHECK: ^bb2:
45+
! CHECK: cf.br ^bb3
46+
! CHECK: ^bb3:
47+
! CHECK: cuf.free %[[DECL]]#1 : !fir.ref<!fir.array<10xi32>>
48+
! CHECK: }

0 commit comments

Comments
 (0)