Skip to content

[flang] lower select rank #93967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions flang/include/flang/Lower/ConvertExprToHLFIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ convertExprToHLFIR(mlir::Location loc, Fortran::lower::AbstractConverter &,
const Fortran::lower::SomeExpr &, Fortran::lower::SymMap &,
Fortran::lower::StatementContext &);

inline fir::ExtendedValue
translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity,
Fortran::lower::StatementContext &context) {
inline fir::ExtendedValue translateToExtendedValue(
mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity entity,
Fortran::lower::StatementContext &context, bool contiguityHint = false) {
auto [exv, exvCleanup] =
hlfir::translateToExtendedValue(loc, builder, entity);
hlfir::translateToExtendedValue(loc, builder, entity, contiguityHint);
if (exvCleanup)
context.attachCleanup(*exvCleanup);
return exv;
Expand Down
10 changes: 8 additions & 2 deletions flang/lib/Evaluate/check-expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -801,8 +801,14 @@ class IsContiguousHelper
// simple contiguity to allow their use in contexts like
// data targets in pointer assignments with remapping.
return true;
} else if (ultimate.has<semantics::AssocEntityDetails>()) {
return Base::operator()(ultimate); // use expr
} else if (const auto *details{
ultimate.detailsIf<semantics::AssocEntityDetails>()}) {
// RANK(*) associating entity is contiguous.
if (details->IsAssumedSize()) {
return true;
} else {
return Base::operator()(ultimate); // use expr
}
} else if (semantics::IsPointer(ultimate) ||
semantics::IsAssumedShape(ultimate) || IsAssumedRank(ultimate)) {
return std::nullopt;
Expand Down
204 changes: 198 additions & 6 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ struct ConstructContext {

Fortran::lower::pft::Evaluation &eval; // construct eval
Fortran::lower::StatementContext &stmtCtx; // construct exit code
std::optional<hlfir::Entity> selector; // construct selector, if any.
bool pushedScope = false; // was a scoped pushed for this construct?
};

/// Helper class to generate the runtime type info global data and the
Expand Down Expand Up @@ -1468,6 +1470,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
void popActiveConstruct() {
assert(!activeConstructStack.empty() && "invalid active construct stack");
activeConstructStack.back().eval.activeConstruct = false;
if (activeConstructStack.back().pushedScope)
localSymbols.popScope();
activeConstructStack.pop_back();
}

Expand Down Expand Up @@ -2181,7 +2185,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}

void genFIR(const Fortran::parser::CaseConstruct &) {
void genCaseOrRankConstruct() {
Fortran::lower::pft::Evaluation &eval = getEval();
Fortran::lower::StatementContext stmtCtx;
pushActiveConstruct(eval, stmtCtx);
Expand All @@ -2203,6 +2207,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
popActiveConstruct();
}
void genFIR(const Fortran::parser::CaseConstruct &) {
genCaseOrRankConstruct();
}

template <typename A>
void genNestedStatement(const Fortran::parser::Statement<A> &stmt) {
Expand Down Expand Up @@ -3032,13 +3039,198 @@ class FirConverter : public Fortran::lower::AbstractConverter {

void genFIR(const Fortran::parser::SelectRankConstruct &selectRankConstruct) {
setCurrentPositionAt(selectRankConstruct);
TODO(toLocation(), "coarray: SelectRankConstruct");
genCaseOrRankConstruct();
}
void genFIR(const Fortran::parser::SelectRankStmt &) {
TODO(toLocation(), "coarray: SelectRankStmt");

void genFIR(const Fortran::parser::SelectRankStmt &selectRankStmt) {
// Generate a fir.select_case with the selector rank. The RANK(*) case,
// if any, is handles with a conditional branch before the fir.select_case.
mlir::Type rankType = builder->getIntegerType(8);
mlir::MLIRContext *context = builder->getContext();
mlir::Location loc = toLocation();
// Build block list for fir.select_case, and identify RANK(*) block, if any.
// Default block must be placed last in the fir.select_case block list.
mlir::Block *rankStarBlock = nullptr;
Fortran::lower::pft::Evaluation &eval = getEval();
mlir::Block *defaultBlock = eval.parentConstruct->constructExit->block;
llvm::SmallVector<mlir::Attribute> attrList;
llvm::SmallVector<mlir::Value> valueList;
llvm::SmallVector<mlir::Block *> blockList;
for (Fortran::lower::pft::Evaluation *e = eval.controlSuccessor; e;
e = e->controlSuccessor) {
if (const auto *rankCaseStmt =
e->getIf<Fortran::parser::SelectRankCaseStmt>()) {
const auto &rank = std::get<Fortran::parser::SelectRankCaseStmt::Rank>(
rankCaseStmt->t);
assert(e->block && "missing SelectRankCaseStmt block");
std::visit(
Fortran::common::visitors{
[&](const Fortran::parser::ScalarIntConstantExpr &rankExpr) {
blockList.emplace_back(e->block);
attrList.emplace_back(fir::PointIntervalAttr::get(context));
std::optional<std::int64_t> rankCst =
Fortran::evaluate::ToInt64(
Fortran::semantics::GetExpr(rankExpr));
assert(rankCst.has_value() &&
"rank expr must be constant integer");
valueList.emplace_back(
builder->createIntegerConstant(loc, rankType, *rankCst));
},
[&](const Fortran::parser::Star &) {
rankStarBlock = e->block;
},
[&](const Fortran::parser::Default &) {
defaultBlock = e->block;
}},
rank.u);
}
}
attrList.push_back(mlir::UnitAttr::get(context));
blockList.push_back(defaultBlock);

// Lower selector.
assert(!activeConstructStack.empty() && "must be inside construct");
assert(!activeConstructStack.back().selector &&
"selector should not yet be set");
Fortran::lower::StatementContext &stmtCtx =
activeConstructStack.back().stmtCtx;
const Fortran::lower::SomeExpr *selectorExpr =
std::visit([](const auto &x) { return Fortran::semantics::GetExpr(x); },
std::get<Fortran::parser::Selector>(selectRankStmt.t).u);
assert(selectorExpr && "failed to retrieve selector expr");
hlfir::Entity selector = Fortran::lower::convertExprToHLFIR(
loc, *this, *selectorExpr, localSymbols, stmtCtx);
activeConstructStack.back().selector = selector;

// Deal with assumed-size first. They must fall into RANK(*) if present, or
// the default case (F'2023 11.1.10.2.). The selector cannot be an
// assumed-size if it is allocatable or pointer, so the check is skipped.
if (!Fortran::evaluate::IsAllocatableOrPointerObject(*selectorExpr)) {
mlir::Value isAssumedSize = builder->create<fir::IsAssumedSizeOp>(
loc, builder->getI1Type(), selector);
// Create new block to hold the fir.select_case for the non assumed-size
// cases.
mlir::Block *selectCaseBlock = insertBlock(blockList[0]);
mlir::Block *assumedSizeBlock =
rankStarBlock ? rankStarBlock : defaultBlock;
builder->create<mlir::cf::CondBranchOp>(loc, isAssumedSize,
assumedSizeBlock, std::nullopt,
selectCaseBlock, std::nullopt);
startBlock(selectCaseBlock);
}
// Create fir.select_case for the other rank cases.
mlir::Value rank = builder->create<fir::BoxRankOp>(loc, rankType, selector);
stmtCtx.finalizeAndReset();
builder->create<fir::SelectCaseOp>(loc, rank, attrList, valueList,
blockList);
}

// Get associating entity symbol inside case statement scope.
static const Fortran::semantics::Symbol &
getAssociatingEntitySymbol(const Fortran::semantics::Scope &scope) {
const Fortran::semantics::Symbol *assocSym = nullptr;
for (const auto &sym : scope.GetSymbols()) {
if (sym->has<Fortran::semantics::AssocEntityDetails>()) {
assert(!assocSym &&
"expect only one associating entity symbol in this scope");
assocSym = &*sym;
}
}
assert(assocSym && "should contain associating entity symbol");
return *assocSym;
}
void genFIR(const Fortran::parser::SelectRankCaseStmt &) {
TODO(toLocation(), "coarray: SelectRankCaseStmt");

void genFIR(const Fortran::parser::SelectRankCaseStmt &stmt) {
assert(!activeConstructStack.empty() &&
"must be inside select rank construct");
// Pop previous associating entity mapping, if any, and push scope for new
// mapping.
if (activeConstructStack.back().pushedScope)
localSymbols.popScope();
localSymbols.pushScope();
activeConstructStack.back().pushedScope = true;
const Fortran::semantics::Symbol &assocEntitySymbol =
getAssociatingEntitySymbol(
bridge.getSemanticsContext().FindScope(getEval().position));
const auto &details =
assocEntitySymbol.get<Fortran::semantics::AssocEntityDetails>();
assert(!activeConstructStack.empty() &&
activeConstructStack.back().selector.has_value() &&
"selector must have been created");
// Get lowered value for the selector.
hlfir::Entity selector = *activeConstructStack.back().selector;
assert(selector.isVariable() && "assumed-rank selector are variables");
// Cook selector mlir::Value according to rank case and map it to
// associating entity symbol.
Fortran::lower::StatementContext stmtCtx;
mlir::Location loc = toLocation();
if (details.IsAssumedRank()) {
fir::ExtendedValue selectorExv = Fortran::lower::translateToExtendedValue(
loc, *builder, selector, stmtCtx);
addSymbol(assocEntitySymbol, selectorExv);
} else if (details.IsAssumedSize()) {
// Create rank-1 assumed-size from descriptor. Assumed-size are contiguous
// so a new entity can be built from scratch using the base address, type
// parameters and dynamic type. The selector cannot be a
// POINTER/ALLOCATBLE as per F'2023 C1160.
fir::ExtendedValue newExv;
llvm::SmallVector assumeSizeExtents{
builder->createMinusOneInteger(loc, builder->getIndexType())};
mlir::Value baseAddr =
hlfir::genVariableRawAddress(loc, *builder, selector);
mlir::Type eleType =
fir::unwrapSequenceType(fir::unwrapRefType(baseAddr.getType()));
mlir::Type rank1Type =
fir::ReferenceType::get(builder->getVarLenSeqTy(eleType, 1));
baseAddr = builder->createConvert(loc, rank1Type, baseAddr);
if (selector.isCharacter()) {
mlir::Value len = hlfir::genCharLength(loc, *builder, selector);
newExv = fir::CharArrayBoxValue{baseAddr, len, assumeSizeExtents};
} else if (selector.isDerivedWithLengthParameters()) {
TODO(loc, "RANK(*) with parameterized derived type selector");
} else if (selector.isPolymorphic()) {
TODO(loc, "RANK(*) with polymorphic selector");
} else {
// Simple intrinsic or derived type.
newExv = fir::ArrayBoxValue{baseAddr, assumeSizeExtents};
}
addSymbol(assocEntitySymbol, newExv);
} else {
int rank = details.rank().value();
auto boxTy =
mlir::cast<fir::BaseBoxType>(fir::unwrapRefType(selector.getType()));
mlir::Type newBoxType = boxTy.getBoxTypeWithNewShape(rank);
if (fir::isa_ref_type(selector.getType()))
newBoxType = fir::ReferenceType::get(newBoxType);
// Give rank info to value via cast, and get rid of the box if not needed
// (simple scalars, contiguous arrays... This is done by
// translateVariableToExtendedValue).
hlfir::Entity rankedBox{
builder->createConvert(loc, newBoxType, selector)};
bool isSimplyContiguous = Fortran::evaluate::IsSimplyContiguous(
assocEntitySymbol, getFoldingContext());
fir::ExtendedValue newExv = Fortran::lower::translateToExtendedValue(
loc, *builder, rankedBox, stmtCtx, isSimplyContiguous);

// Non deferred length parameters of character allocatable/pointer
// MutableBoxValue should be properly set before binding it to a symbol in
// order to get correct assignment semantics.
if (const fir::MutableBoxValue *mutableBox =
newExv.getBoxOf<fir::MutableBoxValue>()) {
if (selector.isCharacter()) {
auto dynamicType =
Fortran::evaluate::DynamicType::From(assocEntitySymbol);
if (!dynamicType.value().HasDeferredTypeParameter()) {
llvm::SmallVector<mlir::Value> lengthParams;
hlfir::genLengthParameters(loc, *builder, selector, lengthParams);
newExv = fir::MutableBoxValue{rankedBox, lengthParams,
mutableBox->getMutableProperties()};
}
}
}
addSymbol(assocEntitySymbol, newExv);
}
// Statements inside rank case are lowered by SelectRankConstruct visit.
}

void genFIR(const Fortran::parser::SelectTypeConstruct &selectTypeConstruct) {
Expand Down
13 changes: 9 additions & 4 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ genLboundsAndExtentsFromBox(mlir::Location loc, fir::FirOpBuilder &builder,
static llvm::SmallVector<mlir::Value>
getNonDefaultLowerBounds(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity) {
assert(!entity.isAssumedRank() &&
"cannot compute assumed rank bounds statically");
if (!entity.mayHaveNonDefaultLowerBounds())
return {};
if (auto varIface = entity.getIfVariableInterface()) {
Expand Down Expand Up @@ -889,11 +891,14 @@ static fir::ExtendedValue translateVariableToExtendedValue(
fir::MutableProperties{});

if (mlir::isa<fir::BaseBoxType>(base.getType())) {
bool contiguous = variable.isSimplyContiguous() || contiguousHint;
const bool contiguous = variable.isSimplyContiguous() || contiguousHint;
const bool isAssumedRank = variable.isAssumedRank();
if (!contiguous || variable.isPolymorphic() ||
variable.isDerivedWithLengthParameters() || variable.isOptional()) {
llvm::SmallVector<mlir::Value> nonDefaultLbounds =
getNonDefaultLowerBounds(loc, builder, variable);
variable.isDerivedWithLengthParameters() || variable.isOptional() ||
isAssumedRank) {
llvm::SmallVector<mlir::Value> nonDefaultLbounds;
if (!isAssumedRank)
nonDefaultLbounds = getNonDefaultLowerBounds(loc, builder, variable);
return fir::BoxValue(base, nonDefaultLbounds,
getExplicitTypeParams(variable));
}
Expand Down
Loading
Loading