Skip to content

Commit f8843ef

Browse files
authored
[flang][hlfir] Lower Cray pointee references. (#65563)
A Cray pointee reference must be done using the characteristics (bounds, type params) of the original pointee declaration, but using the actual address value of the associated Cray pointer. There might be multiple Cray pointees associated with the same Cray pointer. The proposed solution is to lower each Cray pointee into a POINTER variable with a descriptor. The descriptor is initialized at the point of declaration of the pointee, though its base_addr is set to null. Before each reference of the Cray pointee its descriptor's base_addr is updated to the current value of the Cray pointer. The update of the base_addr is done using PointerAssociateScalar runtime call, which just updates the base_addr of the descriptor. This is a temporary solution just to make Cray pointers work to the same extent they work with FIR lowering.
1 parent 8835921 commit f8843ef

File tree

10 files changed

+351
-19
lines changed

10 files changed

+351
-19
lines changed

flang/include/flang/Lower/ConvertExpr.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ inline mlir::NamedAttribute getAdaptToByRefAttr(fir::FirOpBuilder &builder) {
234234
builder.getUnitAttr()};
235235
}
236236

237-
Fortran::semantics::SymbolRef getPointer(Fortran::semantics::SymbolRef sym);
238237
mlir::Value addCrayPointerInst(mlir::Location loc, fir::FirOpBuilder &builder,
239238
mlir::Value ptrVal, mlir::Type ptrTy,
240239
mlir::Type pteTy);

flang/include/flang/Lower/ConvertVariable.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,9 @@ void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
125125
const Fortran::semantics::Symbol &sym,
126126
const fir::ExtendedValue &exv, bool force = false);
127127

128+
/// For the given Cray pointee symbol return the corresponding
129+
/// Cray pointer symbol. Assert if the pointer symbol cannot be found.
130+
Fortran::semantics::SymbolRef getCrayPointer(Fortran::semantics::SymbolRef sym);
131+
128132
} // namespace Fortran::lower
129133
#endif // FORTRAN_LOWER_CONVERT_VARIABLE_H
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===-- Pointer.h - generate pointer runtime API calls-----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_POINTER_H
10+
#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_POINTER_H
11+
12+
#include "mlir/IR/Value.h"
13+
14+
namespace mlir {
15+
class Location;
16+
} // namespace mlir
17+
18+
namespace fir {
19+
class FirOpBuilder;
20+
}
21+
22+
namespace fir::runtime {
23+
24+
/// Generate runtime call to associate \p target address of scalar
25+
/// with the \p desc pointer descriptor.
26+
void genPointerAssociateScalar(fir::FirOpBuilder &builder, mlir::Location loc,
27+
mlir::Value desc, mlir::Value target);
28+
29+
} // namespace fir::runtime
30+
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_POINTER_H

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3629,7 +3629,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
36293629
sym->Rank() == 0) {
36303630
// get the corresponding Cray pointer
36313631

3632-
auto ptrSym = Fortran::lower::getPointer(*sym);
3632+
auto ptrSym = Fortran::lower::getCrayPointer(*sym);
36333633
fir::ExtendedValue ptr =
36343634
getSymbolExtendedValue(ptrSym, nullptr);
36353635
mlir::Value ptrVal = fir::getBase(ptr);

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,7 @@ class ScalarExprLowering {
863863
addr);
864864
} else if (sym->test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
865865
// get the corresponding Cray pointer
866-
auto ptrSym = Fortran::lower::getPointer(sym);
866+
auto ptrSym = Fortran::lower::getCrayPointer(sym);
867867
ExtValue ptr = gen(ptrSym);
868868
mlir::Value ptrVal = fir::getBase(ptr);
869869
mlir::Type ptrTy = converter.genType(*ptrSym);
@@ -1571,7 +1571,7 @@ class ScalarExprLowering {
15711571
auto baseSym = getFirstSym(aref);
15721572
if (baseSym.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
15731573
// get the corresponding Cray pointer
1574-
auto ptrSym = Fortran::lower::getPointer(baseSym);
1574+
auto ptrSym = Fortran::lower::getCrayPointer(baseSym);
15751575

15761576
fir::ExtendedValue ptr = gen(ptrSym);
15771577
mlir::Value ptrVal = fir::getBase(ptr);
@@ -6974,7 +6974,7 @@ class ArrayExprLowering {
69746974
ComponentPath &components) {
69756975
mlir::Value ptrVal = nullptr;
69766976
if (x.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
6977-
auto ptrSym = Fortran::lower::getPointer(x);
6977+
auto ptrSym = Fortran::lower::getCrayPointer(x);
69786978
ExtValue ptr = converter.getSymbolExtendedValue(ptrSym);
69796979
ptrVal = fir::getBase(ptr);
69806980
}
@@ -7629,19 +7629,6 @@ void Fortran::lower::createArrayMergeStores(
76297629
esp.incrementCounter();
76307630
}
76317631

7632-
Fortran::semantics::SymbolRef
7633-
Fortran::lower::getPointer(Fortran::semantics::SymbolRef sym) {
7634-
assert(!sym->owner().crayPointers().empty() &&
7635-
"empty Cray pointer/pointee map");
7636-
for (const auto &[pointee, pointer] : sym->owner().crayPointers()) {
7637-
if (pointee == sym->name()) {
7638-
Fortran::semantics::SymbolRef v{pointer.get()};
7639-
return v;
7640-
}
7641-
}
7642-
llvm_unreachable("corresponding Cray pointer cannot be found");
7643-
}
7644-
76457632
mlir::Value Fortran::lower::addCrayPointerInst(mlir::Location loc,
76467633
fir::FirOpBuilder &builder,
76477634
mlir::Value ptrVal,

flang/lib/Lower/ConvertExprToHLFIR.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "flang/Optimizer/Builder/MutableBox.h"
2929
#include "flang/Optimizer/Builder/Runtime/Character.h"
3030
#include "flang/Optimizer/Builder/Runtime/Derived.h"
31+
#include "flang/Optimizer/Builder/Runtime/Pointer.h"
3132
#include "flang/Optimizer/Builder/Todo.h"
3233
#include "flang/Optimizer/HLFIR/HLFIROps.h"
3334
#include "llvm/ADT/TypeSwitch.h"
@@ -268,8 +269,36 @@ class HlfirDesignatorBuilder {
268269
fir::FortranVariableOpInterface
269270
gen(const Fortran::evaluate::SymbolRef &symbolRef) {
270271
if (std::optional<fir::FortranVariableOpInterface> varDef =
271-
getSymMap().lookupVariableDefinition(symbolRef))
272+
getSymMap().lookupVariableDefinition(symbolRef)) {
273+
if (symbolRef->test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
274+
// The pointee is represented with a descriptor inheriting
275+
// the shape and type parameters of the pointee.
276+
// We have to update the base_addr to point to the current
277+
// value of the Cray pointer variable.
278+
fir::FirOpBuilder &builder = getBuilder();
279+
fir::FortranVariableOpInterface ptrVar =
280+
gen(Fortran::lower::getCrayPointer(symbolRef));
281+
mlir::Value ptrAddr = ptrVar.getBase();
282+
283+
// Reinterpret the reference to a Cray pointer so that
284+
// we have a pointer-compatible value after loading
285+
// the Cray pointer value.
286+
mlir::Type refPtrType = builder.getRefType(
287+
fir::PointerType::get(fir::dyn_cast_ptrEleTy(ptrAddr.getType())));
288+
mlir::Value cast = builder.createConvert(loc, refPtrType, ptrAddr);
289+
mlir::Value ptrVal = builder.create<fir::LoadOp>(loc, cast);
290+
291+
// Update the base_addr to the value of the Cray pointer.
292+
// This is a hacky way to do the update, and it may harm
293+
// performance around Cray pointer references.
294+
// TODO: we should introduce an operation that updates
295+
// just the base_addr of the given box. The CodeGen
296+
// will just convert it into a single store.
297+
fir::runtime::genPointerAssociateScalar(builder, loc, varDef->getBase(),
298+
ptrVal);
299+
}
272300
return *varDef;
301+
}
273302
TODO(getLoc(), "lowering symbol to HLFIR");
274303
}
275304

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,6 +1477,8 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
14771477
if (converter.getLoweringOptions().getLowerToHighLevelFIR() &&
14781478
!Fortran::semantics::IsProcedure(sym) &&
14791479
!sym.detailsIf<Fortran::semantics::CommonBlockDetails>()) {
1480+
bool isCrayPointee =
1481+
sym.test(Fortran::semantics::Symbol::Flag::CrayPointee);
14801482
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
14811483
const mlir::Location loc = genLocation(converter, sym);
14821484
mlir::Value shapeOrShift;
@@ -1492,6 +1494,51 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
14921494
auto name = converter.mangleName(sym);
14931495
fir::FortranVariableFlagsAttr attributes =
14941496
Fortran::lower::translateSymbolAttributes(builder.getContext(), sym);
1497+
1498+
if (isCrayPointee) {
1499+
mlir::Type baseType =
1500+
hlfir::getFortranElementOrSequenceType(base.getType());
1501+
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(baseType)) {
1502+
// The pointer box's sequence type must be with unknown shape.
1503+
llvm::SmallVector<int64_t> shape(seqType.getDimension(),
1504+
fir::SequenceType::getUnknownExtent());
1505+
baseType = fir::SequenceType::get(shape, seqType.getEleTy());
1506+
}
1507+
fir::BoxType ptrBoxType =
1508+
fir::BoxType::get(fir::PointerType::get(baseType));
1509+
mlir::Value boxAlloc = builder.createTemporary(loc, ptrBoxType);
1510+
1511+
// Declare a local pointer variable.
1512+
attributes = fir::FortranVariableFlagsAttr::get(
1513+
builder.getContext(), fir::FortranVariableFlagsEnum::pointer);
1514+
auto newBase = builder.create<hlfir::DeclareOp>(
1515+
loc, boxAlloc, name, /*shape=*/nullptr, lenParams, attributes);
1516+
mlir::Value nullAddr =
1517+
builder.createNullConstant(loc, ptrBoxType.getEleTy());
1518+
1519+
// If the element type is known-length character, then
1520+
// EmboxOp does not need the length parameters.
1521+
if (auto charType = mlir::dyn_cast<fir::CharacterType>(
1522+
fir::unwrapSequenceType(baseType)))
1523+
if (!charType.hasDynamicLen())
1524+
lenParams.clear();
1525+
1526+
// Inherit the shape (and maybe length parameters) from the pointee
1527+
// declaration.
1528+
mlir::Value initVal =
1529+
builder.create<fir::EmboxOp>(loc, ptrBoxType, nullAddr, shapeOrShift,
1530+
/*slice=*/nullptr, lenParams);
1531+
builder.create<fir::StoreOp>(loc, initVal, newBase.getBase());
1532+
1533+
// Any reference to the pointee is going to be using the pointer
1534+
// box from now on. The base_addr of the descriptor must be updated
1535+
// to hold the value of the Cray pointer at the point of the pointee
1536+
// access.
1537+
// Note that the same Cray pointer may be associated with
1538+
// multiple pointees and each of them has its own descriptor.
1539+
symMap.addVariableDefinition(sym, newBase, force);
1540+
return;
1541+
}
14951542
auto newBase = builder.create<hlfir::DeclareOp>(
14961543
loc, base, name, shapeOrShift, lenParams, attributes);
14971544
symMap.addVariableDefinition(sym, newBase, force);
@@ -2056,3 +2103,16 @@ void Fortran::lower::createRuntimeTypeInfoGlobal(
20562103
mlir::StringAttr linkage = getLinkageAttribute(builder, var);
20572104
defineGlobal(converter, var, globalName, linkage);
20582105
}
2106+
2107+
Fortran::semantics::SymbolRef
2108+
Fortran::lower::getCrayPointer(Fortran::semantics::SymbolRef sym) {
2109+
assert(!sym->owner().crayPointers().empty() &&
2110+
"empty Cray pointer/pointee map");
2111+
for (const auto &[pointee, pointer] : sym->owner().crayPointers()) {
2112+
if (pointee == sym->name()) {
2113+
Fortran::semantics::SymbolRef v{pointer.get()};
2114+
return v;
2115+
}
2116+
}
2117+
llvm_unreachable("corresponding Cray pointer cannot be found");
2118+
}

flang/lib/Optimizer/Builder/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ add_flang_library(FIRBuilder
2222
Runtime/Inquiry.cpp
2323
Runtime/Intrinsics.cpp
2424
Runtime/Numeric.cpp
25+
Runtime/Pointer.cpp
2526
Runtime/Ragged.cpp
2627
Runtime/Reduction.cpp
2728
Runtime/Stop.cpp
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===-- Pointer.cpp -- generate pointer runtime API calls------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Optimizer/Builder/Runtime/Pointer.h"
10+
#include "flang/Optimizer/Builder/FIRBuilder.h"
11+
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
12+
#include "flang/Runtime/pointer.h"
13+
14+
using namespace Fortran::runtime;
15+
16+
void fir::runtime::genPointerAssociateScalar(fir::FirOpBuilder &builder,
17+
mlir::Location loc,
18+
mlir::Value desc,
19+
mlir::Value target) {
20+
mlir::func::FuncOp func{
21+
fir::runtime::getRuntimeFunc<mkRTKey(PointerAssociateScalar)>(loc,
22+
builder)};
23+
mlir::FunctionType fTy{func.getFunctionType()};
24+
llvm::SmallVector<mlir::Value> args{
25+
fir::runtime::createArguments(builder, loc, fTy, desc, target)};
26+
builder.create<fir::CallOp>(loc, func, args);
27+
}

0 commit comments

Comments
 (0)