Skip to content

Commit 8d05dfa

Browse files
Fix review comments
1 parent 037b6f5 commit 8d05dfa

File tree

3 files changed

+111
-131
lines changed

3 files changed

+111
-131
lines changed

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
244244

245245
// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
246246
def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
247-
let summary = "Convert scalar literals of function arguments to global constants.";
247+
let summary = "Convert constant function arguments to global constants.";
248248
let description = [{
249249
Convert scalar literals of function arguments to global constants.
250250
}];

flang/lib/Optimizer/Transforms/ConstExtruder.cpp

Lines changed: 109 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
//===- ConstExtruder.cpp -----------------------------------------------===//
1+
//===- ConstExtruder.cpp --------------------------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "flang/Optimizer/Builder/BoxValue.h"
109
#include "flang/Optimizer/Builder/FIRBuilder.h"
1110
#include "flang/Optimizer/Dialect/FIRDialect.h"
1211
#include "flang/Optimizer/Dialect/FIROps.h"
@@ -17,9 +16,6 @@
1716
#include "mlir/IR/Dominance.h"
1817
#include "mlir/Pass/Pass.h"
1918
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20-
#include "mlir/Transforms/Passes.h"
21-
#include "llvm/ADT/TypeSwitch.h"
22-
#include <atomic>
2319

2420
namespace fir {
2521
#define GEN_PASS_DEF_CONSTEXTRUDEROPT
@@ -29,170 +25,154 @@ namespace fir {
2925
#define DEBUG_TYPE "flang-const-extruder-opt"
3026

3127
namespace {
32-
std::atomic<int> uniqueLitId = 1;
33-
34-
static bool needsExtrusion(const mlir::Value *a) {
35-
if (!a || !a->getDefiningOp())
36-
return false;
37-
38-
// is alloca
39-
if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
40-
// alloca has annotation
41-
if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
42-
for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
43-
if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
44-
auto constant_def = store->getOperand(0).getDefiningOp();
45-
// Expect constant definition operation
46-
if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
47-
return true;
48-
}
49-
}
50-
}
51-
}
52-
}
53-
return false;
54-
}
28+
unsigned uniqueLitId = 1;
5529

5630
class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
5731
protected:
58-
mlir::DominanceInfo &di;
32+
const mlir::DominanceInfo &di;
5933

6034
public:
6135
using OpRewritePattern::OpRewritePattern;
6236

63-
CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
37+
CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
6438
: OpRewritePattern(ctx), di(_di) {}
6539

6640
mlir::LogicalResult
6741
matchAndRewrite(fir::CallOp callOp,
6842
mlir::PatternRewriter &rewriter) const override {
6943
LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
7044
auto module = callOp->getParentOfType<mlir::ModuleOp>();
45+
bool needUpdate = false;
7146
fir::FirOpBuilder builder(rewriter, module);
7247
llvm::SmallVector<mlir::Value> newOperands;
7348
llvm::SmallVector<mlir::Operation *> toErase;
74-
for (const auto &a : callOp.getArgs()) {
75-
if (auto alloca =
76-
mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
77-
if (needsExtrusion(&a)) {
78-
79-
mlir::Type varTy = alloca.getInType();
80-
assert(!fir::hasDynamicSize(varTy) &&
81-
"only expect statically sized scalars to be by value");
82-
83-
// find immediate store with const argument
84-
llvm::SmallVector<mlir::Operation *> stores;
85-
for (mlir::Operation *s : alloca.getOperation()->getUsers())
86-
if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
87-
stores.push_back(s);
88-
assert(stores.size() == 1 && "expected exactly one store");
89-
LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
90-
91-
auto constant_def = stores[0]->getOperand(0).getDefiningOp();
92-
// Expect constant definition operation or force legalisation of the
93-
// callOp and continue with its next argument
94-
if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
95-
// unable to remove alloca arg
96-
newOperands.push_back(a);
97-
continue;
98-
}
49+
for (const mlir::Value &a : callOp.getArgs()) {
50+
auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
51+
// We can convert arguments that are alloca, and that has
52+
// the value by reference attribute. All else is just added
53+
// to the argument list.
54+
if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
55+
newOperands.push_back(a);
56+
continue;
57+
}
9958

100-
LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
101-
102-
auto loc = callOp.getLoc();
103-
llvm::StringRef globalPrefix = "_extruded_";
104-
105-
std::string globalName;
106-
while (!globalName.length() || builder.getNamedGlobal(globalName))
107-
globalName =
108-
globalPrefix.str() + "." + std::to_string(uniqueLitId++);
109-
110-
if (alloca->hasOneUse()) {
111-
toErase.push_back(alloca);
112-
toErase.push_back(stores[0]);
113-
} else {
114-
int count = -2;
115-
for (mlir::Operation *s : alloca.getOperation()->getUsers())
116-
if (di.dominates(stores[0], s))
117-
++count;
118-
119-
// delete if dominates itself and one more operation (which should
120-
// be callOp)
121-
if (!count)
122-
toErase.push_back(stores[0]);
59+
mlir::Type varTy = alloca.getInType();
60+
assert(!fir::hasDynamicSize(varTy) &&
61+
"only expect statically sized scalars to be by value");
62+
63+
// Find immediate store with const argument
64+
mlir::Operation *store = nullptr;
65+
for (mlir::Operation *s : alloca->getUsers()) {
66+
if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
67+
// We can only deal with ONE store - if already found one,
68+
// set to nullptr and exit the loop.
69+
if (store) {
70+
store = nullptr;
71+
break;
12372
}
124-
auto global = builder.createGlobalConstant(
125-
loc, varTy, globalName,
126-
[&](fir::FirOpBuilder &builder) {
127-
mlir::Operation *cln = constant_def->clone();
128-
builder.insert(cln);
129-
fir::ExtendedValue exv{cln->getResult(0)};
130-
mlir::Value valBase = fir::getBase(exv);
131-
mlir::Value val = builder.createConvert(loc, varTy, valBase);
132-
builder.create<fir::HasValueOp>(loc, val);
133-
},
134-
builder.createInternalLinkage());
135-
mlir::Value ope = {builder.create<fir::AddrOfOp>(
136-
loc, global.resultType(), global.getSymbol())};
137-
newOperands.push_back(ope);
138-
} else {
139-
// alloca but without attr, add it
140-
newOperands.push_back(a);
73+
store = s;
14174
}
142-
} else {
143-
// non-alloca operand, add it
75+
}
76+
77+
// If we didn't find one signle store, add argument as is, and move on.
78+
if (!store) {
79+
newOperands.push_back(a);
80+
continue;
81+
}
82+
83+
LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
84+
85+
mlir::Operation *constant_def = store->getOperand(0).getDefiningOp();
86+
// Expect constant definition operation or force legalisation of the
87+
// callOp and continue with its next argument
88+
if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
89+
// Unable to remove alloca arg
14490
newOperands.push_back(a);
91+
continue;
14592
}
93+
94+
LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
95+
96+
std::string globalName = "_extruded_." + std::to_string(uniqueLitId++);
97+
assert(!builder.getNamedGlobal(globalName) &&
98+
"We should have a unique name here");
99+
100+
unsigned count = 0;
101+
for (mlir::Operation *s : alloca->getUsers())
102+
if (di.dominates(store, s))
103+
++count;
104+
105+
// Delete if dominates itself and one more operation (which should
106+
// be callOp)
107+
if (count == 2)
108+
toErase.push_back(store);
109+
110+
auto loc = callOp.getLoc();
111+
fir::GlobalOp global = builder.createGlobalConstant(
112+
loc, varTy, globalName,
113+
[&](fir::FirOpBuilder &builder) {
114+
mlir::Operation *cln = constant_def->clone();
115+
builder.insert(cln);
116+
mlir::Value val =
117+
builder.createConvert(loc, varTy, cln->getResult(0));
118+
builder.create<fir::HasValueOp>(loc, val);
119+
},
120+
builder.createInternalLinkage());
121+
mlir::Value addr = {builder.create<fir::AddrOfOp>(
122+
loc, global.resultType(), global.getSymbol())};
123+
newOperands.push_back(addr);
124+
needUpdate = true;
146125
}
147126

148-
auto loc = callOp.getLoc();
149-
llvm::SmallVector<mlir::Type> newResultTypes;
150-
newResultTypes.append(callOp.getResultTypes().begin(),
151-
callOp.getResultTypes().end());
152-
fir::CallOp newOp = builder.create<fir::CallOp>(
153-
loc, newResultTypes,
154-
callOp.getCallee().has_value() ? callOp.getCallee().value()
155-
: mlir::SymbolRefAttr{},
156-
newOperands, callOp.getFastmathAttr());
157-
rewriter.replaceOp(callOp, newOp);
158-
159-
for (auto e : toErase)
160-
rewriter.eraseOp(e);
161-
162-
LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
163-
<< newOp << '\n');
164-
return mlir::success();
127+
if (needUpdate) {
128+
auto loc = callOp.getLoc();
129+
llvm::SmallVector<mlir::Type> newResultTypes;
130+
newResultTypes.append(callOp.getResultTypes().begin(),
131+
callOp.getResultTypes().end());
132+
fir::CallOp newOp = builder.create<fir::CallOp>(
133+
loc, newResultTypes,
134+
callOp.getCallee().has_value() ? callOp.getCallee().value()
135+
: mlir::SymbolRefAttr{},
136+
newOperands, callOp.getFastmathAttr());
137+
rewriter.replaceOp(callOp, newOp);
138+
139+
for (auto e : toErase)
140+
rewriter.eraseOp(e);
141+
LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
142+
<< newOp << '\n');
143+
return mlir::success();
144+
}
145+
146+
// Failure here just means "we couldn't do the conversion", which is
147+
// perfectly acceptable to the upper layers of this function.
148+
return mlir::failure();
165149
}
166150
};
167151

168-
// This pass attempts to convert immediate scalar literals in function calls
152+
// this pass attempts to convert immediate scalar literals in function calls
169153
// to global constants to allow transformations as Dead Argument Elimination
170154
class ConstExtruderOpt
171155
: public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
172-
protected:
173-
mlir::DominanceInfo *di;
174-
mlir::GreedyRewriteConfig config;
175-
176156
public:
177-
ConstExtruderOpt() {
178-
config.enableRegionSimplification = false;
179-
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
180-
}
157+
ConstExtruderOpt() = default;
181158

182159
void runOnOperation() override {
183160
mlir::ModuleOp mod = getOperation();
184-
di = &getAnalysis<mlir::DominanceInfo>();
185-
mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
161+
mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
162+
mod.walk([di, this](mlir::func::FuncOp func) { runOnFunc(func, di); });
186163
}
187164

188-
void runOnFunc(mlir::func::FuncOp &func) {
189-
auto *context = &getContext();
190-
mlir::RewritePatternSet patterns(context);
191-
165+
void runOnFunc(mlir::func::FuncOp &func, const mlir::DominanceInfo *di) {
192166
// If func is a declaration, skip it.
193167
if (func.empty())
194168
return;
195169

170+
auto *context = &getContext();
171+
mlir::RewritePatternSet patterns(context);
172+
mlir::GreedyRewriteConfig config;
173+
config.enableRegionSimplification = false;
174+
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
175+
196176
patterns.insert<CallOpRewriter>(context, *di);
197177
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
198178
func, std::move(patterns), config))) {

flang/test/Driver/bbc-mlir-pass-pipeline.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131

3232
! CHECK-NEXT: 'func.func' Pipeline
3333
! CHECK-NEXT: MemoryAllocationOpt
34-
! CHECK-NEXT: ConstExtruderOpt
3534

35+
! CHECK-NEXT: ConstExtruderOpt
3636
! CHECK-NEXT: Inliner
3737
! CHECK-NEXT: SimplifyRegionLite
3838
! CHECK-NEXT: CSE

0 commit comments

Comments
 (0)