Skip to content

[flang] Fix lowering of host associated cray pointee symbols #86121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flang/include/flang/Lower/ConvertVariable.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
fir::FortranVariableFlagsEnum::None,
bool force = false);

/// For the given Cray pointee symbol return the corresponding
/// Cray pointer symbol. Assert if the pointer symbol cannot be found.
Fortran::semantics::SymbolRef getCrayPointer(Fortran::semantics::SymbolRef sym);
/// Given the Fortran type of a Cray pointee, return the fir.box type used to
/// track the cray pointee as Fortran pointer.
mlir::Type getCrayPointeeBoxType(mlir::Type);

} // namespace lower
} // namespace Fortran
Expand Down
3 changes: 3 additions & 0 deletions flang/include/flang/Semantics/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ const Symbol *FindExternallyVisibleObject(
// specific procedure of the same name, return it instead.
const Symbol &BypassGeneric(const Symbol &);

// Given a cray pointee symbol, returns the related cray pointer symbol.
const Symbol &GetCrayPointer(const Symbol &crayPointee);

using SomeExpr = evaluate::Expr<evaluate::SomeType>;

bool ExprHasTypeCategory(
Expand Down
5 changes: 3 additions & 2 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3995,11 +3995,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
sym->Rank() == 0) {
// get the corresponding Cray pointer

auto ptrSym = Fortran::lower::getCrayPointer(*sym);
const Fortran::semantics::Symbol &ptrSym =
Fortran::semantics::GetCrayPointer(*sym);
fir::ExtendedValue ptr =
getSymbolExtendedValue(ptrSym, nullptr);
mlir::Value ptrVal = fir::getBase(ptr);
mlir::Type ptrTy = genType(*ptrSym);
mlir::Type ptrTy = genType(ptrSym);

fir::ExtendedValue pte =
getSymbolExtendedValue(*sym, nullptr);
Expand Down
10 changes: 6 additions & 4 deletions flang/lib/Lower/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,8 @@ class ScalarExprLowering {
addr);
} else if (sym->test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
// get the corresponding Cray pointer
auto ptrSym = Fortran::lower::getCrayPointer(sym);
Fortran::semantics::SymbolRef ptrSym{
Fortran::semantics::GetCrayPointer(sym)};
ExtValue ptr = gen(ptrSym);
mlir::Value ptrVal = fir::getBase(ptr);
mlir::Type ptrTy = converter.genType(*ptrSym);
Expand Down Expand Up @@ -1537,8 +1538,8 @@ class ScalarExprLowering {
auto baseSym = getFirstSym(aref);
if (baseSym.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
// get the corresponding Cray pointer
auto ptrSym = Fortran::lower::getCrayPointer(baseSym);

Fortran::semantics::SymbolRef ptrSym{
Fortran::semantics::GetCrayPointer(baseSym)};
fir::ExtendedValue ptr = gen(ptrSym);
mlir::Value ptrVal = fir::getBase(ptr);
mlir::Type ptrTy = ptrVal.getType();
Expand Down Expand Up @@ -6946,7 +6947,8 @@ class ArrayExprLowering {
ComponentPath &components) {
mlir::Value ptrVal = nullptr;
if (x.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
auto ptrSym = Fortran::lower::getCrayPointer(x);
Fortran::semantics::SymbolRef ptrSym{
Fortran::semantics::GetCrayPointer(x)};
ExtValue ptr = converter.getSymbolExtendedValue(ptrSym);
ptrVal = fir::getBase(ptr);
}
Expand Down
9 changes: 8 additions & 1 deletion flang/lib/Lower/ConvertExprToHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class HlfirDesignatorBuilder {
// value of the Cray pointer variable.
fir::FirOpBuilder &builder = getBuilder();
fir::FortranVariableOpInterface ptrVar =
gen(Fortran::lower::getCrayPointer(symbolRef));
gen(Fortran::semantics::GetCrayPointer(symbolRef));
mlir::Value ptrAddr = ptrVar.getBase();

// Reinterpret the reference to a Cray pointer so that
Expand All @@ -306,9 +306,16 @@ class HlfirDesignatorBuilder {
}
return *varDef;
}
llvm::errs() << *symbolRef << "\n";
TODO(getLoc(), "lowering symbol to HLFIR");
}

fir::FortranVariableOpInterface
gen(const Fortran::semantics::Symbol &symbol) {
Fortran::evaluate::SymbolRef symref{symbol};
return gen(symref);
}

fir::FortranVariableOpInterface
gen(const Fortran::evaluate::Component &component) {
if (Fortran::semantics::IsAllocatableOrPointer(component.GetLastSymbol()))
Expand Down
48 changes: 19 additions & 29 deletions flang/lib/Lower/ConvertVariable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,11 @@ fir::FortranVariableFlagsAttr Fortran::lower::translateSymbolAttributes(
mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym,
fir::FortranVariableFlagsEnum extraFlags) {
fir::FortranVariableFlagsEnum flags = extraFlags;
if (sym.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
// CrayPointee are represented as pointers.
flags = flags | fir::FortranVariableFlagsEnum::pointer;
return fir::FortranVariableFlagsAttr::get(mlirContext, flags);
}
const auto &attrs = sym.attrs();
if (attrs.test(Fortran::semantics::Attr::ALLOCATABLE))
flags = flags | fir::FortranVariableFlagsEnum::allocatable;
Expand Down Expand Up @@ -1615,8 +1620,6 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
(!Fortran::semantics::IsProcedure(sym) ||
Fortran::semantics::IsPointer(sym)) &&
!sym.detailsIf<Fortran::semantics::CommonBlockDetails>()) {
bool isCrayPointee =
sym.test(Fortran::semantics::Symbol::Flag::CrayPointee);
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
const mlir::Location loc = genLocation(converter, sym);
mlir::Value shapeOrShift;
Expand All @@ -1636,31 +1639,21 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
Fortran::lower::translateSymbolCUDADataAttribute(builder.getContext(),
sym);

if (isCrayPointee) {
mlir::Type baseType =
hlfir::getFortranElementOrSequenceType(base.getType());
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(baseType)) {
// The pointer box's sequence type must be with unknown shape.
llvm::SmallVector<int64_t> shape(seqType.getDimension(),
fir::SequenceType::getUnknownExtent());
baseType = fir::SequenceType::get(shape, seqType.getEleTy());
}
fir::BoxType ptrBoxType =
fir::BoxType::get(fir::PointerType::get(baseType));
if (sym.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
mlir::Type ptrBoxType =
Fortran::lower::getCrayPointeeBoxType(base.getType());
mlir::Value boxAlloc = builder.createTemporary(loc, ptrBoxType);

// Declare a local pointer variable.
attributes = fir::FortranVariableFlagsAttr::get(
builder.getContext(), fir::FortranVariableFlagsEnum::pointer);
auto newBase = builder.create<hlfir::DeclareOp>(
loc, boxAlloc, name, /*shape=*/nullptr, lenParams, attributes);
mlir::Value nullAddr =
builder.createNullConstant(loc, ptrBoxType.getEleTy());
mlir::Value nullAddr = builder.createNullConstant(
loc, llvm::cast<fir::BaseBoxType>(ptrBoxType).getEleTy());

// If the element type is known-length character, then
// EmboxOp does not need the length parameters.
if (auto charType = mlir::dyn_cast<fir::CharacterType>(
fir::unwrapSequenceType(baseType)))
hlfir::getFortranElementType(base.getType())))
if (!charType.hasDynamicLen())
lenParams.clear();

Expand Down Expand Up @@ -2346,16 +2339,13 @@ void Fortran::lower::createRuntimeTypeInfoGlobal(
defineGlobal(converter, var, globalName, linkage);
}

Fortran::semantics::SymbolRef
Fortran::lower::getCrayPointer(Fortran::semantics::SymbolRef sym) {
assert(!sym->GetUltimate().owner().crayPointers().empty() &&
"empty Cray pointer/pointee map");
for (const auto &[pointee, pointer] :
sym->GetUltimate().owner().crayPointers()) {
if (pointee == sym->name()) {
Fortran::semantics::SymbolRef v{pointer.get()};
return v;
}
mlir::Type Fortran::lower::getCrayPointeeBoxType(mlir::Type fortranType) {
mlir::Type baseType = hlfir::getFortranElementOrSequenceType(fortranType);
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(baseType)) {
// The pointer box's sequence type must be with unknown shape.
llvm::SmallVector<int64_t> shape(seqType.getDimension(),
fir::SequenceType::getUnknownExtent());
baseType = fir::SequenceType::get(shape, seqType.getEleTy());
}
llvm_unreachable("corresponding Cray pointer cannot be found");
return fir::BoxType::get(fir::PointerType::get(baseType));
}
9 changes: 7 additions & 2 deletions flang/lib/Lower/HostAssociations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,11 @@ class CapturedAllocatableAndPointer
public:
static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
return fir::ReferenceType::get(converter.genType(sym));
mlir::Type baseType = converter.genType(sym);
if (sym.GetUltimate().test(Fortran::semantics::Symbol::Flag::CrayPointee))
return fir::ReferenceType::get(
Fortran::lower::getCrayPointeeBoxType(baseType));
return fir::ReferenceType::get(baseType);
}
static void instantiateHostTuple(const InstantiateHostTuple &args,
Fortran::lower::AbstractConverter &converter,
Expand Down Expand Up @@ -507,7 +511,8 @@ walkCaptureCategories(T visitor, Fortran::lower::AbstractConverter &converter,
if (Fortran::semantics::IsProcedure(sym))
return CapturedProcedure::visit(visitor, converter, sym, ba);
ba.analyze(sym);
if (Fortran::semantics::IsAllocatableOrPointer(sym))
if (Fortran::semantics::IsAllocatableOrPointer(sym) ||
sym.GetUltimate().test(Fortran::semantics::Symbol::Flag::CrayPointee))
return CapturedAllocatableAndPointer::visit(visitor, converter, sym, ba);
if (ba.isArray())
return CapturedArrays::visit(visitor, converter, sym, ba);
Expand Down
9 changes: 9 additions & 0 deletions flang/lib/Lower/PFTBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,11 @@ struct SymbolDependenceAnalysis {
if (!s->has<semantics::DerivedTypeDetails>())
depth = std::max(analyze(s) + 1, depth);
}

// Make sure cray pointer is instantiated even if it is not visible.
if (ultimate.test(Fortran::semantics::Symbol::Flag::CrayPointee))
depth = std::max(
analyze(Fortran::semantics::GetCrayPointer(ultimate)) + 1, depth);
adjustSize(depth + 1);
bool global = lower::symbolIsGlobal(sym);
layeredVarList[depth].emplace_back(sym, global, depth);
Expand Down Expand Up @@ -2002,6 +2007,10 @@ struct SymbolVisitor {
}
}
}
// - CrayPointer needs to be available whenever a CrayPointee is used.
if (symbol.GetUltimate().test(
Fortran::semantics::Symbol::Flag::CrayPointee))
visitSymbol(Fortran::semantics::GetCrayPointer(symbol));
}

template <typename A>
Expand Down
12 changes: 12 additions & 0 deletions flang/lib/Semantics/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,18 @@ const Symbol &BypassGeneric(const Symbol &symbol) {
return symbol;
}

const Symbol &GetCrayPointer(const Symbol &crayPointee) {
const Symbol *found{nullptr};
for (const auto &[pointee, pointer] :
crayPointee.GetUltimate().owner().crayPointers()) {
if (pointee == crayPointee.name()) {
found = &pointer.get();
break;
}
}
return DEREF(found);
}

bool ExprHasTypeCategory(
const SomeExpr &expr, const common::TypeCategory &type) {
auto dynamicType{expr.GetType()};
Expand Down
Loading