Skip to content

Commit de7a50f

Browse files
authored
[flang] Fix lowering of host associated cray pointee symbols (#86121)
Cray pointee symbols can be host associated from a module or host procedure while the related cray pointer is not explicitly associated. This caused the "not yet implemented: lowering symbol to HLFIR" to fire when lowering a reference to the cray pointee and fetching the cray pointer. This patch: - Ensures cray pointers are always instantiated when instantiating a cray pointee. - Fix internal procedure lowering to deal with cray pointee host association like it does for pointers (the lowering strategy for cray pointee is to create a pointer that is updated with the cray pointer value before being fetched). This should fix the bug reported in #85420.
1 parent 465ea0b commit de7a50f

File tree

11 files changed

+176
-53
lines changed

11 files changed

+176
-53
lines changed

flang/include/flang/Lower/ConvertVariable.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
161161
fir::FortranVariableFlagsEnum::None,
162162
bool force = false);
163163

164-
/// For the given Cray pointee symbol return the corresponding
165-
/// Cray pointer symbol. Assert if the pointer symbol cannot be found.
166-
Fortran::semantics::SymbolRef getCrayPointer(Fortran::semantics::SymbolRef sym);
164+
/// Given the Fortran type of a Cray pointee, return the fir.box type used to
165+
/// track the cray pointee as Fortran pointer.
166+
mlir::Type getCrayPointeeBoxType(mlir::Type);
167167

168168
} // namespace lower
169169
} // namespace Fortran

flang/include/flang/Semantics/tools.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,9 @@ const Symbol *FindExternallyVisibleObject(
282282
// specific procedure of the same name, return it instead.
283283
const Symbol &BypassGeneric(const Symbol &);
284284

285+
// Given a cray pointee symbol, returns the related cray pointer symbol.
286+
const Symbol &GetCrayPointer(const Symbol &crayPointee);
287+
285288
using SomeExpr = evaluate::Expr<evaluate::SomeType>;
286289

287290
bool ExprHasTypeCategory(

flang/lib/Lower/Bridge.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3995,11 +3995,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
39953995
sym->Rank() == 0) {
39963996
// get the corresponding Cray pointer
39973997

3998-
auto ptrSym = Fortran::lower::getCrayPointer(*sym);
3998+
const Fortran::semantics::Symbol &ptrSym =
3999+
Fortran::semantics::GetCrayPointer(*sym);
39994000
fir::ExtendedValue ptr =
40004001
getSymbolExtendedValue(ptrSym, nullptr);
40014002
mlir::Value ptrVal = fir::getBase(ptr);
4002-
mlir::Type ptrTy = genType(*ptrSym);
4003+
mlir::Type ptrTy = genType(ptrSym);
40034004

40044005
fir::ExtendedValue pte =
40054006
getSymbolExtendedValue(*sym, nullptr);

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,8 @@ class ScalarExprLowering {
862862
addr);
863863
} else if (sym->test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
864864
// get the corresponding Cray pointer
865-
auto ptrSym = Fortran::lower::getCrayPointer(sym);
865+
Fortran::semantics::SymbolRef ptrSym{
866+
Fortran::semantics::GetCrayPointer(sym)};
866867
ExtValue ptr = gen(ptrSym);
867868
mlir::Value ptrVal = fir::getBase(ptr);
868869
mlir::Type ptrTy = converter.genType(*ptrSym);
@@ -1537,8 +1538,8 @@ class ScalarExprLowering {
15371538
auto baseSym = getFirstSym(aref);
15381539
if (baseSym.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
15391540
// get the corresponding Cray pointer
1540-
auto ptrSym = Fortran::lower::getCrayPointer(baseSym);
1541-
1541+
Fortran::semantics::SymbolRef ptrSym{
1542+
Fortran::semantics::GetCrayPointer(baseSym)};
15421543
fir::ExtendedValue ptr = gen(ptrSym);
15431544
mlir::Value ptrVal = fir::getBase(ptr);
15441545
mlir::Type ptrTy = ptrVal.getType();
@@ -6946,7 +6947,8 @@ class ArrayExprLowering {
69466947
ComponentPath &components) {
69476948
mlir::Value ptrVal = nullptr;
69486949
if (x.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
6949-
auto ptrSym = Fortran::lower::getCrayPointer(x);
6950+
Fortran::semantics::SymbolRef ptrSym{
6951+
Fortran::semantics::GetCrayPointer(x)};
69506952
ExtValue ptr = converter.getSymbolExtendedValue(ptrSym);
69516953
ptrVal = fir::getBase(ptr);
69526954
}

flang/lib/Lower/ConvertExprToHLFIR.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ class HlfirDesignatorBuilder {
284284
// value of the Cray pointer variable.
285285
fir::FirOpBuilder &builder = getBuilder();
286286
fir::FortranVariableOpInterface ptrVar =
287-
gen(Fortran::lower::getCrayPointer(symbolRef));
287+
gen(Fortran::semantics::GetCrayPointer(symbolRef));
288288
mlir::Value ptrAddr = ptrVar.getBase();
289289

290290
// Reinterpret the reference to a Cray pointer so that
@@ -306,9 +306,16 @@ class HlfirDesignatorBuilder {
306306
}
307307
return *varDef;
308308
}
309+
llvm::errs() << *symbolRef << "\n";
309310
TODO(getLoc(), "lowering symbol to HLFIR");
310311
}
311312

313+
fir::FortranVariableOpInterface
314+
gen(const Fortran::semantics::Symbol &symbol) {
315+
Fortran::evaluate::SymbolRef symref{symbol};
316+
return gen(symref);
317+
}
318+
312319
fir::FortranVariableOpInterface
313320
gen(const Fortran::evaluate::Component &component) {
314321
if (Fortran::semantics::IsAllocatableOrPointer(component.GetLastSymbol()))

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,6 +1554,11 @@ fir::FortranVariableFlagsAttr Fortran::lower::translateSymbolAttributes(
15541554
mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym,
15551555
fir::FortranVariableFlagsEnum extraFlags) {
15561556
fir::FortranVariableFlagsEnum flags = extraFlags;
1557+
if (sym.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
1558+
// CrayPointee are represented as pointers.
1559+
flags = flags | fir::FortranVariableFlagsEnum::pointer;
1560+
return fir::FortranVariableFlagsAttr::get(mlirContext, flags);
1561+
}
15571562
const auto &attrs = sym.attrs();
15581563
if (attrs.test(Fortran::semantics::Attr::ALLOCATABLE))
15591564
flags = flags | fir::FortranVariableFlagsEnum::allocatable;
@@ -1615,8 +1620,6 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
16151620
(!Fortran::semantics::IsProcedure(sym) ||
16161621
Fortran::semantics::IsPointer(sym)) &&
16171622
!sym.detailsIf<Fortran::semantics::CommonBlockDetails>()) {
1618-
bool isCrayPointee =
1619-
sym.test(Fortran::semantics::Symbol::Flag::CrayPointee);
16201623
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
16211624
const mlir::Location loc = genLocation(converter, sym);
16221625
mlir::Value shapeOrShift;
@@ -1636,31 +1639,21 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
16361639
Fortran::lower::translateSymbolCUDADataAttribute(builder.getContext(),
16371640
sym);
16381641

1639-
if (isCrayPointee) {
1640-
mlir::Type baseType =
1641-
hlfir::getFortranElementOrSequenceType(base.getType());
1642-
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(baseType)) {
1643-
// The pointer box's sequence type must be with unknown shape.
1644-
llvm::SmallVector<int64_t> shape(seqType.getDimension(),
1645-
fir::SequenceType::getUnknownExtent());
1646-
baseType = fir::SequenceType::get(shape, seqType.getEleTy());
1647-
}
1648-
fir::BoxType ptrBoxType =
1649-
fir::BoxType::get(fir::PointerType::get(baseType));
1642+
if (sym.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
1643+
mlir::Type ptrBoxType =
1644+
Fortran::lower::getCrayPointeeBoxType(base.getType());
16501645
mlir::Value boxAlloc = builder.createTemporary(loc, ptrBoxType);
16511646

16521647
// Declare a local pointer variable.
1653-
attributes = fir::FortranVariableFlagsAttr::get(
1654-
builder.getContext(), fir::FortranVariableFlagsEnum::pointer);
16551648
auto newBase = builder.create<hlfir::DeclareOp>(
16561649
loc, boxAlloc, name, /*shape=*/nullptr, lenParams, attributes);
1657-
mlir::Value nullAddr =
1658-
builder.createNullConstant(loc, ptrBoxType.getEleTy());
1650+
mlir::Value nullAddr = builder.createNullConstant(
1651+
loc, llvm::cast<fir::BaseBoxType>(ptrBoxType).getEleTy());
16591652

16601653
// If the element type is known-length character, then
16611654
// EmboxOp does not need the length parameters.
16621655
if (auto charType = mlir::dyn_cast<fir::CharacterType>(
1663-
fir::unwrapSequenceType(baseType)))
1656+
hlfir::getFortranElementType(base.getType())))
16641657
if (!charType.hasDynamicLen())
16651658
lenParams.clear();
16661659

@@ -2346,16 +2339,13 @@ void Fortran::lower::createRuntimeTypeInfoGlobal(
23462339
defineGlobal(converter, var, globalName, linkage);
23472340
}
23482341

2349-
Fortran::semantics::SymbolRef
2350-
Fortran::lower::getCrayPointer(Fortran::semantics::SymbolRef sym) {
2351-
assert(!sym->GetUltimate().owner().crayPointers().empty() &&
2352-
"empty Cray pointer/pointee map");
2353-
for (const auto &[pointee, pointer] :
2354-
sym->GetUltimate().owner().crayPointers()) {
2355-
if (pointee == sym->name()) {
2356-
Fortran::semantics::SymbolRef v{pointer.get()};
2357-
return v;
2358-
}
2342+
mlir::Type Fortran::lower::getCrayPointeeBoxType(mlir::Type fortranType) {
2343+
mlir::Type baseType = hlfir::getFortranElementOrSequenceType(fortranType);
2344+
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(baseType)) {
2345+
// The pointer box's sequence type must be with unknown shape.
2346+
llvm::SmallVector<int64_t> shape(seqType.getDimension(),
2347+
fir::SequenceType::getUnknownExtent());
2348+
baseType = fir::SequenceType::get(shape, seqType.getEleTy());
23592349
}
2360-
llvm_unreachable("corresponding Cray pointer cannot be found");
2350+
return fir::BoxType::get(fir::PointerType::get(baseType));
23612351
}

flang/lib/Lower/HostAssociations.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,11 @@ class CapturedAllocatableAndPointer
315315
public:
316316
static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
317317
const Fortran::semantics::Symbol &sym) {
318-
return fir::ReferenceType::get(converter.genType(sym));
318+
mlir::Type baseType = converter.genType(sym);
319+
if (sym.GetUltimate().test(Fortran::semantics::Symbol::Flag::CrayPointee))
320+
return fir::ReferenceType::get(
321+
Fortran::lower::getCrayPointeeBoxType(baseType));
322+
return fir::ReferenceType::get(baseType);
319323
}
320324
static void instantiateHostTuple(const InstantiateHostTuple &args,
321325
Fortran::lower::AbstractConverter &converter,
@@ -507,7 +511,8 @@ walkCaptureCategories(T visitor, Fortran::lower::AbstractConverter &converter,
507511
if (Fortran::semantics::IsProcedure(sym))
508512
return CapturedProcedure::visit(visitor, converter, sym, ba);
509513
ba.analyze(sym);
510-
if (Fortran::semantics::IsAllocatableOrPointer(sym))
514+
if (Fortran::semantics::IsAllocatableOrPointer(sym) ||
515+
sym.GetUltimate().test(Fortran::semantics::Symbol::Flag::CrayPointee))
511516
return CapturedAllocatableAndPointer::visit(visitor, converter, sym, ba);
512517
if (ba.isArray())
513518
return CapturedArrays::visit(visitor, converter, sym, ba);

flang/lib/Lower/PFTBuilder.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,11 @@ struct SymbolDependenceAnalysis {
15941594
if (!s->has<semantics::DerivedTypeDetails>())
15951595
depth = std::max(analyze(s) + 1, depth);
15961596
}
1597+
1598+
// Make sure cray pointer is instantiated even if it is not visible.
1599+
if (ultimate.test(Fortran::semantics::Symbol::Flag::CrayPointee))
1600+
depth = std::max(
1601+
analyze(Fortran::semantics::GetCrayPointer(ultimate)) + 1, depth);
15971602
adjustSize(depth + 1);
15981603
bool global = lower::symbolIsGlobal(sym);
15991604
layeredVarList[depth].emplace_back(sym, global, depth);
@@ -2002,6 +2007,10 @@ struct SymbolVisitor {
20022007
}
20032008
}
20042009
}
2010+
// - CrayPointer needs to be available whenever a CrayPointee is used.
2011+
if (symbol.GetUltimate().test(
2012+
Fortran::semantics::Symbol::Flag::CrayPointee))
2013+
visitSymbol(Fortran::semantics::GetCrayPointer(symbol));
20052014
}
20062015

20072016
template <typename A>

flang/lib/Semantics/tools.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,18 @@ const Symbol &BypassGeneric(const Symbol &symbol) {
403403
return symbol;
404404
}
405405

406+
const Symbol &GetCrayPointer(const Symbol &crayPointee) {
407+
const Symbol *found{nullptr};
408+
for (const auto &[pointee, pointer] :
409+
crayPointee.GetUltimate().owner().crayPointers()) {
410+
if (pointee == crayPointee.name()) {
411+
found = &pointer.get();
412+
break;
413+
}
414+
}
415+
return DEREF(found);
416+
}
417+
406418
bool ExprHasTypeCategory(
407419
const SomeExpr &expr, const common::TypeCategory &type) {
408420
auto dynamicType{expr.GetType()};

0 commit comments

Comments
 (0)