Skip to content

[flang] add extra component information in fir.type_info #96746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
fir::StringLitOp createStringLitOp(mlir::Location loc,
llvm::StringRef string);

std::pair<fir::TypeInfoOp, mlir::OpBuilder::InsertPoint>
createTypeInfoOp(mlir::Location loc, fir::RecordType recordType,
fir::RecordType parentType);

//===--------------------------------------------------------------------===//
// Linkage helpers (inline). The default linkage is external.
//===--------------------------------------------------------------------===//
Expand Down
24 changes: 23 additions & 1 deletion flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2956,7 +2956,10 @@ def fir_TypeInfoOp : fir_Op<"type_info",

let hasVerifier = 1;

let regions = (region MaxSizedRegion<1>:$dispatch_table);
let regions = (region
MaxSizedRegion<1>:$dispatch_table,
MaxSizedRegion<1>:$component_info
);

let builders = [
OpBuilder<(ins "fir::RecordType":$type, "fir::RecordType":$parent_type,
Expand All @@ -2967,6 +2970,7 @@ def fir_TypeInfoOp : fir_Op<"type_info",
$sym_name (`noinit` $no_init^)? (`nodestroy` $no_destroy^)?
(`nofinal` $no_final^)? (`extends` $parent_type^)? attr-dict `:` $type
(`dispatch_table` $dispatch_table^)?
(`component_info` $component_info^)?
}];

let extraClassDeclaration = [{
Expand Down Expand Up @@ -3010,6 +3014,24 @@ def fir_DTEntryOp : fir_Op<"dt_entry", [HasParent<"TypeInfoOp">]> {
}];
}

def fir_DTComponentOp : fir_Op<"dt_component", [HasParent<"TypeInfoOp">]> {
let summary = "define extra information about a component inside fir.type_info";

let description = [{
```
fir.dt_component i lbs [-1,2] init @init_val
```
}];

let arguments = (ins
StrAttr:$name,
OptionalAttr<DenseI64ArrayAttr>:$lower_bounds,
OptionalAttr<FlatSymbolRefAttr>:$init_val
);

let assemblyFormat = "$name (`lbs` $lower_bounds^)? (`init` $init_val^)? attr-dict";
}

def fir_AbsentOp : fir_OneResultOp<"absent", [NoMemoryEffect]> {
let summary = "create value to be passed for absent optional function argument";
let description = [{
Expand Down
6 changes: 6 additions & 0 deletions flang/include/flang/Optimizer/Support/InternalNames.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <optional>

static constexpr llvm::StringRef typeDescriptorSeparator = ".dt.";
static constexpr llvm::StringRef componentInitSeparator = ".di.";
static constexpr llvm::StringRef bindingTableSeparator = ".v.";
static constexpr llvm::StringRef boxprocSuffix = "UnboxProc";

Expand Down Expand Up @@ -156,6 +157,11 @@ struct NameUniquer {
static std::string
getTypeDescriptorBindingTableName(llvm::StringRef mangledTypeName);

/// Given a mangled derived type name and a component name, get the name of
/// the global object containing the component default initialization.
static std::string getComponentInitName(llvm::StringRef mangledTypeName,
llvm::StringRef componentName);

/// Remove markers that have been added when doing partial type
/// conversions. mlir::Type cannot be mutated in a pass, so new
/// fir::RecordType must be created when lowering member types.
Expand Down
23 changes: 23 additions & 0 deletions flang/include/flang/Optimizer/Support/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,29 @@ inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
" in " + intrinsicName);
}

/// Find the fir.type_info that was created for this \p recordType in \p module,
/// if any. \p symbolTable can be provided to speed-up the lookup. This tool
/// will match record type even if they have been "altered" in type conversion
/// passes.
fir::TypeInfoOp
lookupTypeInfoOp(fir::RecordType recordType, mlir::ModuleOp module,
const mlir::SymbolTable *symbolTable = nullptr);

/// Find the fir.type_info named \p name in \p module, if any. \p symbolTable
/// can be provided to speed-up the lookup. Prefer using the equivalent with a
/// RecordType argument unless it is certain \p name has not been altered by a
/// pass rewriting fir.type (see NameUniquer::dropTypeConversionMarkers).
fir::TypeInfoOp
lookupTypeInfoOp(llvm::StringRef name, mlir::ModuleOp module,
const mlir::SymbolTable *symbolTable = nullptr);

/// Returns all lower bounds of \p component if it is an array component of \p
/// recordType with non default lower bounds. Returns nullopt if this is not an
/// array componnet of \p recordType or if its lower bounds are all ones.
std::optional<llvm::ArrayRef<int64_t>> getComponentLowerBoundsIfNonDefault(
fir::RecordType recordType, llvm::StringRef component,
mlir::ModuleOp module, const mlir::SymbolTable *symbolTable = nullptr);

} // namespace fir

#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
119 changes: 104 additions & 15 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,71 @@ struct ConstructContext {
bool pushedScope = false; // was a scoped pushed for this construct?
};

/// Helper to gather the lower bounds of array components with non deferred
/// shape when they are not all ones. Return an empty array attribute otherwise.
static mlir::DenseI64ArrayAttr
gatherComponentNonDefaultLowerBounds(mlir::Location loc,
mlir::MLIRContext *mlirContext,
const Fortran::semantics::Symbol &sym) {
if (Fortran::semantics::IsAllocatableOrObjectPointer(&sym))
return {};
mlir::DenseI64ArrayAttr lbs_attr;
if (const auto *objDetails =
sym.detailsIf<Fortran::semantics::ObjectEntityDetails>()) {
llvm::SmallVector<std::int64_t> lbs;
bool hasNonDefaultLbs = false;
for (const Fortran::semantics::ShapeSpec &bounds : objDetails->shape())
if (auto lb = bounds.lbound().GetExplicit()) {
if (auto constant = Fortran::evaluate::ToInt64(*lb)) {
hasNonDefaultLbs |= (*constant != 1);
lbs.push_back(*constant);
} else {
TODO(loc, "generate fir.dt_component for length parametrized derived "
"types");
}
}
if (hasNonDefaultLbs) {
assert(static_cast<int>(lbs.size()) == sym.Rank() &&
"expected component bounds to be constant or deferred");
lbs_attr = mlir::DenseI64ArrayAttr::get(mlirContext, lbs);
}
}
return lbs_attr;
}

// Helper class to generate name of fir.global containing component explicit
// default value for objects, and initial procedure target for procedure pointer
// components.
static mlir::FlatSymbolRefAttr gatherComponentInit(
mlir::Location loc, Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym, fir::RecordType derivedType) {
mlir::MLIRContext *mlirContext = &converter.getMLIRContext();
// Return procedure target mangled name for procedure pointer components.
if (const auto *procPtr =
sym.detailsIf<Fortran::semantics::ProcEntityDetails>()) {
if (std::optional<const Fortran::semantics::Symbol *> maybeInitSym =
procPtr->init()) {
// So far, do not make distinction between p => NULL() and p without init,
// f18 always initialize pointers to NULL anyway.
if (!*maybeInitSym)
return {};
return mlir::FlatSymbolRefAttr::get(mlirContext,
converter.mangleName(**maybeInitSym));
}
}

const auto *objDetails =
sym.detailsIf<Fortran::semantics::ObjectEntityDetails>();
if (!objDetails || !objDetails->init().has_value())
return {};
// Object component initial value. Semantic package component object default
// value into compiler generated symbols that are lowered as read-only
// fir.global. Get the name of this global.
std::string name = fir::NameUniquer::getComponentInitName(
derivedType.getName(), toStringRef(sym.name()));
return mlir::FlatSymbolRefAttr::get(mlirContext, name);
}

/// Helper class to generate the runtime type info global data and the
/// fir.type_info operations that contain the dipatch tables (if any).
/// The type info global data is required to describe the derived type to the
Expand Down Expand Up @@ -213,15 +278,14 @@ class TypeInfoConverter {
parentType = mlir::cast<fir::RecordType>(converter.genType(*parent));

fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::ModuleOp module = builder.getModule();
fir::TypeInfoOp dt =
module.lookupSymbol<fir::TypeInfoOp>(info.type.getName());
if (dt)
return; // Already created.
auto insertPt = builder.saveInsertionPoint();
builder.setInsertionPoint(module.getBody(), module.getBody()->end());
dt = builder.create<fir::TypeInfoOp>(info.loc, info.type, parentType);

fir::TypeInfoOp dt;
mlir::OpBuilder::InsertPoint insertPointIfCreated;
std::tie(dt, insertPointIfCreated) =
builder.createTypeInfoOp(info.loc, info.type, parentType);
if (!insertPointIfCreated.isSet())
return; // fir.type_info was already built in a previous call.

// Set init, destroy, and nofinal attributes.
if (!info.typeSpec.HasDefaultInitialization(/*ignoreAllocatable=*/false,
/*ignorePointer=*/false))
dt->setAttr(dt.getNoInitAttrName(), builder.getUnitAttr());
Expand All @@ -230,13 +294,12 @@ class TypeInfoConverter {
if (!Fortran::semantics::MayRequireFinalization(info.typeSpec))
dt->setAttr(dt.getNoFinalAttrName(), builder.getUnitAttr());

const Fortran::semantics::Scope *scope = info.typeSpec.scope();
if (!scope)
scope = info.typeSpec.typeSymbol().scope();
assert(scope && "failed to find type scope");
const Fortran::semantics::Scope &derivedScope =
DEREF(info.typeSpec.GetScope());

// Fill binding table region if the derived type has bindings.
Fortran::semantics::SymbolVector bindings =
Fortran::semantics::CollectBindings(*scope);
Fortran::semantics::CollectBindings(derivedScope);
if (!bindings.empty()) {
builder.createBlock(&dt.getDispatchTable());
for (const Fortran::semantics::SymbolRef &binding : bindings) {
Expand All @@ -252,7 +315,33 @@ class TypeInfoConverter {
}
builder.create<fir::FirEndOp>(info.loc);
}
builder.restoreInsertionPoint(insertPt);
// Gather info about components that is not reflected in fir.type and may be
// needed later: component initial values and array component non default
// lower bounds.
mlir::Block *componentInfo = nullptr;
for (const auto &componentName :
info.typeSpec.typeSymbol()
.get<Fortran::semantics::DerivedTypeDetails>()
.componentNames()) {
auto scopeIter = derivedScope.find(componentName);
assert(scopeIter != derivedScope.cend() &&
"failed to find derived type component symbol");
const Fortran::semantics::Symbol &component = scopeIter->second.get();
mlir::FlatSymbolRefAttr init_val =
gatherComponentInit(info.loc, converter, component, info.type);
mlir::DenseI64ArrayAttr lbs = gatherComponentNonDefaultLowerBounds(
info.loc, builder.getContext(), component);
if (init_val || lbs) {
if (!componentInfo)
componentInfo = builder.createBlock(&dt.getComponentInfo());
auto compName = mlir::StringAttr::get(builder.getContext(),
toStringRef(component.name()));
builder.create<fir::DTComponentOp>(info.loc, compName, lbs, init_val);
}
}
if (componentInfo)
builder.create<fir::FirEndOp>(info.loc);
builder.restoreInsertionPoint(insertPointIfCreated);
}

/// Store the front-end data that will be required to generate the type info
Expand Down
17 changes: 17 additions & 0 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/FatalError.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Support/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
Expand Down Expand Up @@ -364,6 +365,22 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
return glob;
}

std::pair<fir::TypeInfoOp, mlir::OpBuilder::InsertPoint>
fir::FirOpBuilder::createTypeInfoOp(mlir::Location loc,
fir::RecordType recordType,
fir::RecordType parentType) {
mlir::ModuleOp module = getModule();
if (fir::TypeInfoOp typeInfo =
fir::lookupTypeInfoOp(recordType.getName(), module, symbolTable))
return {typeInfo, InsertPoint{}};
InsertPoint insertPoint = saveInsertionPoint();
setInsertionPoint(module.getBody(), module.getBody()->end());
auto typeInfo = create<fir::TypeInfoOp>(loc, recordType, parentType);
if (symbolTable)
symbolTable->insert(typeInfo);
return {typeInfo, insertPoint};
}

mlir::Value fir::FirOpBuilder::convertWithSemantics(
mlir::Location loc, mlir::Type toTy, mlir::Value val,
bool allowCharacterConversion, bool allowRebox) {
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,7 @@ void fir::TypeInfoOp::build(mlir::OpBuilder &builder,
fir::RecordType parentType,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
result.addRegion();
result.addRegion();
result.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(type.getName()));
result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type));
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Support/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_flang_library(FIRSupport
DataLayout.cpp
InitFIR.cpp
InternalNames.cpp
Utils.cpp

DEPENDS
FIROpsIncGen
Expand Down
9 changes: 9 additions & 0 deletions flang/lib/Optimizer/Support/InternalNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,15 @@ std::string fir::NameUniquer::getTypeDescriptorBindingTableName(
return getDerivedTypeObjectName(mangledTypeName, bindingTableSeparator);
}

std::string
fir::NameUniquer::getComponentInitName(llvm::StringRef mangledTypeName,
llvm::StringRef componentName) {

std::string prefix =
getDerivedTypeObjectName(mangledTypeName, componentInitSeparator);
return prefix + "." + componentName.str();
}

llvm::StringRef
fir::NameUniquer::dropTypeConversionMarkers(llvm::StringRef mangledTypeName) {
if (mangledTypeName.ends_with(boxprocSuffix))
Expand Down
52 changes: 52 additions & 0 deletions flang/lib/Optimizer/Support/Utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===-- Utils.cpp ---------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Support/Utils.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/InternalNames.h"

fir::TypeInfoOp fir::lookupTypeInfoOp(fir::RecordType recordType,
mlir::ModuleOp module,
const mlir::SymbolTable *symbolTable) {
// fir.type_info was created with the mangled name of the derived type.
// It is the same as the name in the related fir.type, except when a pass
// lowered the fir.type (e.g., when lowering fir.boxproc type if the type has
// pointer procedure components), in which case suffix may have been added to
// the fir.type name. Get rid of them when looking up for the fir.type_info.
llvm::StringRef originalMangledTypeName =
fir::NameUniquer::dropTypeConversionMarkers(recordType.getName());
return fir::lookupTypeInfoOp(originalMangledTypeName, module, symbolTable);
}

fir::TypeInfoOp fir::lookupTypeInfoOp(llvm::StringRef name,
mlir::ModuleOp module,
const mlir::SymbolTable *symbolTable) {
if (symbolTable)
if (auto typeInfo = symbolTable->lookup<fir::TypeInfoOp>(name))
return typeInfo;
return module.lookupSymbol<fir::TypeInfoOp>(name);
}

std::optional<llvm::ArrayRef<int64_t>> fir::getComponentLowerBoundsIfNonDefault(
fir::RecordType recordType, llvm::StringRef component,
mlir::ModuleOp module, const mlir::SymbolTable *symbolTable) {
fir::TypeInfoOp typeInfo =
fir::lookupTypeInfoOp(recordType, module, symbolTable);
if (!typeInfo || typeInfo.getComponentInfo().empty())
return std::nullopt;
for (auto componentInfo :
typeInfo.getComponentInfo().getOps<fir::DTComponentOp>())
if (componentInfo.getName() == component)
return componentInfo.getLowerBounds();
return std::nullopt;
}
7 changes: 7 additions & 0 deletions flang/test/Fir/fir-ops.fir
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,13 @@ fir.type_info @dispatch_tbl : !fir.type<dispatch_tbl{i:i32}> dispatch_table {
// CHECK-LABEL: fir.type_info @test_type_info noinit nodestroy nofinal extends !fir.type<parent{i:i32}> : !fir.type<test_type_info{i:i32,j:f32}>
fir.type_info @test_type_info noinit nodestroy nofinal extends !fir.type<parent{i:i32}> : !fir.type<test_type_info{i:i32,j:f32}>

// CHECK-LABEL: fir.type_info @cpinfo : !fir.type<cpinfo{comp_i:!fir.array<10x20xi32>}> component_info {
// CHECK: fir.dt_component "component_info" lbs [2, 3]
// CHECK: }
fir.type_info @cpinfo : !fir.type<cpinfo{comp_i:!fir.array<10x20xi32>}> component_info {
fir.dt_component "component_info" lbs [2, 3]
}

// CHECK-LABEL: func @compare_complex(
// CHECK-SAME: [[VAL_151:%.*]]: !fir.complex<16>, [[VAL_152:%.*]]: !fir.complex<16>) {
func.func @compare_complex(%a : !fir.complex<16>, %b : !fir.complex<16>) {
Expand Down
Loading
Loading