Skip to content

Commit 50d6b74

Browse files
Leporacanthicusd-smirnov
authored andcommitted
[Flang] Extracting internal constants from scalar literals (llvm#73829)
Constants actual arguments in function/subroutine calls are currently lowered as allocas + store. This can sometimes inhibit LTO and the constant will not be propagated to the called function. Particularly in cases where the function/subroutine call happens inside a condition. This patch changes the lowering of these constant actual arguments to a global constant + fir.address_of_op. This lowering makes it easier for LTO to propagate the constant. The optimization must be enabled explicitly to run. Use -mmlir --enable-constant-argument-globalisation to enable. --------- Co-authored-by: Dmitriy Smirnov <[email protected]>
1 parent 86e28c3 commit 50d6b74

File tree

9 files changed

+375
-4
lines changed

9 files changed

+375
-4
lines changed

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ namespace fir {
5757
#define GEN_PASS_DECL_OMPFUNCTIONFILTERING
5858
#define GEN_PASS_DECL_VSCALEATTR
5959
#define GEN_PASS_DECL_FUNCTIONATTR
60+
#define GEN_PASS_DECL_CONSTANTARGUMENTGLOBALISATIONOPT
61+
6062
#include "flang/Optimizer/Transforms/Passes.h.inc"
6163

6264
std::unique_ptr<mlir::Pass> createAffineDemotionPass();

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,15 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
251251
];
252252
}
253253

254+
// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
255+
def ConstantArgumentGlobalisationOpt : Pass<"constant-argument-globalisation-opt", "mlir::ModuleOp"> {
256+
let summary = "Convert constant function arguments to global constants.";
257+
let description = [{
258+
Convert scalar literals of function arguments to global constants.
259+
}];
260+
let dependentDialects = [ "fir::FIROpsDialect" ];
261+
}
262+
254263
def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
255264
let summary = "Move local array allocations from heap memory into stack memory";
256265
let description = [{

flang/include/flang/Tools/CLOptions.inc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
2626
llvm::cl::desc("disable " DODescription " pass"), llvm::cl::init(false), \
2727
llvm::cl::Hidden)
28+
#define EnableOption(EOName, EOOption, EODescription) \
29+
static llvm::cl::opt<bool> enable##EOName("enable-" EOOption, \
30+
llvm::cl::desc("enable " EODescription " pass"), llvm::cl::init(false), \
31+
llvm::cl::Hidden)
2832

2933
/// Shared option in tools to control whether dynamically sized array
3034
/// allocations should always be on the heap.
@@ -86,6 +90,8 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
8690

8791
DisableOption(ExternalNameConversion, "external-name-interop",
8892
"convert names with external convention");
93+
EnableOption(ConstantArgumentGlobalisation, "constant-argument-globalisation",
94+
"the local constant argument to global constant conversion");
8995

9096
using PassConstructor = std::unique_ptr<mlir::Pass>();
9197

@@ -270,6 +276,8 @@ inline void createDefaultFIROptimizerPassPipeline(
270276
// These passes may increase code size.
271277
pm.addPass(fir::createSimplifyIntrinsics());
272278
pm.addPass(fir::createAlgebraicSimplificationPass(config));
279+
if (enableConstantArgumentGlobalisation)
280+
pm.addPass(fir::createConstantArgumentGlobalisationOpt());
273281
}
274282

275283
if (pc.LoopVersioning)

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_flang_library(FIRTransforms
66
AnnotateConstant.cpp
77
AssumedRankOpConversion.cpp
88
CharacterConversion.cpp
9+
ConstantArgumentGlobalisation.cpp
910
ControlFlowConverter.cpp
1011
ArrayValueCopy.cpp
1112
ExternalNameConversion.cpp
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
//===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Optimizer/Builder/FIRBuilder.h"
10+
#include "flang/Optimizer/Dialect/FIRDialect.h"
11+
#include "flang/Optimizer/Dialect/FIROps.h"
12+
#include "flang/Optimizer/Dialect/FIRType.h"
13+
#include "flang/Optimizer/Transforms/Passes.h"
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/IR/Diagnostics.h"
16+
#include "mlir/IR/Dominance.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
namespace fir {
21+
#define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
22+
#include "flang/Optimizer/Transforms/Passes.h.inc"
23+
} // namespace fir
24+
25+
#define DEBUG_TYPE "flang-constant-argument-globalisation-opt"
26+
27+
namespace {
28+
unsigned uniqueLitId = 1;
29+
30+
class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
31+
protected:
32+
const mlir::DominanceInfo &di;
33+
34+
public:
35+
using OpRewritePattern::OpRewritePattern;
36+
37+
CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
38+
: OpRewritePattern(ctx), di(_di) {}
39+
40+
mlir::LogicalResult
41+
matchAndRewrite(fir::CallOp callOp,
42+
mlir::PatternRewriter &rewriter) const override {
43+
LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
44+
auto module = callOp->getParentOfType<mlir::ModuleOp>();
45+
bool needUpdate = false;
46+
fir::FirOpBuilder builder(rewriter, module);
47+
llvm::SmallVector<mlir::Value> newOperands;
48+
llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas;
49+
for (const mlir::Value &a : callOp.getArgs()) {
50+
auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
51+
// We can convert arguments that are alloca, and that has
52+
// the value by reference attribute. All else is just added
53+
// to the argument list.
54+
if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
55+
newOperands.push_back(a);
56+
continue;
57+
}
58+
59+
mlir::Type varTy = alloca.getInType();
60+
assert(!fir::hasDynamicSize(varTy) &&
61+
"only expect statically sized scalars to be by value");
62+
63+
// Find immediate store with const argument
64+
mlir::Operation *store = nullptr;
65+
for (mlir::Operation *s : alloca->getUsers()) {
66+
if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
67+
// We can only deal with ONE store - if already found one,
68+
// set to nullptr and exit the loop.
69+
if (store) {
70+
store = nullptr;
71+
break;
72+
}
73+
store = s;
74+
}
75+
}
76+
77+
// If we didn't find any store, or multiple stores, add argument as is
78+
// and move on.
79+
if (!store) {
80+
newOperands.push_back(a);
81+
continue;
82+
}
83+
84+
LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
85+
86+
mlir::Operation *definingOp = store->getOperand(0).getDefiningOp();
87+
// If not a constant, add to operands and move on.
88+
if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) {
89+
// Unable to remove alloca arg
90+
newOperands.push_back(a);
91+
continue;
92+
}
93+
94+
LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n");
95+
96+
std::string globalName =
97+
"_global_const_." + std::to_string(uniqueLitId++);
98+
assert(!builder.getNamedGlobal(globalName) &&
99+
"We should have a unique name here");
100+
101+
if (std::find_if(allocas.begin(), allocas.end(), [alloca](auto x) {
102+
return x.first == alloca;
103+
}) == allocas.end()) {
104+
allocas.push_back(std::make_pair(alloca, store));
105+
}
106+
107+
auto loc = callOp.getLoc();
108+
fir::GlobalOp global = builder.createGlobalConstant(
109+
loc, varTy, globalName,
110+
[&](fir::FirOpBuilder &builder) {
111+
mlir::Operation *cln = definingOp->clone();
112+
builder.insert(cln);
113+
mlir::Value val =
114+
builder.createConvert(loc, varTy, cln->getResult(0));
115+
builder.create<fir::HasValueOp>(loc, val);
116+
},
117+
builder.createInternalLinkage());
118+
mlir::Value addr = builder.create<fir::AddrOfOp>(loc, global.resultType(),
119+
global.getSymbol());
120+
newOperands.push_back(addr);
121+
needUpdate = true;
122+
}
123+
124+
if (needUpdate) {
125+
auto loc = callOp.getLoc();
126+
llvm::SmallVector<mlir::Type> newResultTypes;
127+
newResultTypes.append(callOp.getResultTypes().begin(),
128+
callOp.getResultTypes().end());
129+
fir::CallOp newOp = builder.create<fir::CallOp>(
130+
loc, newResultTypes,
131+
callOp.getCallee().has_value() ? callOp.getCallee().value()
132+
: mlir::SymbolRefAttr{},
133+
newOperands);
134+
// Copy all the attributes from the old to new op.
135+
newOp->setAttrs(callOp->getAttrs());
136+
rewriter.replaceOp(callOp, newOp);
137+
138+
for (auto a : allocas) {
139+
if (a.first->hasOneUse()) {
140+
// If the alloca is only used for a store and the call operand, the
141+
// store is no longer required.
142+
rewriter.eraseOp(a.second);
143+
rewriter.eraseOp(a.first);
144+
}
145+
}
146+
LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
147+
<< newOp << '\n');
148+
return mlir::success();
149+
}
150+
151+
// Failure here just means "we couldn't do the conversion", which is
152+
// perfectly acceptable to the upper layers of this function.
153+
return mlir::failure();
154+
}
155+
};
156+
157+
// this pass attempts to convert immediate scalar literals in function calls
158+
// to global constants to allow transformations such as Dead Argument
159+
// Elimination
160+
class ConstantArgumentGlobalisationOpt
161+
: public fir::impl::ConstantArgumentGlobalisationOptBase<
162+
ConstantArgumentGlobalisationOpt> {
163+
public:
164+
ConstantArgumentGlobalisationOpt() = default;
165+
166+
void runOnOperation() override {
167+
mlir::ModuleOp mod = getOperation();
168+
mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
169+
auto *context = &getContext();
170+
mlir::RewritePatternSet patterns(context);
171+
mlir::GreedyRewriteConfig config;
172+
config.enableRegionSimplification =
173+
mlir::GreedySimplifyRegionLevel::Disabled;
174+
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
175+
176+
patterns.insert<CallOpRewriter>(context, *di);
177+
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
178+
mod, std::move(patterns), config))) {
179+
mlir::emitError(mod.getLoc(),
180+
"error in constant globalisation optimization\n");
181+
signalPassFailure();
182+
}
183+
}
184+
};
185+
} // namespace

flang/test/Fir/boxproc.fir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616

1717
// CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
1818
// CHECK-SAME: %[[VAL_0:.*]])
19-
// CHECK: %[[VAL_1:.*]] = alloca i32, i64 1, align 4
20-
// CHECK: store i32 4, ptr %[[VAL_1]], align 4
21-
// CHECK: call void %[[VAL_0]](ptr %[[VAL_1]])
19+
// CHECK: call void %[[VAL_0]](ptr %{{.*}})
2220

2321
func.func @_QPtest_proc_dummy() {
2422
%c0_i32 = arith.constant 0 : i32

flang/test/Lower/character-local-variables.f90

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
! RUN: bbc -hlfir=false %s -o - | FileCheck %s
2+
! RUN: bbc -hlfir=false --enable-constant-argument-globalisation %s -o - \
3+
! RUN: | FileCheck %s --check-prefix=CHECK-CONST
24

35
! Test lowering of local character variables
46

@@ -118,7 +120,8 @@ subroutine assumed_length_param(n)
118120
integer :: n
119121
! CHECK: %[[c4:.*]] = arith.constant 4 : i64
120122
! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
121-
! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
123+
! CHECK-CONST: %[[tmp:.*]] = fir.address_of(@_global_const_.{{.*}}) : !fir.ref<i64>
124+
! CHECK-CONST: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
122125
call take_int(len(c(n), kind=8))
123126
end
124127

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// RUN: fir-opt --split-input-file --constant-argument-globalisation-opt < %s | FileCheck %s
2+
3+
module {
4+
// Test for "two conditional writes to the same alloca doesn't get replaced."
5+
func.func @func(%arg0: i32, %arg1: i1) {
6+
%c2_i32 = arith.constant 2 : i32
7+
%addr = fir.alloca i32 {adapt.valuebyref}
8+
fir.if %arg1 {
9+
fir.store %c2_i32 to %addr : !fir.ref<i32>
10+
} else {
11+
fir.store %arg0 to %addr : !fir.ref<i32>
12+
}
13+
fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
14+
return
15+
}
16+
func.func private @sub2(!fir.ref<i32>)
17+
18+
// CHECK-LABEL: func.func @func
19+
// CHECK-SAME: [[ARG0:%.*]]: i32
20+
// CHECK-SAME: [[ARG1:%.*]]: i1)
21+
// CHECK: [[CONST:%.*]] = arith.constant
22+
// CHECK: [[ADDR:%.*]] = fir.alloca i32
23+
// CHECK: fir.if [[ARG1]]
24+
// CHECK: fir.store [[CONST]] to [[ADDR]]
25+
// CHECK: } else {
26+
// CHECK: fir.store [[ARG0]] to [[ADDR]]
27+
// CHECK: fir.call @sub2([[ADDR]])
28+
// CHECK: return
29+
30+
}
31+
32+
// -----
33+
34+
module {
35+
// Test for "two writes to the same alloca doesn't get replaced."
36+
func.func @func() {
37+
%c1_i32 = arith.constant 1 : i32
38+
%c2_i32 = arith.constant 2 : i32
39+
%addr = fir.alloca i32 {adapt.valuebyref}
40+
fir.store %c1_i32 to %addr : !fir.ref<i32>
41+
fir.store %c2_i32 to %addr : !fir.ref<i32>
42+
fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
43+
return
44+
}
45+
func.func private @sub2(!fir.ref<i32>)
46+
47+
// CHECK-LABEL: func.func @func
48+
// CHECK: [[CONST1:%.*]] = arith.constant
49+
// CHECK: [[CONST2:%.*]] = arith.constant
50+
// CHECK: [[ADDR:%.*]] = fir.alloca i32
51+
// CHECK: fir.store [[CONST1]] to [[ADDR]]
52+
// CHECK: fir.store [[CONST2]] to [[ADDR]]
53+
// CHECK: fir.call @sub2([[ADDR]])
54+
// CHECK: return
55+
56+
}
57+
58+
// -----
59+
60+
module {
61+
// Test for "one write to the the alloca gets replaced."
62+
func.func @func() {
63+
%c1_i32 = arith.constant 1 : i32
64+
%addr = fir.alloca i32 {adapt.valuebyref}
65+
fir.store %c1_i32 to %addr : !fir.ref<i32>
66+
fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
67+
return
68+
}
69+
func.func private @sub2(!fir.ref<i32>)
70+
71+
// CHECK-LABEL: func.func @func
72+
// CHECK: [[ADDR:%.*]] = fir.address_of([[EXTR:@.*]]) : !fir.ref<i32>
73+
// CHECK: fir.call @sub2([[ADDR]])
74+
// CHECK: return
75+
// CHECK: fir.global internal [[EXTR]] constant : i32 {
76+
// CHECK: %{{.*}} = arith.constant 1 : i32
77+
// CHECK: fir.has_value %{{.*}} : i32
78+
// CHECK: }
79+
80+
}
81+
82+
// -----
83+
// Check that same argument used twice is converted.
84+
module {
85+
func.func @func(%arg0: !fir.ref<i32>, %arg1: i1) {
86+
%c2_i32 = arith.constant 2 : i32
87+
%addr1 = fir.alloca i32 {adapt.valuebyref}
88+
fir.store %c2_i32 to %addr1 : !fir.ref<i32>
89+
fir.call @sub1(%addr1, %addr1) : (!fir.ref<i32>, !fir.ref<i32>) -> ()
90+
return
91+
}
92+
}
93+
94+
// CHECK-LABEL: func.func @func
95+
// CHECK-NEXT: %[[ARG1:.*]] = fir.address_of([[CONST1:@.*]]) : !fir.ref<i32>
96+
// CHECK-NEXT: %[[ARG2:.*]] = fir.address_of([[CONST2:@.*]]) : !fir.ref<i32>
97+
// CHECK-NEXT: fir.call @sub1(%[[ARG1]], %[[ARG2]])
98+
// CHECK-NEXT: return

0 commit comments

Comments
 (0)