Skip to content

Commit 99cd0c5

Browse files
[Flang] Extracting internal constants from scalar literals
Constants actual arguments in function/subroutine calls are currently lowered as allocas + store. This can sometimes inhibit LTO and the constant will not be propagated to the called function. Particularly in cases where the function/subroutine call happens inside a condition. This patch changes the lowering of these constant actual arguments to a global constant + fir.address_of_op. This lowering makes it easier for LTO to propagate the constant. Co-authored-by: Dmitriy Smirnov <[email protected]>
1 parent 60aeea2 commit 99cd0c5

File tree

14 files changed

+272
-12
lines changed

14 files changed

+272
-12
lines changed

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ createExternalNameConversionPass(bool appendUnderscore);
6060
std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
6161
std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
6262
std::unique_ptr<mlir::Pass> createMemoryAllocationPass();
63+
std::unique_ptr<mlir::Pass> createConstExtruderPass();
6364
std::unique_ptr<mlir::Pass> createStackArraysPass();
6465
std::unique_ptr<mlir::Pass> createAliasTagsPass();
6566
std::unique_ptr<mlir::Pass> createSimplifyIntrinsicsPass();

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,16 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
242242
let constructor = "::fir::createMemoryAllocationPass()";
243243
}
244244

245+
// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
246+
def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
247+
let summary = "Convert scalar literals of function arguments to global constants.";
248+
let description = [{
249+
Convert scalar literals of function arguments to global constants.
250+
}];
251+
let dependentDialects = [ "fir::FIROpsDialect" ];
252+
let constructor = "::fir::createConstExtruderPass()";
253+
}
254+
245255
def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
246256
let summary = "Move local array allocations from heap memory into stack memory";
247257
let description = [{

flang/include/flang/Tools/CLOptions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ inline void createDefaultFIROptimizerPassPipeline(
216216
else
217217
fir::addMemoryAllocationOpt(pm);
218218

219+
pm.addPass(fir::createConstExtruderPass());
220+
219221
// The default inliner pass adds the canonicalizer pass with the default
220222
// configuration. Create the inliner pass with tco config.
221223
llvm::StringMap<mlir::OpPassManager> pipelines;

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_flang_library(FIRTransforms
55
AffineDemotion.cpp
66
AnnotateConstant.cpp
77
CharacterConversion.cpp
8+
ConstExtruder.cpp
89
ControlFlowConverter.cpp
910
ArrayValueCopy.cpp
1011
ExternalNameConversion.cpp
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
//===- ConstExtruder.cpp -----------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Optimizer/Builder/BoxValue.h"
10+
#include "flang/Optimizer/Builder/FIRBuilder.h"
11+
#include "flang/Optimizer/Dialect/FIRDialect.h"
12+
#include "flang/Optimizer/Dialect/FIROps.h"
13+
#include "flang/Optimizer/Dialect/FIRType.h"
14+
#include "flang/Optimizer/Transforms/Passes.h"
15+
#include "mlir/Dialect/Func/IR/FuncOps.h"
16+
#include "mlir/IR/Diagnostics.h"
17+
#include "mlir/IR/Dominance.h"
18+
#include "mlir/Pass/Pass.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
#include "mlir/Transforms/Passes.h"
21+
#include "llvm/ADT/TypeSwitch.h"
22+
#include <atomic>
23+
24+
namespace fir {
25+
#define GEN_PASS_DEF_CONSTEXTRUDEROPT
26+
#include "flang/Optimizer/Transforms/Passes.h.inc"
27+
} // namespace fir
28+
29+
#define DEBUG_TYPE "flang-const-extruder-opt"
30+
31+
namespace {
32+
std::atomic<int> uniqueLitId = 1;
33+
34+
static bool needsExtrusion(const mlir::Value *a) {
35+
if (!a || !a->getDefiningOp())
36+
return false;
37+
38+
// is alloca
39+
if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
40+
// alloca has annotation
41+
if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
42+
for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
43+
if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
44+
auto constant_def = store->getOperand(0).getDefiningOp();
45+
// Expect constant definition operation
46+
if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
47+
return true;
48+
}
49+
}
50+
}
51+
}
52+
}
53+
return false;
54+
}
55+
56+
class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
57+
protected:
58+
mlir::DominanceInfo &di;
59+
60+
public:
61+
using OpRewritePattern::OpRewritePattern;
62+
63+
CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
64+
: OpRewritePattern(ctx), di(_di) {}
65+
66+
mlir::LogicalResult
67+
matchAndRewrite(fir::CallOp callOp,
68+
mlir::PatternRewriter &rewriter) const override {
69+
LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
70+
auto module = callOp->getParentOfType<mlir::ModuleOp>();
71+
fir::FirOpBuilder builder(rewriter, module);
72+
llvm::SmallVector<mlir::Value> newOperands;
73+
llvm::SmallVector<mlir::Operation *> toErase;
74+
for (const auto &a : callOp.getArgs()) {
75+
if (auto alloca =
76+
mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
77+
if (needsExtrusion(&a)) {
78+
79+
mlir::Type varTy = alloca.getInType();
80+
assert(!fir::hasDynamicSize(varTy) &&
81+
"only expect statically sized scalars to be by value");
82+
83+
// find immediate store with const argument
84+
llvm::SmallVector<mlir::Operation *> stores;
85+
for (mlir::Operation *s : alloca.getOperation()->getUsers())
86+
if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
87+
stores.push_back(s);
88+
assert(stores.size() == 1 && "expected exactly one store");
89+
LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
90+
91+
auto constant_def = stores[0]->getOperand(0).getDefiningOp();
92+
// Expect constant definition operation or force legalisation of the
93+
// callOp and continue with its next argument
94+
if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
95+
// unable to remove alloca arg
96+
newOperands.push_back(a);
97+
continue;
98+
}
99+
100+
LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
101+
102+
auto loc = callOp.getLoc();
103+
llvm::StringRef globalPrefix = "_extruded_";
104+
105+
std::string globalName;
106+
while (!globalName.length() || builder.getNamedGlobal(globalName))
107+
globalName =
108+
globalPrefix.str() + "." + std::to_string(uniqueLitId++);
109+
110+
if (alloca->hasOneUse()) {
111+
toErase.push_back(alloca);
112+
toErase.push_back(stores[0]);
113+
} else {
114+
int count = -2;
115+
for (mlir::Operation *s : alloca.getOperation()->getUsers())
116+
if (di.dominates(stores[0], s))
117+
++count;
118+
119+
// delete if dominates itself and one more operation (which should
120+
// be callOp)
121+
if (!count)
122+
toErase.push_back(stores[0]);
123+
}
124+
auto global = builder.createGlobalConstant(
125+
loc, varTy, globalName,
126+
[&](fir::FirOpBuilder &builder) {
127+
mlir::Operation *cln = constant_def->clone();
128+
builder.insert(cln);
129+
fir::ExtendedValue exv{cln->getResult(0)};
130+
mlir::Value valBase = fir::getBase(exv);
131+
mlir::Value val = builder.createConvert(loc, varTy, valBase);
132+
builder.create<fir::HasValueOp>(loc, val);
133+
},
134+
builder.createInternalLinkage());
135+
mlir::Value ope = {builder.create<fir::AddrOfOp>(
136+
loc, global.resultType(), global.getSymbol())};
137+
newOperands.push_back(ope);
138+
} else {
139+
// alloca but without attr, add it
140+
newOperands.push_back(a);
141+
}
142+
} else {
143+
// non-alloca operand, add it
144+
newOperands.push_back(a);
145+
}
146+
}
147+
148+
auto loc = callOp.getLoc();
149+
llvm::SmallVector<mlir::Type> newResultTypes;
150+
newResultTypes.append(callOp.getResultTypes().begin(),
151+
callOp.getResultTypes().end());
152+
fir::CallOp newOp = builder.create<fir::CallOp>(
153+
loc, newResultTypes,
154+
callOp.getCallee().has_value() ? callOp.getCallee().value()
155+
: mlir::SymbolRefAttr{},
156+
newOperands, callOp.getFastmathAttr());
157+
rewriter.replaceOp(callOp, newOp);
158+
159+
for (auto e : toErase)
160+
rewriter.eraseOp(e);
161+
162+
LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
163+
<< newOp << '\n');
164+
return mlir::success();
165+
}
166+
};
167+
168+
// This pass attempts to convert immediate scalar literals in function calls
169+
// to global constants to allow transformations as Dead Argument Elimination
170+
class ConstExtruderOpt
171+
: public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
172+
protected:
173+
mlir::DominanceInfo *di;
174+
175+
public:
176+
ConstExtruderOpt() {}
177+
178+
void runOnOperation() override {
179+
mlir::ModuleOp mod = getOperation();
180+
di = &getAnalysis<mlir::DominanceInfo>();
181+
mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
182+
}
183+
184+
void runOnFunc(mlir::func::FuncOp &func) {
185+
auto *context = &getContext();
186+
mlir::RewritePatternSet patterns(context);
187+
mlir::ConversionTarget target(*context);
188+
189+
// If func is a declaration, skip it.
190+
if (func.empty())
191+
return;
192+
193+
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
194+
mlir::func::FuncDialect>();
195+
target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
196+
for (auto a : op.getArgs()) {
197+
if (needsExtrusion(&a))
198+
return false;
199+
}
200+
return true;
201+
});
202+
203+
patterns.insert<CallOpRewriter>(context, *di);
204+
if (mlir::failed(
205+
mlir::applyPartialConversion(func, target, std::move(patterns)))) {
206+
mlir::emitError(func.getLoc(),
207+
"error in constant extrusion optimization\n");
208+
signalPassFailure();
209+
}
210+
}
211+
};
212+
} // namespace
213+
214+
std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
215+
return std::make_unique<ConstExtruderOpt>();
216+
}

flang/test/Driver/bbc-mlir-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
! CHECK-NEXT: 'func.func' Pipeline
3333
! CHECK-NEXT: MemoryAllocationOpt
34+
! CHECK-NEXT: ConstExtruderOpt
3435

3536
! CHECK-NEXT: Inliner
3637
! CHECK-NEXT: SimplifyRegionLite

flang/test/Driver/mlir-debug-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
! ALL-NEXT: 'func.func' Pipeline
5353
! ALL-NEXT: MemoryAllocationOpt
54+
! ALL-NEXT: ConstExtruderOpt
5455

5556
! ALL-NEXT: Inliner
5657
! ALL-NEXT: SimplifyRegionLite

flang/test/Driver/mlir-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
! ALL-NEXT: 'func.func' Pipeline
4444
! ALL-NEXT: MemoryAllocationOpt
45+
! ALL-NEXT: ConstExtruderOpt
4546

4647
! ALL-NEXT: Inliner
4748
! ALL-NEXT: SimplifyRegionLite

flang/test/Fir/basic-program.fir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func.func @_QQmain() {
4848

4949
// PASSES-NEXT: 'func.func' Pipeline
5050
// PASSES-NEXT: MemoryAllocationOpt
51+
// PASSES-NEXT: ConstExtruderOpt
5152

5253
// PASSES-NEXT: Inliner
5354
// PASSES-NEXT: SimplifyRegionLite

flang/test/Fir/boxproc.fir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616

1717
// CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
1818
// CHECK-SAME: %[[VAL_0:.*]])
19-
// CHECK: %[[VAL_1:.*]] = alloca i32, i64 1, align 4
20-
// CHECK: store i32 4, ptr %[[VAL_1]], align 4
21-
// CHECK: call void %[[VAL_0]](ptr %[[VAL_1]])
19+
// CHECK: call void %[[VAL_0]](ptr @{{.*}})
2220

2321
func.func @_QPtest_proc_dummy() {
2422
%c0_i32 = arith.constant 0 : i32

flang/test/Lower/character-local-variables.f90

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ subroutine dyn_array_dyn_len_lb(l, n)
116116
subroutine assumed_length_param(n)
117117
character(*), parameter :: c(1)=(/"abcd"/)
118118
integer :: n
119-
! CHECK: %[[c4:.*]] = arith.constant 4 : i64
120-
! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
119+
! CHECK: %[[tmp:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i64>
121120
! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
122121
call take_int(len(c(n), kind=8))
123122
end

flang/test/Lower/dummy-arguments.f90

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
! CHECK-LABEL: _QQmain
44
program test1
5-
! CHECK-DAG: %[[TMP:.*]] = fir.alloca
6-
! CHECK-DAG: %[[TEN:.*]] = arith.constant
7-
! CHECK: fir.store %[[TEN]] to %[[TMP]]
5+
! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
86
! CHECK-NEXT: fir.call @_QFPfoo
97
call foo(10)
108
contains

flang/test/Lower/host-associated.f90

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,11 +448,10 @@ subroutine bar()
448448

449449
! CHECK-LABEL: func @_QPtest_proc_dummy_other(
450450
! CHECK-SAME: %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) {
451-
! CHECK: %[[VAL_1:.*]] = arith.constant 4 : i32
452-
! CHECK: %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref}
453-
! CHECK: fir.store %[[VAL_1]] to %[[VAL_2]] : !fir.ref<i32>
454451
! CHECK: %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i32>) -> ())
455-
! CHECK: fir.call %[[VAL_3]](%[[VAL_2]]) {{.*}}: (!fir.ref<i32>) -> ()
452+
! CHECK: %[[VAL_1:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
453+
! CHECK: fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref<i32>) -> ()
454+
456455
! CHECK: return
457456
! CHECK: }
458457

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
2+
3+
subroutine sub1(x,y)
4+
implicit none
5+
integer x, y
6+
7+
call sub2(0.0d0, 1.0d0, x, y, 1)
8+
end subroutine sub1
9+
10+
!CHECK-LABEL: func.func @_QPsub1
11+
!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
12+
!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
13+
!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
14+
!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
15+
!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
16+
!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
17+
!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
18+
!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
19+
!CHECK: return
20+
21+
!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
22+
!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
23+
!CHECK: fir.has_value %{{.*}} : f64
24+
!CHECK: }
25+
!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
26+
!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
27+
!CHECK: fir.has_value %{{.*}} : f64
28+
!CHECK: }
29+
!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
30+
!CHECK: %{{.*}} = arith.constant 1 : i32
31+
!CHECK: fir.has_value %{{.*}} : i32
32+
!CHECK: }

0 commit comments

Comments
 (0)