1
- // ===- ConstExtruder.cpp -----------------------------------------------===//
1
+ // ===- ConstExtruder.cpp -------------------------------------------------- ===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
- #include " flang/Optimizer/Builder/BoxValue.h"
10
9
#include " flang/Optimizer/Builder/FIRBuilder.h"
11
10
#include " flang/Optimizer/Dialect/FIRDialect.h"
12
11
#include " flang/Optimizer/Dialect/FIROps.h"
17
16
#include " mlir/IR/Dominance.h"
18
17
#include " mlir/Pass/Pass.h"
19
18
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
20
- #include " mlir/Transforms/Passes.h"
21
- #include " llvm/ADT/TypeSwitch.h"
22
- #include < atomic>
23
19
24
20
namespace fir {
25
21
#define GEN_PASS_DEF_CONSTEXTRUDEROPT
@@ -29,170 +25,154 @@ namespace fir {
29
25
#define DEBUG_TYPE " flang-const-extruder-opt"
30
26
31
27
namespace {
32
- std::atomic<int > uniqueLitId = 1 ;
33
-
34
- static bool needsExtrusion (const mlir::Value *a) {
35
- if (!a || !a->getDefiningOp ())
36
- return false ;
37
-
38
- // is alloca
39
- if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp ())) {
40
- // alloca has annotation
41
- if (alloca->hasAttr (fir::getAdaptToByRefAttrName ())) {
42
- for (mlir::Operation *s : alloca.getOperation ()->getUsers ()) {
43
- if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
44
- auto constant_def = store->getOperand (0 ).getDefiningOp ();
45
- // Expect constant definition operation
46
- if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
47
- return true ;
48
- }
49
- }
50
- }
51
- }
52
- }
53
- return false ;
54
- }
28
+ unsigned uniqueLitId = 1 ;
55
29
56
30
class CallOpRewriter : public mlir ::OpRewritePattern<fir::CallOp> {
57
31
protected:
58
- mlir::DominanceInfo &di;
32
+ const mlir::DominanceInfo &di;
59
33
60
34
public:
61
35
using OpRewritePattern::OpRewritePattern;
62
36
63
- CallOpRewriter (mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
37
+ CallOpRewriter (mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
64
38
: OpRewritePattern(ctx), di(_di) {}
65
39
66
40
mlir::LogicalResult
67
41
matchAndRewrite (fir::CallOp callOp,
68
42
mlir::PatternRewriter &rewriter) const override {
69
43
LLVM_DEBUG (llvm::dbgs () << " Processing call op: " << callOp << " \n " );
70
44
auto module = callOp->getParentOfType <mlir::ModuleOp>();
45
+ bool needUpdate = false ;
71
46
fir::FirOpBuilder builder (rewriter, module );
72
47
llvm::SmallVector<mlir::Value> newOperands;
73
48
llvm::SmallVector<mlir::Operation *> toErase;
74
- for (const auto &a : callOp.getArgs ()) {
75
- if (auto alloca =
76
- mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp ())) {
77
- if (needsExtrusion (&a)) {
78
-
79
- mlir::Type varTy = alloca.getInType ();
80
- assert (!fir::hasDynamicSize (varTy) &&
81
- " only expect statically sized scalars to be by value" );
82
-
83
- // find immediate store with const argument
84
- llvm::SmallVector<mlir::Operation *> stores;
85
- for (mlir::Operation *s : alloca.getOperation ()->getUsers ())
86
- if (mlir::isa<fir::StoreOp>(s) && di.dominates (s, callOp))
87
- stores.push_back (s);
88
- assert (stores.size () == 1 && " expected exactly one store" );
89
- LLVM_DEBUG (llvm::dbgs () << " found store " << *stores[0 ] << " \n " );
90
-
91
- auto constant_def = stores[0 ]->getOperand (0 ).getDefiningOp ();
92
- // Expect constant definition operation or force legalisation of the
93
- // callOp and continue with its next argument
94
- if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
95
- // unable to remove alloca arg
96
- newOperands.push_back (a);
97
- continue ;
98
- }
49
+ for (const mlir::Value &a : callOp.getArgs ()) {
50
+ auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp ());
51
+ // We can convert arguments that are alloca, and that has
52
+ // the value by reference attribute. All else is just added
53
+ // to the argument list.
54
+ if (!alloca || !alloca->hasAttr (fir::getAdaptToByRefAttrName ())) {
55
+ newOperands.push_back (a);
56
+ continue ;
57
+ }
99
58
100
- LLVM_DEBUG (llvm::dbgs () << " found define " << *constant_def << " \n " );
101
-
102
- auto loc = callOp.getLoc ();
103
- llvm::StringRef globalPrefix = " _extruded_" ;
104
-
105
- std::string globalName;
106
- while (!globalName.length () || builder.getNamedGlobal (globalName))
107
- globalName =
108
- globalPrefix.str () + " ." + std::to_string (uniqueLitId++);
109
-
110
- if (alloca->hasOneUse ()) {
111
- toErase.push_back (alloca);
112
- toErase.push_back (stores[0 ]);
113
- } else {
114
- int count = -2 ;
115
- for (mlir::Operation *s : alloca.getOperation ()->getUsers ())
116
- if (di.dominates (stores[0 ], s))
117
- ++count;
118
-
119
- // delete if dominates itself and one more operation (which should
120
- // be callOp)
121
- if (!count)
122
- toErase.push_back (stores[0 ]);
59
+ mlir::Type varTy = alloca.getInType ();
60
+ assert (!fir::hasDynamicSize (varTy) &&
61
+ " only expect statically sized scalars to be by value" );
62
+
63
+ // Find immediate store with const argument
64
+ mlir::Operation *store = nullptr ;
65
+ for (mlir::Operation *s : alloca->getUsers ()) {
66
+ if (mlir::isa<fir::StoreOp>(s) && di.dominates (s, callOp)) {
67
+ // We can only deal with ONE store - if already found one,
68
+ // set to nullptr and exit the loop.
69
+ if (store) {
70
+ store = nullptr ;
71
+ break ;
123
72
}
124
- auto global = builder.createGlobalConstant (
125
- loc, varTy, globalName,
126
- [&](fir::FirOpBuilder &builder) {
127
- mlir::Operation *cln = constant_def->clone ();
128
- builder.insert (cln);
129
- fir::ExtendedValue exv{cln->getResult (0 )};
130
- mlir::Value valBase = fir::getBase (exv);
131
- mlir::Value val = builder.createConvert (loc, varTy, valBase);
132
- builder.create <fir::HasValueOp>(loc, val);
133
- },
134
- builder.createInternalLinkage ());
135
- mlir::Value ope = {builder.create <fir::AddrOfOp>(
136
- loc, global.resultType (), global.getSymbol ())};
137
- newOperands.push_back (ope);
138
- } else {
139
- // alloca but without attr, add it
140
- newOperands.push_back (a);
73
+ store = s;
141
74
}
142
- } else {
143
- // non-alloca operand, add it
75
+ }
76
+
77
+ // If we didn't find one signle store, add argument as is, and move on.
78
+ if (!store) {
79
+ newOperands.push_back (a);
80
+ continue ;
81
+ }
82
+
83
+ LLVM_DEBUG (llvm::dbgs () << " found store " << *store << " \n " );
84
+
85
+ mlir::Operation *constant_def = store->getOperand (0 ).getDefiningOp ();
86
+ // Expect constant definition operation or force legalisation of the
87
+ // callOp and continue with its next argument
88
+ if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
89
+ // Unable to remove alloca arg
144
90
newOperands.push_back (a);
91
+ continue ;
145
92
}
93
+
94
+ LLVM_DEBUG (llvm::dbgs () << " found define " << *constant_def << " \n " );
95
+
96
+ std::string globalName = " _extruded_." + std::to_string (uniqueLitId++);
97
+ assert (!builder.getNamedGlobal (globalName) &&
98
+ " We should have a unique name here" );
99
+
100
+ unsigned count = 0 ;
101
+ for (mlir::Operation *s : alloca->getUsers ())
102
+ if (di.dominates (store, s))
103
+ ++count;
104
+
105
+ // Delete if dominates itself and one more operation (which should
106
+ // be callOp)
107
+ if (count == 2 )
108
+ toErase.push_back (store);
109
+
110
+ auto loc = callOp.getLoc ();
111
+ fir::GlobalOp global = builder.createGlobalConstant (
112
+ loc, varTy, globalName,
113
+ [&](fir::FirOpBuilder &builder) {
114
+ mlir::Operation *cln = constant_def->clone ();
115
+ builder.insert (cln);
116
+ mlir::Value val =
117
+ builder.createConvert (loc, varTy, cln->getResult (0 ));
118
+ builder.create <fir::HasValueOp>(loc, val);
119
+ },
120
+ builder.createInternalLinkage ());
121
+ mlir::Value addr = {builder.create <fir::AddrOfOp>(
122
+ loc, global.resultType (), global.getSymbol ())};
123
+ newOperands.push_back (addr);
124
+ needUpdate = true ;
146
125
}
147
126
148
- auto loc = callOp.getLoc ();
149
- llvm::SmallVector<mlir::Type> newResultTypes;
150
- newResultTypes.append (callOp.getResultTypes ().begin (),
151
- callOp.getResultTypes ().end ());
152
- fir::CallOp newOp = builder.create <fir::CallOp>(
153
- loc, newResultTypes,
154
- callOp.getCallee ().has_value () ? callOp.getCallee ().value ()
155
- : mlir::SymbolRefAttr{},
156
- newOperands, callOp.getFastmathAttr ());
157
- rewriter.replaceOp (callOp, newOp);
158
-
159
- for (auto e : toErase)
160
- rewriter.eraseOp (e);
161
-
162
- LLVM_DEBUG (llvm::dbgs () << " extruded constant for " << callOp << " as "
163
- << newOp << ' \n ' );
164
- return mlir::success ();
127
+ if (needUpdate) {
128
+ auto loc = callOp.getLoc ();
129
+ llvm::SmallVector<mlir::Type> newResultTypes;
130
+ newResultTypes.append (callOp.getResultTypes ().begin (),
131
+ callOp.getResultTypes ().end ());
132
+ fir::CallOp newOp = builder.create <fir::CallOp>(
133
+ loc, newResultTypes,
134
+ callOp.getCallee ().has_value () ? callOp.getCallee ().value ()
135
+ : mlir::SymbolRefAttr{},
136
+ newOperands, callOp.getFastmathAttr ());
137
+ rewriter.replaceOp (callOp, newOp);
138
+
139
+ for (auto e : toErase)
140
+ rewriter.eraseOp (e);
141
+ LLVM_DEBUG (llvm::dbgs () << " extruded constant for " << callOp << " as "
142
+ << newOp << ' \n ' );
143
+ return mlir::success ();
144
+ }
145
+
146
+ // Failure here just means "we couldn't do the conversion", which is
147
+ // perfectly acceptable to the upper layers of this function.
148
+ return mlir::failure ();
165
149
}
166
150
};
167
151
168
- // This pass attempts to convert immediate scalar literals in function calls
152
+ // this pass attempts to convert immediate scalar literals in function calls
169
153
// to global constants to allow transformations as Dead Argument Elimination
170
154
class ConstExtruderOpt
171
155
: public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
172
- protected:
173
- mlir::DominanceInfo *di;
174
- mlir::GreedyRewriteConfig config;
175
-
176
156
public:
177
- ConstExtruderOpt () {
178
- config.enableRegionSimplification = false ;
179
- config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
180
- }
157
+ ConstExtruderOpt () = default ;
181
158
182
159
void runOnOperation () override {
183
160
mlir::ModuleOp mod = getOperation ();
184
- di = &getAnalysis<mlir::DominanceInfo>();
185
- mod.walk ([this ](mlir::func::FuncOp func) { runOnFunc (func); });
161
+ mlir::DominanceInfo * di = &getAnalysis<mlir::DominanceInfo>();
162
+ mod.walk ([di, this ](mlir::func::FuncOp func) { runOnFunc (func, di ); });
186
163
}
187
164
188
- void runOnFunc (mlir::func::FuncOp &func) {
189
- auto *context = &getContext ();
190
- mlir::RewritePatternSet patterns (context);
191
-
165
+ void runOnFunc (mlir::func::FuncOp &func, const mlir::DominanceInfo *di) {
192
166
// If func is a declaration, skip it.
193
167
if (func.empty ())
194
168
return ;
195
169
170
+ auto *context = &getContext ();
171
+ mlir::RewritePatternSet patterns (context);
172
+ mlir::GreedyRewriteConfig config;
173
+ config.enableRegionSimplification = false ;
174
+ config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
175
+
196
176
patterns.insert <CallOpRewriter>(context, *di);
197
177
if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
198
178
func, std::move (patterns), config))) {
0 commit comments