Skip to content

Commit 476e443

Browse files
committed
[Flang][OpenMP] Align map clause generation and fix issue with non-shared allocations for assumed shape/size descriptor types
This PR aims to unify the map argument generation behavior across both the implicit capture (captured in a target region) and the explicit capture (process map), currently the varPtr field of the MapInfo for the same variable will be different depending on how it's captured. This PR tries to align that across the generations of MapInfoOp in the OpenMP lowering. Currently, I have opted to utilise the rawInput (input memref to a HLFIR DeclareInfoOp) as opposed to the addr field which includes more information. The side affect of this is that we have to deal with BoxTypes less often, which will result in simpler maps in these cases. The negative side affect of this is that we don't have access to the bounds information through the resulting value, however, I believe the bounds information we require in our case is still appropriately stored in the map bounds, and this seems to be the case from testing so far. The other fix is for cases where we end up with a BoxType argument into a function (certain assumed shape and sizes cases do this) that has no fir.ref wrapping it. As we need the Box to be a reference type to actually utilise the operation to access the base address stored inside and create the correct mappings we currently generate an intermediate allocation in these cases, and then store into it, and utilise this as the map argument, as opposed to the original. However, as we were not sharing the same intermediate allocation across all of the maps for a variable, this resulted in errors in certain cases when detatching/attatching the data e.g. via enter and exit. This PR adjusts this for cases Currently we only maintain tracking of all intermediate allocations for the current function scope, as opposed to module. Primarily as the only case I am aware of that this is required is in cases where we pass certain types of arguments to functions (so I opted to minimize the overhead of the pass for now). It could likely be extended to module scope if required if we find other cases where it's applicable and causing issues.
1 parent 788731c commit 476e443

File tree

9 files changed

+141
-127
lines changed

9 files changed

+141
-127
lines changed

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,11 +341,11 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> {
341341
}
342342

343343
def OMPMapInfoFinalizationPass
344-
: Pass<"omp-map-info-finalization"> {
344+
: Pass<"omp-map-info-finalization", "mlir::func::FuncOp"> {
345345
let summary = "expands OpenMP MapInfo operations containing descriptors";
346346
let description = [{
347-
Expands MapInfo operations containing descriptor types into multiple
348-
MapInfo's for each pointer element in the descriptor that requires
347+
Expands MapInfo operations containing descriptor types into multiple
348+
MapInfo's for each pointer element in the descriptor that requires
349349
explicit individual mapping by the OpenMP runtime.
350350
}];
351351
let dependentDialects = ["mlir::omp::OpenMPDialect"];

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -970,25 +970,21 @@ bool ClauseProcessor::processMap(
970970
object.ref(), clauseLocation, asFortran, bounds,
971971
treatIndexAsSection);
972972

973-
auto origSymbol = converter.getSymbolAddress(*object.sym());
974-
mlir::Value symAddr = info.addr;
975-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
976-
symAddr = origSymbol;
977-
978973
// Explicit map captures are captured ByRef by default,
979974
// optimisation passes may alter this to ByCopy or other capture
980975
// types to optimise
976+
mlir::Value baseOp = info.rawInput;
981977
auto location = mlir::NameLoc::get(
982978
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
983-
symAddr.getLoc());
979+
baseOp.getLoc());
984980
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
985-
firOpBuilder, location, symAddr,
981+
firOpBuilder, location, baseOp,
986982
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
987983
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
988984
static_cast<
989985
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
990986
mapTypeBits),
991-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
987+
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
992988

993989
if (object.sym()->owner().IsDerivedType()) {
994990
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
@@ -997,9 +993,9 @@ bool ClauseProcessor::processMap(
997993
result.mapVars.push_back(mapOp);
998994
ptrMapSyms->push_back(object.sym());
999995
if (mapSymTypes)
1000-
mapSymTypes->push_back(symAddr.getType());
996+
mapSymTypes->push_back(baseOp.getType());
1001997
if (mapSymLocs)
1002-
mapSymLocs->push_back(symAddr.getLoc());
998+
mapSymLocs->push_back(baseOp.getLoc());
1003999
}
10041000
}
10051001
});

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,22 +212,18 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
212212
object.ref(), clauseLocation, asFortran, bounds,
213213
treatIndexAsSection);
214214

215-
auto origSymbol = converter.getSymbolAddress(*object.sym());
216-
mlir::Value symAddr = info.addr;
217-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
218-
symAddr = origSymbol;
219-
220215
// Explicit map captures are captured ByRef by default,
221216
// optimisation passes may alter this to ByCopy or other capture
222217
// types to optimise
218+
mlir::Value baseOp = info.rawInput;
223219
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
224-
firOpBuilder, clauseLocation, symAddr,
220+
firOpBuilder, clauseLocation, baseOp,
225221
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
226222
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
227223
static_cast<
228224
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
229225
mapTypeBits),
230-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
226+
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
231227

232228
if (object.sym()->owner().IsDerivedType()) {
233229
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 79 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,92 +1670,98 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
16701670
if (dsp.getAllSymbolsToPrivatize().contains(&sym))
16711671
return;
16721672

1673+
// Structure component symbols don't have bindings, and can only be
1674+
// explicitly mapped individually. If a member is captured implicitly
1675+
// we map the entirety of the derived type when we find its symbol.
1676+
if (sym.owner().IsDerivedType())
1677+
return;
1678+
16731679
// if the symbol is part of an already mapped common block, do not make a
16741680
// map for it.
16751681
if (const Fortran::semantics::Symbol *common =
16761682
Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate()))
16771683
if (llvm::find(mapSyms, common) != mapSyms.end())
16781684
return;
16791685

1680-
if (llvm::find(mapSyms, &sym) == mapSyms.end()) {
1681-
mlir::Value baseOp = converter.getSymbolAddress(sym);
1682-
if (!baseOp)
1683-
if (const auto *details =
1684-
sym.template detailsIf<semantics::HostAssocDetails>()) {
1685-
baseOp = converter.getSymbolAddress(details->symbol());
1686-
converter.copySymbolBinding(details->symbol(), sym);
1687-
}
1686+
// If we come across a symbol without a symbol address, we
1687+
// return as we cannot process it, this is intended as a
1688+
// catch all early exit for symbols that do not have a
1689+
// corresponding extended value. Such as subroutines,
1690+
// interfaces and named blocks.
1691+
if (!converter.getSymbolAddress(sym))
1692+
return;
16881693

1689-
if (baseOp) {
1690-
llvm::SmallVector<mlir::Value> bounds;
1691-
std::stringstream name;
1692-
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
1693-
name << sym.name().ToString();
1694-
1695-
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
1696-
converter, firOpBuilder, sym, converter.getCurrentLocation());
1697-
if (mlir::isa<fir::BaseBoxType>(
1698-
fir::unwrapRefType(info.addr.getType())))
1699-
bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
1700-
mlir::omp::MapBoundsType>(
1701-
firOpBuilder, converter.getCurrentLocation(), dataExv, info);
1702-
if (mlir::isa<fir::SequenceType>(
1703-
fir::unwrapRefType(info.addr.getType()))) {
1704-
bool dataExvIsAssumedSize =
1705-
semantics::IsAssumedSizeArray(sym.GetUltimate());
1706-
bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
1707-
mlir::omp::MapBoundsType>(
1708-
firOpBuilder, converter.getCurrentLocation(), dataExv,
1709-
dataExvIsAssumedSize);
1710-
}
1694+
if (llvm::find(mapSyms, &sym) == mapSyms.end()) {
1695+
if (const auto *details =
1696+
sym.template detailsIf<semantics::HostAssocDetails>())
1697+
converter.copySymbolBinding(details->symbol(), sym);
1698+
llvm::SmallVector<mlir::Value> bounds;
1699+
std::stringstream name;
1700+
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
1701+
name << sym.name().ToString();
1702+
1703+
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
1704+
converter, firOpBuilder, sym, converter.getCurrentLocation());
1705+
mlir::Value baseOp = info.rawInput;
1706+
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
1707+
bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
1708+
mlir::omp::MapBoundsType>(
1709+
firOpBuilder, converter.getCurrentLocation(), dataExv, info);
1710+
if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
1711+
bool dataExvIsAssumedSize =
1712+
semantics::IsAssumedSizeArray(sym.GetUltimate());
1713+
bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
1714+
mlir::omp::MapBoundsType>(
1715+
firOpBuilder, converter.getCurrentLocation(), dataExv,
1716+
dataExvIsAssumedSize);
1717+
}
17111718

1712-
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
1713-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1714-
mlir::omp::VariableCaptureKind captureKind =
1715-
mlir::omp::VariableCaptureKind::ByRef;
1716-
1717-
mlir::Type eleType = baseOp.getType();
1718-
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
1719-
eleType = refType.getElementType();
1720-
1721-
// If a variable is specified in declare target link and if device
1722-
// type is not specified as `nohost`, it needs to be mapped tofrom
1723-
mlir::ModuleOp mod = firOpBuilder.getModule();
1724-
mlir::Operation *op = mod.lookupSymbol(converter.mangleName(sym));
1725-
auto declareTargetOp =
1726-
llvm::dyn_cast_if_present<mlir::omp::DeclareTargetInterface>(op);
1727-
if (declareTargetOp && declareTargetOp.isDeclareTarget()) {
1728-
if (declareTargetOp.getDeclareTargetCaptureClause() ==
1729-
mlir::omp::DeclareTargetCaptureClause::link &&
1730-
declareTargetOp.getDeclareTargetDeviceType() !=
1731-
mlir::omp::DeclareTargetDeviceType::nohost) {
1732-
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1733-
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1734-
}
1735-
} else if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) {
1736-
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
1737-
} else if (!fir::isa_builtin_cptr_type(eleType)) {
1719+
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
1720+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1721+
mlir::omp::VariableCaptureKind captureKind =
1722+
mlir::omp::VariableCaptureKind::ByRef;
1723+
1724+
mlir::Type eleType = baseOp.getType();
1725+
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
1726+
eleType = refType.getElementType();
1727+
1728+
// If a variable is specified in declare target link and if device
1729+
// type is not specified as `nohost`, it needs to be mapped tofrom
1730+
mlir::ModuleOp mod = firOpBuilder.getModule();
1731+
mlir::Operation *op = mod.lookupSymbol(converter.mangleName(sym));
1732+
auto declareTargetOp =
1733+
llvm::dyn_cast_if_present<mlir::omp::DeclareTargetInterface>(op);
1734+
if (declareTargetOp && declareTargetOp.isDeclareTarget()) {
1735+
if (declareTargetOp.getDeclareTargetCaptureClause() ==
1736+
mlir::omp::DeclareTargetCaptureClause::link &&
1737+
declareTargetOp.getDeclareTargetDeviceType() !=
1738+
mlir::omp::DeclareTargetDeviceType::nohost) {
17381739
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
17391740
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
17401741
}
1741-
auto location =
1742-
mlir::NameLoc::get(mlir::StringAttr::get(firOpBuilder.getContext(),
1743-
sym.name().ToString()),
1744-
baseOp.getLoc());
1745-
mlir::Value mapOp = createMapInfoOp(
1746-
firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{},
1747-
name.str(), bounds, /*members=*/{},
1748-
/*membersIndex=*/mlir::DenseIntElementsAttr{},
1749-
static_cast<
1750-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1751-
mapFlag),
1752-
captureKind, baseOp.getType());
1753-
1754-
clauseOps.mapVars.push_back(mapOp);
1755-
mapSyms.push_back(&sym);
1756-
mapLocs.push_back(baseOp.getLoc());
1757-
mapTypes.push_back(baseOp.getType());
1742+
} else if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) {
1743+
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
1744+
} else if (!fir::isa_builtin_cptr_type(eleType)) {
1745+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1746+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
17581747
}
1748+
auto location =
1749+
mlir::NameLoc::get(mlir::StringAttr::get(firOpBuilder.getContext(),
1750+
sym.name().ToString()),
1751+
baseOp.getLoc());
1752+
mlir::Value mapOp = createMapInfoOp(
1753+
firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{},
1754+
name.str(), bounds, /*members=*/{},
1755+
/*membersIndex=*/mlir::DenseIntElementsAttr{},
1756+
static_cast<
1757+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1758+
mapFlag),
1759+
captureKind, baseOp.getType());
1760+
1761+
clauseOps.mapVars.push_back(mapOp);
1762+
mapSyms.push_back(&sym);
1763+
mapLocs.push_back(baseOp.getLoc());
1764+
mapTypes.push_back(baseOp.getType());
17591765
}
17601766
};
17611767
lower::pft::visitAllSymbols(eval, captureImplicitMap);

flang/lib/Optimizer/Transforms/OMPMapInfoFinalization.cpp

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ class OMPMapInfoFinalizationPass
5151
: public fir::impl::OMPMapInfoFinalizationPassBase<
5252
OMPMapInfoFinalizationPass> {
5353

54+
/// Tracks any intermediate function/subroutine local allocations we
55+
/// generate for the descriptors of box type dummy arguments, so that
56+
/// we can retrieve it for subsequent reuses within the functions
57+
/// scope
58+
std::map</*descriptor opaque pointer=*/void *,
59+
/*corresponding local alloca=*/fir::AllocaOp>
60+
localBoxAllocas;
61+
5462
void genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
5563
fir::FirOpBuilder &builder,
5664
mlir::Operation *target) {
@@ -75,14 +83,26 @@ class OMPMapInfoFinalizationPass
7583
// perform an alloca and then store to it and retrieve the data from the new
7684
// alloca.
7785
if (mlir::isa<fir::BaseBoxType>(descriptor.getType())) {
78-
mlir::OpBuilder::InsertPoint insPt = builder.saveInsertionPoint();
79-
mlir::Block *allocaBlock = builder.getAllocaBlock();
80-
assert(allocaBlock && "No alloca block found for this top level op");
81-
builder.setInsertionPointToStart(allocaBlock);
82-
auto alloca = builder.create<fir::AllocaOp>(loc, descriptor.getType());
83-
builder.restoreInsertionPoint(insPt);
84-
builder.create<fir::StoreOp>(loc, descriptor, alloca);
85-
descriptor = alloca;
86+
// if we have already created a local allocation for this BoxType,
87+
// we must be sure to re-use it so that we end up with the same
88+
// allocations being utilised for the same descriptor across all map uses,
89+
// this prevents runtime issues such as not appropriately releasing or
90+
// deleting all mapped data.
91+
auto find = localBoxAllocas.find(descriptor.getAsOpaquePointer());
92+
if (find != localBoxAllocas.end()) {
93+
builder.create<fir::StoreOp>(loc, descriptor, find->second);
94+
descriptor = find->second;
95+
} else {
96+
mlir::OpBuilder::InsertPoint insPt = builder.saveInsertionPoint();
97+
mlir::Block *allocaBlock = builder.getAllocaBlock();
98+
assert(allocaBlock && "No alloca block found for this top level op");
99+
builder.setInsertionPointToStart(allocaBlock);
100+
auto alloca = builder.create<fir::AllocaOp>(loc, descriptor.getType());
101+
builder.restoreInsertionPoint(insPt);
102+
builder.create<fir::StoreOp>(loc, descriptor, alloca);
103+
localBoxAllocas[descriptor.getAsOpaquePointer()] = alloca;
104+
descriptor = alloca;
105+
}
86106
}
87107

88108
mlir::Value baseAddrAddr = builder.create<fir::BoxOffsetOp>(
@@ -228,14 +248,12 @@ class OMPMapInfoFinalizationPass
228248
// operation (usually function) containing the MapInfoOp because this pass
229249
// will mutate siblings of MapInfoOp.
230250
void runOnOperation() override {
231-
mlir::ModuleOp module =
232-
mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation());
233-
if (!module)
234-
module = getOperation()->getParentOfType<mlir::ModuleOp>();
251+
mlir::func::FuncOp func = getOperation();
252+
mlir::ModuleOp module = func->getParentOfType<mlir::ModuleOp>();
235253
fir::KindMapping kindMap = fir::getKindMapping(module);
236254
fir::FirOpBuilder builder{module, std::move(kindMap)};
237255

238-
getOperation()->walk([&](mlir::omp::MapInfoOp op) {
256+
func->walk([&](mlir::omp::MapInfoOp op) {
239257
// TODO: Currently only supports a single user for the MapInfoOp, this
240258
// is fine for the moment as the Fortran Frontend will generate a
241259
// new MapInfoOp per Target operation for the moment. However, when/if

0 commit comments

Comments
 (0)