Skip to content

Commit a4798bb

Browse files
authored
[flang][NFC] use mlir::SymbolTable in lowering (#86673)
Whenever lowering is checking if a function or global already exists in the mlir::Module, it was doing module->lookup. On big programs (~5000 globals and functions), this causes important slowdowns because these lookups are linear. Use mlir::SymbolTable to speed-up these lookups. The SymbolTable has to be created from the ModuleOp and maintained in sync. It is therefore placed in the converter, and FirOPBuilders can take a pointer to it to speed-up the lookups. This patch does not bring mlir::SymbolTable to FIR/HLFIR passes, but some passes creating a lot of runtime calls could benefit from it too. More analysis will be needed. As an example of the speed-ups, this patch speeds-up compilation of Whizard compare_amplitude_UFO.F90 from 5 mins to 2 mins on my machine (there is still room for speed-ups).
1 parent 2950283 commit a4798bb

File tree

12 files changed

+206
-127
lines changed

12 files changed

+206
-127
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
#include "mlir/IR/Operation.h"
2424
#include "llvm/ADT/ArrayRef.h"
2525

26+
namespace mlir {
27+
class SymbolTable;
28+
}
29+
2630
namespace fir {
2731
class KindMapping;
2832
class FirOpBuilder;
@@ -305,6 +309,15 @@ class AbstractConverter {
305309
virtual Fortran::lower::SymbolBox
306310
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
307311

312+
/// Return the mlir::SymbolTable associated to the ModuleOp.
313+
/// Look-ups are faster using it than using module.lookup<>,
314+
/// but the module op should be queried in case of failure
315+
/// because this symbol table is not guaranteed to contain
316+
/// all the symbols from the ModuleOp (the symbol table should
317+
/// always be provided to the builder helper creating globals and
318+
/// functions in order to be in sync).
319+
virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
320+
308321
private:
309322
/// Options controlling lowering behavior.
310323
const Fortran::lower::LoweringOptions &loweringOptions;

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
#include <optional>
2929
#include <utility>
3030

31+
namespace mlir {
32+
class SymbolTable;
33+
}
34+
3135
namespace fir {
3236
class AbstractArrayBox;
3337
class ExtendedValue;
@@ -42,8 +46,10 @@ class BoxValue;
4246
/// patterns.
4347
class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
4448
public:
45-
explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap)
46-
: OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)} {}
49+
explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap,
50+
mlir::SymbolTable *symbolTable = nullptr)
51+
: OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)},
52+
symbolTable{symbolTable} {}
4753
explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap)
4854
: OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} {
4955
setListener(this);
@@ -69,13 +75,14 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
6975
// The listener self-reference has to be updated in case of copy-construction.
7076
FirOpBuilder(const FirOpBuilder &other)
7177
: OpBuilder(other), OpBuilder::Listener(), kindMap{other.kindMap},
72-
fastMathFlags{other.fastMathFlags} {
78+
fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
7379
setListener(this);
7480
}
7581

7682
FirOpBuilder(FirOpBuilder &&other)
7783
: OpBuilder(other), OpBuilder::Listener(),
78-
kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags} {
84+
kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags},
85+
symbolTable{other.symbolTable} {
7986
setListener(this);
8087
}
8188

@@ -95,6 +102,9 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
95102
/// Get a reference to the kind map.
96103
const fir::KindMapping &getKindMap() { return kindMap; }
97104

105+
/// Get func.func/fir.global symbol table attached to this builder if any.
106+
mlir::SymbolTable *getMLIRSymbolTable() { return symbolTable; }
107+
98108
/// Get the default integer type
99109
[[maybe_unused]] mlir::IntegerType getDefaultIntegerType() {
100110
return getIntegerType(
@@ -280,24 +290,27 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
280290
/// Get a function by name. If the function exists in the current module, it
281291
/// is returned. Otherwise, a null FuncOp is returned.
282292
mlir::func::FuncOp getNamedFunction(llvm::StringRef name) {
283-
return getNamedFunction(getModule(), name);
293+
return getNamedFunction(getModule(), getMLIRSymbolTable(), name);
284294
}
285-
static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
286-
llvm::StringRef name);
295+
static mlir::func::FuncOp
296+
getNamedFunction(mlir::ModuleOp module, const mlir::SymbolTable *symbolTable,
297+
llvm::StringRef name);
287298

288299
/// Get a function by symbol name. The result will be null if there is no
289300
/// function with the given symbol in the module.
290301
mlir::func::FuncOp getNamedFunction(mlir::SymbolRefAttr symbol) {
291-
return getNamedFunction(getModule(), symbol);
302+
return getNamedFunction(getModule(), getMLIRSymbolTable(), symbol);
292303
}
293-
static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
294-
mlir::SymbolRefAttr symbol);
304+
static mlir::func::FuncOp
305+
getNamedFunction(mlir::ModuleOp module, const mlir::SymbolTable *symbolTable,
306+
mlir::SymbolRefAttr symbol);
295307

296308
fir::GlobalOp getNamedGlobal(llvm::StringRef name) {
297-
return getNamedGlobal(getModule(), name);
309+
return getNamedGlobal(getModule(), getMLIRSymbolTable(), name);
298310
}
299311

300312
static fir::GlobalOp getNamedGlobal(mlir::ModuleOp module,
313+
const mlir::SymbolTable *symbolTable,
301314
llvm::StringRef name);
302315

303316
/// Lazy creation of fir.convert op.
@@ -313,35 +326,18 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
313326
/// result of the load if it was created, otherwise return \p val
314327
mlir::Value loadIfRef(mlir::Location loc, mlir::Value val);
315328

316-
/// Create a new FuncOp. If the function may have already been created, use
317-
/// `addNamedFunction` instead.
329+
/// Determine if the named function is already in the module. Return the
330+
/// instance if found, otherwise add a new named function to the module.
318331
mlir::func::FuncOp createFunction(mlir::Location loc, llvm::StringRef name,
319332
mlir::FunctionType ty) {
320-
return createFunction(loc, getModule(), name, ty);
333+
return createFunction(loc, getModule(), name, ty, getMLIRSymbolTable());
321334
}
322335

323336
static mlir::func::FuncOp createFunction(mlir::Location loc,
324337
mlir::ModuleOp module,
325338
llvm::StringRef name,
326-
mlir::FunctionType ty);
327-
328-
/// Determine if the named function is already in the module. Return the
329-
/// instance if found, otherwise add a new named function to the module.
330-
mlir::func::FuncOp addNamedFunction(mlir::Location loc, llvm::StringRef name,
331-
mlir::FunctionType ty) {
332-
if (auto func = getNamedFunction(name))
333-
return func;
334-
return createFunction(loc, name, ty);
335-
}
336-
337-
static mlir::func::FuncOp addNamedFunction(mlir::Location loc,
338-
mlir::ModuleOp module,
339-
llvm::StringRef name,
340-
mlir::FunctionType ty) {
341-
if (auto func = getNamedFunction(module, name))
342-
return func;
343-
return createFunction(loc, module, name, ty);
344-
}
339+
mlir::FunctionType ty,
340+
mlir::SymbolTable *);
345341

346342
/// Cast the input value to IndexType.
347343
mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) {
@@ -515,6 +511,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
515511
/// FastMathFlags that need to be set for operations that support
516512
/// mlir::arith::FastMathAttr.
517513
mlir::arith::FastMathFlags fastMathFlags{};
514+
515+
/// fir::GlobalOp and func::FuncOp symbol table to speed-up
516+
/// lookups.
517+
mlir::SymbolTable *symbolTable = nullptr;
518518
};
519519

520520
} // namespace fir

flang/include/flang/Optimizer/Dialect/FIROpsSupport.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,19 @@ inline bool pureCall(mlir::Operation *op) {
5252
/// Get or create a FuncOp in a module.
5353
///
5454
/// If `module` already contains FuncOp `name`, it is returned. Otherwise, a new
55-
/// FuncOp is created, and that new FuncOp is returned.
56-
mlir::func::FuncOp
57-
createFuncOp(mlir::Location loc, mlir::ModuleOp module, llvm::StringRef name,
58-
mlir::FunctionType type,
59-
llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
60-
61-
/// Get or create a GlobalOp in a module.
55+
/// FuncOp is created, and that new FuncOp is returned. A symbol table can
56+
/// be provided to speed-up the lookups.
57+
mlir::func::FuncOp createFuncOp(mlir::Location loc, mlir::ModuleOp module,
58+
llvm::StringRef name, mlir::FunctionType type,
59+
llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
60+
const mlir::SymbolTable *symbolTable = nullptr);
61+
62+
/// Get or create a GlobalOp in a module. A symbol table can be provided to
63+
/// speed-up the lookups.
6264
fir::GlobalOp createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
6365
llvm::StringRef name, mlir::Type type,
64-
llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
66+
llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
67+
const mlir::SymbolTable *symbolTable = nullptr);
6568

6669
/// Attribute to mark Fortran entities with the CONTIGUOUS attribute.
6770
constexpr llvm::StringRef getContiguousAttrName() { return "fir.contiguous"; }

flang/lib/Lower/Bridge.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
273273
public:
274274
explicit FirConverter(Fortran::lower::LoweringBridge &bridge)
275275
: Fortran::lower::AbstractConverter(bridge.getLoweringOptions()),
276-
bridge{bridge}, foldingContext{bridge.createFoldingContext()} {}
276+
bridge{bridge}, foldingContext{bridge.createFoldingContext()},
277+
mlirSymbolTable{bridge.getModule()} {}
277278
virtual ~FirConverter() = default;
278279

279280
/// Convert the PFT to FIR.
@@ -329,8 +330,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
329330
[&](Fortran::lower::pft::BlockDataUnit &b) {},
330331
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
331332
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
332-
builder = new fir::FirOpBuilder(bridge.getModule(),
333-
bridge.getKindMap());
333+
builder = new fir::FirOpBuilder(
334+
bridge.getModule(), bridge.getKindMap(), &mlirSymbolTable);
334335
Fortran::lower::genOpenACCRoutineConstruct(
335336
*this, bridge.getSemanticsContext(), bridge.getModule(),
336337
d.routine, accRoutineInfos);
@@ -1036,6 +1037,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
10361037
return {};
10371038
}
10381039

1040+
mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
1041+
10391042
/// Add the symbol to the local map and return `true`. If the symbol is
10401043
/// already in the map and \p forced is `false`, the map is not updated.
10411044
/// Instead the value `false` is returned.
@@ -4571,7 +4574,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
45714574
llvm::dbgs() << "\n");
45724575
Fortran::lower::CalleeInterface callee(funit, *this);
45734576
mlir::func::FuncOp func = callee.addEntryBlockAndMapArguments();
4574-
builder = new fir::FirOpBuilder(func, bridge.getKindMap());
4577+
builder =
4578+
new fir::FirOpBuilder(func, bridge.getKindMap(), &mlirSymbolTable);
45754579
assert(builder && "FirOpBuilder did not instantiate");
45764580
builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
45774581
builder->setInsertionPointToStart(&func.front());
@@ -4839,12 +4843,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
48394843
// FIXME: get rid of the bogus function context and instantiate the
48404844
// globals directly into the module.
48414845
mlir::MLIRContext *context = &getMLIRContext();
4846+
mlir::SymbolTable *symbolTable = getMLIRSymbolTable();
48424847
mlir::func::FuncOp func = fir::FirOpBuilder::createFunction(
48434848
mlir::UnknownLoc::get(context), getModuleOp(),
48444849
fir::NameUniquer::doGenerated("Sham"),
4845-
mlir::FunctionType::get(context, std::nullopt, std::nullopt));
4850+
mlir::FunctionType::get(context, std::nullopt, std::nullopt),
4851+
symbolTable);
48464852
func.addEntryBlock();
4847-
builder = new fir::FirOpBuilder(func, bridge.getKindMap());
4853+
builder = new fir::FirOpBuilder(func, bridge.getKindMap(), symbolTable);
48484854
assert(builder && "FirOpBuilder did not instantiate");
48494855
builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
48504856
createGlobals();
@@ -5336,6 +5342,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
53365342
/// utilities to deal with procedure pointer components whose arguments have
53375343
/// the type of the containing derived type.
53385344
Fortran::lower::TypeConstructionStack typeConstructionStack;
5345+
/// MLIR symbol table of the fir.global/func.func operations. Note that it is
5346+
/// not guaranteed to contain all operations of the ModuleOp with Symbol
5347+
/// attribute since mlirSymbolTable must pro-actively be maintained when
5348+
/// new Symbol operations are created.
5349+
mlir::SymbolTable mlirSymbolTable;
53395350
};
53405351

53415352
} // namespace

flang/lib/Lower/CallInterface.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,11 +667,13 @@ void Fortran::lower::CallInterface<T>::declare() {
667667
if (!side().isIndirectCall()) {
668668
std::string name = side().getMangledName();
669669
mlir::ModuleOp module = converter.getModuleOp();
670-
func = fir::FirOpBuilder::getNamedFunction(module, name);
670+
mlir::SymbolTable *symbolTable = converter.getMLIRSymbolTable();
671+
func = fir::FirOpBuilder::getNamedFunction(module, symbolTable, name);
671672
if (!func) {
672673
mlir::Location loc = side().getCalleeLocation();
673674
mlir::FunctionType ty = genFunctionType();
674-
func = fir::FirOpBuilder::createFunction(loc, module, name, ty);
675+
func =
676+
fir::FirOpBuilder::createFunction(loc, module, name, ty, symbolTable);
675677
if (const Fortran::semantics::Symbol *sym = side().getProcedureSymbol()) {
676678
if (side().isMainProgram()) {
677679
func->setAttr(fir::getSymbolAttrName(),
@@ -1644,7 +1646,8 @@ mlir::func::FuncOp Fortran::lower::getOrDeclareFunction(
16441646
Fortran::lower::AbstractConverter &converter) {
16451647
mlir::ModuleOp module = converter.getModuleOp();
16461648
std::string name = getProcMangledName(proc, converter);
1647-
mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(module, name);
1649+
mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(
1650+
module, converter.getMLIRSymbolTable(), name);
16481651
if (func)
16491652
return func;
16501653

flang/lib/Lower/OpenACC.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3821,7 +3821,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
38213821
std::string funcName;
38223822
if (name) {
38233823
funcName = converter.mangleName(*name->symbol);
3824-
funcOp = builder.getNamedFunction(mod, funcName);
3824+
funcOp =
3825+
builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
38253826
} else {
38263827
Fortran::semantics::Scope &scope =
38273828
semanticsContext.FindScope(routineConstruct.source);
@@ -3833,7 +3834,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
38333834
: nullptr};
38343835
if (subpDetails && subpDetails->isInterface()) {
38353836
funcName = converter.mangleName(*progUnit.symbol());
3836-
funcOp = builder.getNamedFunction(mod, funcName);
3837+
funcOp =
3838+
builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
38373839
} else {
38383840
funcOp = builder.getFunction();
38393841
funcName = funcOp.getName();

0 commit comments

Comments
 (0)