Skip to content

[flang][NFC] use mlir::SymbolTable in lowering #86673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions flang/include/flang/Lower/AbstractConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include "mlir/IR/Operation.h"
#include "llvm/ADT/ArrayRef.h"

namespace mlir {
class SymbolTable;
}

namespace fir {
class KindMapping;
class FirOpBuilder;
Expand Down Expand Up @@ -305,6 +309,15 @@ class AbstractConverter {
virtual Fortran::lower::SymbolBox
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;

/// Return the mlir::SymbolTable associated to the ModuleOp.
/// Look-ups are faster using it than using module.lookup<>,
/// but the module op should be queried in case of failure
/// because this symbol table is not guaranteed to contain
/// all the symbols from the ModuleOp (the symbol table should
/// always be provided to the builder helper creating globals and
/// functions in order to be in sync).
virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;

private:
/// Options controlling lowering behavior.
const Fortran::lower::LoweringOptions &loweringOptions;
Expand Down
66 changes: 33 additions & 33 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
#include <optional>
#include <utility>

namespace mlir {
class SymbolTable;
}

namespace fir {
class AbstractArrayBox;
class ExtendedValue;
Expand All @@ -42,8 +46,10 @@ class BoxValue;
/// patterns.
class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
public:
explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap)
: OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)} {}
explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap,
mlir::SymbolTable *symbolTable = nullptr)
: OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)},
symbolTable{symbolTable} {}
explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap)
: OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} {
setListener(this);
Expand All @@ -69,13 +75,14 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
// The listener self-reference has to be updated in case of copy-construction.
FirOpBuilder(const FirOpBuilder &other)
: OpBuilder(other), OpBuilder::Listener(), kindMap{other.kindMap},
fastMathFlags{other.fastMathFlags} {
fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
setListener(this);
}

FirOpBuilder(FirOpBuilder &&other)
: OpBuilder(other), OpBuilder::Listener(),
kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags} {
kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags},
symbolTable{other.symbolTable} {
setListener(this);
}

Expand All @@ -95,6 +102,9 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// Get a reference to the kind map.
const fir::KindMapping &getKindMap() { return kindMap; }

/// Get func.func/fir.global symbol table attached to this builder if any.
mlir::SymbolTable *getMLIRSymbolTable() { return symbolTable; }

/// Get the default integer type
[[maybe_unused]] mlir::IntegerType getDefaultIntegerType() {
return getIntegerType(
Expand Down Expand Up @@ -280,24 +290,27 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// Get a function by name. If the function exists in the current module, it
/// is returned. Otherwise, a null FuncOp is returned.
mlir::func::FuncOp getNamedFunction(llvm::StringRef name) {
return getNamedFunction(getModule(), name);
return getNamedFunction(getModule(), getMLIRSymbolTable(), name);
}
static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
llvm::StringRef name);
static mlir::func::FuncOp
getNamedFunction(mlir::ModuleOp module, const mlir::SymbolTable *symbolTable,
llvm::StringRef name);

/// Get a function by symbol name. The result will be null if there is no
/// function with the given symbol in the module.
mlir::func::FuncOp getNamedFunction(mlir::SymbolRefAttr symbol) {
return getNamedFunction(getModule(), symbol);
return getNamedFunction(getModule(), getMLIRSymbolTable(), symbol);
}
static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
mlir::SymbolRefAttr symbol);
static mlir::func::FuncOp
getNamedFunction(mlir::ModuleOp module, const mlir::SymbolTable *symbolTable,
mlir::SymbolRefAttr symbol);

fir::GlobalOp getNamedGlobal(llvm::StringRef name) {
return getNamedGlobal(getModule(), name);
return getNamedGlobal(getModule(), getMLIRSymbolTable(), name);
}

static fir::GlobalOp getNamedGlobal(mlir::ModuleOp module,
const mlir::SymbolTable *symbolTable,
llvm::StringRef name);

/// Lazy creation of fir.convert op.
Expand All @@ -313,35 +326,18 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// result of the load if it was created, otherwise return \p val
mlir::Value loadIfRef(mlir::Location loc, mlir::Value val);

/// Create a new FuncOp. If the function may have already been created, use
/// `addNamedFunction` instead.
/// Determine if the named function is already in the module. Return the
/// instance if found, otherwise add a new named function to the module.
mlir::func::FuncOp createFunction(mlir::Location loc, llvm::StringRef name,
mlir::FunctionType ty) {
return createFunction(loc, getModule(), name, ty);
return createFunction(loc, getModule(), name, ty, getMLIRSymbolTable());
}

static mlir::func::FuncOp createFunction(mlir::Location loc,
mlir::ModuleOp module,
llvm::StringRef name,
mlir::FunctionType ty);

/// Determine if the named function is already in the module. Return the
/// instance if found, otherwise add a new named function to the module.
mlir::func::FuncOp addNamedFunction(mlir::Location loc, llvm::StringRef name,
mlir::FunctionType ty) {
if (auto func = getNamedFunction(name))
return func;
return createFunction(loc, name, ty);
}

static mlir::func::FuncOp addNamedFunction(mlir::Location loc,
mlir::ModuleOp module,
llvm::StringRef name,
mlir::FunctionType ty) {
if (auto func = getNamedFunction(module, name))
return func;
return createFunction(loc, module, name, ty);
}
mlir::FunctionType ty,
mlir::SymbolTable *);

/// Cast the input value to IndexType.
mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) {
Expand Down Expand Up @@ -515,6 +511,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// FastMathFlags that need to be set for operations that support
/// mlir::arith::FastMathAttr.
mlir::arith::FastMathFlags fastMathFlags{};

/// fir::GlobalOp and func::FuncOp symbol table to speed-up
/// lookups.
mlir::SymbolTable *symbolTable = nullptr;
};

} // namespace fir
Expand Down
19 changes: 11 additions & 8 deletions flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,19 @@ inline bool pureCall(mlir::Operation *op) {
/// Get or create a FuncOp in a module.
///
/// If `module` already contains FuncOp `name`, it is returned. Otherwise, a new
/// FuncOp is created, and that new FuncOp is returned.
mlir::func::FuncOp
createFuncOp(mlir::Location loc, mlir::ModuleOp module, llvm::StringRef name,
mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs = {});

/// Get or create a GlobalOp in a module.
/// FuncOp is created, and that new FuncOp is returned. A symbol table can
/// be provided to speed-up the lookups.
mlir::func::FuncOp createFuncOp(mlir::Location loc, mlir::ModuleOp module,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
const mlir::SymbolTable *symbolTable = nullptr);

/// Get or create a GlobalOp in a module. A symbol table can be provided to
/// speed-up the lookups.
fir::GlobalOp createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
llvm::StringRef name, mlir::Type type,
llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
const mlir::SymbolTable *symbolTable = nullptr);

/// Attribute to mark Fortran entities with the CONTIGUOUS attribute.
constexpr llvm::StringRef getContiguousAttrName() { return "fir.contiguous"; }
Expand Down
23 changes: 17 additions & 6 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
public:
explicit FirConverter(Fortran::lower::LoweringBridge &bridge)
: Fortran::lower::AbstractConverter(bridge.getLoweringOptions()),
bridge{bridge}, foldingContext{bridge.createFoldingContext()} {}
bridge{bridge}, foldingContext{bridge.createFoldingContext()},
mlirSymbolTable{bridge.getModule()} {}
virtual ~FirConverter() = default;

/// Convert the PFT to FIR.
Expand Down Expand Up @@ -329,8 +330,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
builder = new fir::FirOpBuilder(bridge.getModule(),
bridge.getKindMap());
builder = new fir::FirOpBuilder(
bridge.getModule(), bridge.getKindMap(), &mlirSymbolTable);
Fortran::lower::genOpenACCRoutineConstruct(
*this, bridge.getSemanticsContext(), bridge.getModule(),
d.routine, accRoutineInfos);
Expand Down Expand Up @@ -1036,6 +1037,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return {};
}

mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }

/// Add the symbol to the local map and return `true`. If the symbol is
/// already in the map and \p forced is `false`, the map is not updated.
/// Instead the value `false` is returned.
Expand Down Expand Up @@ -4570,7 +4573,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
llvm::dbgs() << "\n");
Fortran::lower::CalleeInterface callee(funit, *this);
mlir::func::FuncOp func = callee.addEntryBlockAndMapArguments();
builder = new fir::FirOpBuilder(func, bridge.getKindMap());
builder =
new fir::FirOpBuilder(func, bridge.getKindMap(), &mlirSymbolTable);
assert(builder && "FirOpBuilder did not instantiate");
builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
builder->setInsertionPointToStart(&func.front());
Expand Down Expand Up @@ -4838,12 +4842,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// FIXME: get rid of the bogus function context and instantiate the
// globals directly into the module.
mlir::MLIRContext *context = &getMLIRContext();
mlir::SymbolTable *symbolTable = getMLIRSymbolTable();
mlir::func::FuncOp func = fir::FirOpBuilder::createFunction(
mlir::UnknownLoc::get(context), getModuleOp(),
fir::NameUniquer::doGenerated("Sham"),
mlir::FunctionType::get(context, std::nullopt, std::nullopt));
mlir::FunctionType::get(context, std::nullopt, std::nullopt),
symbolTable);
func.addEntryBlock();
builder = new fir::FirOpBuilder(func, bridge.getKindMap());
builder = new fir::FirOpBuilder(func, bridge.getKindMap(), symbolTable);
assert(builder && "FirOpBuilder did not instantiate");
builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
createGlobals();
Expand Down Expand Up @@ -5335,6 +5341,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// utilities to deal with procedure pointer components whose arguments have
/// the type of the containing derived type.
Fortran::lower::TypeConstructionStack typeConstructionStack;
/// MLIR symbol table of the fir.global/func.func operations. Note that it is
/// not guaranteed to contain all operations of the ModuleOp with Symbol
/// attribute since mlirSymbolTable must pro-actively be maintained when
/// new Symbol operations are created.
mlir::SymbolTable mlirSymbolTable;
};

} // namespace
Expand Down
9 changes: 6 additions & 3 deletions flang/lib/Lower/CallInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,11 +667,13 @@ void Fortran::lower::CallInterface<T>::declare() {
if (!side().isIndirectCall()) {
std::string name = side().getMangledName();
mlir::ModuleOp module = converter.getModuleOp();
func = fir::FirOpBuilder::getNamedFunction(module, name);
mlir::SymbolTable *symbolTable = converter.getMLIRSymbolTable();
func = fir::FirOpBuilder::getNamedFunction(module, symbolTable, name);
if (!func) {
mlir::Location loc = side().getCalleeLocation();
mlir::FunctionType ty = genFunctionType();
func = fir::FirOpBuilder::createFunction(loc, module, name, ty);
func =
fir::FirOpBuilder::createFunction(loc, module, name, ty, symbolTable);
if (const Fortran::semantics::Symbol *sym = side().getProcedureSymbol()) {
if (side().isMainProgram()) {
func->setAttr(fir::getSymbolAttrName(),
Expand Down Expand Up @@ -1644,7 +1646,8 @@ mlir::func::FuncOp Fortran::lower::getOrDeclareFunction(
Fortran::lower::AbstractConverter &converter) {
mlir::ModuleOp module = converter.getModuleOp();
std::string name = getProcMangledName(proc, converter);
mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(module, name);
mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(
module, converter.getMLIRSymbolTable(), name);
if (func)
return func;

Expand Down
6 changes: 4 additions & 2 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3809,7 +3809,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
std::string funcName;
if (name) {
funcName = converter.mangleName(*name->symbol);
funcOp = builder.getNamedFunction(mod, funcName);
funcOp =
builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
} else {
Fortran::semantics::Scope &scope =
semanticsContext.FindScope(routineConstruct.source);
Expand All @@ -3821,7 +3822,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
: nullptr};
if (subpDetails && subpDetails->isInterface()) {
funcName = converter.mangleName(*progUnit.symbol());
funcOp = builder.getNamedFunction(mod, funcName);
funcOp =
builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
} else {
funcOp = builder.getFunction();
funcName = funcOp.getName();
Expand Down
Loading