28
28
#include < optional>
29
29
#include < utility>
30
30
31
+ namespace mlir {
32
+ class SymbolTable ;
33
+ }
34
+
31
35
namespace fir {
32
36
class AbstractArrayBox ;
33
37
class ExtendedValue ;
@@ -42,8 +46,10 @@ class BoxValue;
42
46
// / patterns.
43
47
class FirOpBuilder : public mlir ::OpBuilder, public mlir::OpBuilder::Listener {
44
48
public:
45
- explicit FirOpBuilder (mlir::Operation *op, fir::KindMapping kindMap)
46
- : OpBuilder{op, /* listener=*/ this }, kindMap{std::move (kindMap)} {}
49
+ explicit FirOpBuilder (mlir::Operation *op, fir::KindMapping kindMap,
50
+ mlir::SymbolTable *symbolTable = nullptr )
51
+ : OpBuilder{op, /* listener=*/ this }, kindMap{std::move (kindMap)},
52
+ symbolTable{symbolTable} {}
47
53
explicit FirOpBuilder (mlir::OpBuilder &builder, fir::KindMapping kindMap)
48
54
: OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move (kindMap)} {
49
55
setListener (this );
@@ -69,13 +75,14 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
69
75
// The listener self-reference has to be updated in case of copy-construction.
70
76
FirOpBuilder (const FirOpBuilder &other)
71
77
: OpBuilder(other), OpBuilder::Listener(), kindMap{other.kindMap },
72
- fastMathFlags{other.fastMathFlags } {
78
+ fastMathFlags{other.fastMathFlags }, symbolTable{other. symbolTable } {
73
79
setListener (this );
74
80
}
75
81
76
82
FirOpBuilder (FirOpBuilder &&other)
77
83
: OpBuilder(other), OpBuilder::Listener(),
78
- kindMap{std::move (other.kindMap )}, fastMathFlags{other.fastMathFlags } {
84
+ kindMap{std::move (other.kindMap )}, fastMathFlags{other.fastMathFlags },
85
+ symbolTable{other.symbolTable } {
79
86
setListener (this );
80
87
}
81
88
@@ -95,6 +102,9 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
95
102
// / Get a reference to the kind map.
96
103
const fir::KindMapping &getKindMap () { return kindMap; }
97
104
105
+ // / Get func.func/fir.global symbol table attached to this builder if any.
106
+ mlir::SymbolTable *getMLIRSymbolTable () { return symbolTable; }
107
+
98
108
// / Get the default integer type
99
109
[[maybe_unused]] mlir::IntegerType getDefaultIntegerType () {
100
110
return getIntegerType (
@@ -280,24 +290,27 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
280
290
// / Get a function by name. If the function exists in the current module, it
281
291
// / is returned. Otherwise, a null FuncOp is returned.
282
292
mlir::func::FuncOp getNamedFunction (llvm::StringRef name) {
283
- return getNamedFunction (getModule (), name);
293
+ return getNamedFunction (getModule (), getMLIRSymbolTable (), name);
284
294
}
285
- static mlir::func::FuncOp getNamedFunction (mlir::ModuleOp module ,
286
- llvm::StringRef name);
295
+ static mlir::func::FuncOp
296
+ getNamedFunction (mlir::ModuleOp module , const mlir::SymbolTable *symbolTable,
297
+ llvm::StringRef name);
287
298
288
299
// / Get a function by symbol name. The result will be null if there is no
289
300
// / function with the given symbol in the module.
290
301
mlir::func::FuncOp getNamedFunction (mlir::SymbolRefAttr symbol) {
291
- return getNamedFunction (getModule (), symbol);
302
+ return getNamedFunction (getModule (), getMLIRSymbolTable (), symbol);
292
303
}
293
- static mlir::func::FuncOp getNamedFunction (mlir::ModuleOp module ,
294
- mlir::SymbolRefAttr symbol);
304
+ static mlir::func::FuncOp
305
+ getNamedFunction (mlir::ModuleOp module , const mlir::SymbolTable *symbolTable,
306
+ mlir::SymbolRefAttr symbol);
295
307
296
308
fir::GlobalOp getNamedGlobal (llvm::StringRef name) {
297
- return getNamedGlobal (getModule (), name);
309
+ return getNamedGlobal (getModule (), getMLIRSymbolTable (), name);
298
310
}
299
311
300
312
static fir::GlobalOp getNamedGlobal (mlir::ModuleOp module ,
313
+ const mlir::SymbolTable *symbolTable,
301
314
llvm::StringRef name);
302
315
303
316
// / Lazy creation of fir.convert op.
@@ -313,35 +326,18 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
313
326
// / result of the load if it was created, otherwise return \p val
314
327
mlir::Value loadIfRef (mlir::Location loc, mlir::Value val);
315
328
316
- // / Create a new FuncOp. If the function may have already been created, use
317
- // / `addNamedFunction` instead .
329
+ // / Determine if the named function is already in the module. Return the
330
+ // / instance if found, otherwise add a new named function to the module .
318
331
mlir::func::FuncOp createFunction (mlir::Location loc, llvm::StringRef name,
319
332
mlir::FunctionType ty) {
320
- return createFunction (loc, getModule (), name, ty);
333
+ return createFunction (loc, getModule (), name, ty, getMLIRSymbolTable () );
321
334
}
322
335
323
336
static mlir::func::FuncOp createFunction (mlir::Location loc,
324
337
mlir::ModuleOp module ,
325
338
llvm::StringRef name,
326
- mlir::FunctionType ty);
327
-
328
- // / Determine if the named function is already in the module. Return the
329
- // / instance if found, otherwise add a new named function to the module.
330
- mlir::func::FuncOp addNamedFunction (mlir::Location loc, llvm::StringRef name,
331
- mlir::FunctionType ty) {
332
- if (auto func = getNamedFunction (name))
333
- return func;
334
- return createFunction (loc, name, ty);
335
- }
336
-
337
- static mlir::func::FuncOp addNamedFunction (mlir::Location loc,
338
- mlir::ModuleOp module ,
339
- llvm::StringRef name,
340
- mlir::FunctionType ty) {
341
- if (auto func = getNamedFunction (module , name))
342
- return func;
343
- return createFunction (loc, module , name, ty);
344
- }
339
+ mlir::FunctionType ty,
340
+ mlir::SymbolTable *);
345
341
346
342
// / Cast the input value to IndexType.
347
343
mlir::Value convertToIndexType (mlir::Location loc, mlir::Value val) {
@@ -515,6 +511,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
515
511
// / FastMathFlags that need to be set for operations that support
516
512
// / mlir::arith::FastMathAttr.
517
513
mlir::arith::FastMathFlags fastMathFlags{};
514
+
515
+ // / fir::GlobalOp and func::FuncOp symbol table to speed-up
516
+ // / lookups.
517
+ mlir::SymbolTable *symbolTable = nullptr ;
518
518
};
519
519
520
520
} // namespace fir
0 commit comments