Skip to content

Commit f4cf93f

Browse files
authored
[Flang][OpenMP] Align map clause generation and fix issue with non-shared allocations for assumed shape/size descriptor types (llvm#97855)
This PR aims to unify the map argument generation behavior across both the implicit capture (captured in a target region) and the explicit capture (process map), currently the varPtr field of the MapInfo for the same variable will be different depending on how it's captured. This PR tries to align that across the generations of MapInfoOp in the OpenMP lowering. Currently, I have opted to utilise the rawInput (input memref to a HLFIR DeclareInfoOp) as opposed to the addr field which includes more information. The side affect of this is that we have to deal with BoxTypes less often, which will result in simpler maps in these cases. The negative side affect of this is that we don't have access to the bounds information through the resulting value, however, I believe the bounds information we require in our case is still appropriately stored in the map bounds, and this seems to be the case from testing so far. The other fix is for cases where we end up with a BoxType argument into a function (certain assumed shape and sizes cases do this) that has no fir.ref wrapping it. As we need the Box to be a reference type to actually utilise the operation to access the base address stored inside and create the correct mappings we currently generate an intermediate allocation in these cases, and then store into it, and utilise this as the map argument, as opposed to the original. However, as we were not sharing the same intermediate allocation across all of the maps for a variable, this resulted in errors in certain cases when detatching/attatching the data e.g. via enter and exit. This PR adjusts this for cases Currently we only maintain tracking of all intermediate allocations for the current function scope, as opposed to module. Primarily as the only case I am aware of that this is required is in cases where we pass certain types of arguments to functions (so I opted to minimize the overhead of the pass for now). It could likely be extended to module scope if required if we find other cases where it's applicable and causing issues.
1 parent f607102 commit f4cf93f

File tree

11 files changed

+244
-142
lines changed

11 files changed

+244
-142
lines changed

flang/include/flang/Optimizer/OpenMP/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
include "mlir/Pass/PassBase.td"
1313

1414
def MapInfoFinalizationPass
15-
: Pass<"omp-map-info-finalization"> {
15+
: Pass<"omp-map-info-finalization", "mlir::ModuleOp"> {
1616
let summary = "expands OpenMP MapInfo operations containing descriptors";
1717
let description = [{
1818
Expands MapInfo operations containing descriptor types into multiple

flang/include/flang/Tools/CLOptions.inc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,7 @@ inline void createHLFIRToFIRPassPipeline(
367367
/// rather than the host device.
368368
inline void createOpenMPFIRPassPipeline(
369369
mlir::PassManager &pm, bool isTargetDevice) {
370-
addNestedPassToAllTopLevelOperations(
371-
pm, flangomp::createMapInfoFinalizationPass);
370+
pm.addPass(flangomp::createMapInfoFinalizationPass());
372371
pm.addPass(flangomp::createMarkDeclareTargetPass());
373372
if (isTargetDevice)
374373
pm.addPass(flangomp::createFunctionFiltering());

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -960,25 +960,21 @@ bool ClauseProcessor::processMap(
960960
object.ref(), clauseLocation, asFortran, bounds,
961961
treatIndexAsSection);
962962

963-
auto origSymbol = converter.getSymbolAddress(*object.sym());
964-
mlir::Value symAddr = info.addr;
965-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
966-
symAddr = origSymbol;
967-
968963
// Explicit map captures are captured ByRef by default,
969964
// optimisation passes may alter this to ByCopy or other capture
970965
// types to optimise
966+
mlir::Value baseOp = info.rawInput;
971967
auto location = mlir::NameLoc::get(
972968
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
973-
symAddr.getLoc());
969+
baseOp.getLoc());
974970
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
975-
firOpBuilder, location, symAddr,
971+
firOpBuilder, location, baseOp,
976972
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
977973
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
978974
static_cast<
979975
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
980976
mapTypeBits),
981-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
977+
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
982978

983979
if (object.sym()->owner().IsDerivedType()) {
984980
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
@@ -987,9 +983,9 @@ bool ClauseProcessor::processMap(
987983
result.mapVars.push_back(mapOp);
988984
ptrMapSyms->push_back(object.sym());
989985
if (mapSymTypes)
990-
mapSymTypes->push_back(symAddr.getType());
986+
mapSymTypes->push_back(baseOp.getType());
991987
if (mapSymLocs)
992-
mapSymLocs->push_back(symAddr.getLoc());
988+
mapSymLocs->push_back(baseOp.getLoc());
993989
}
994990
}
995991
});

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -211,22 +211,18 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
211211
object.ref(), clauseLocation, asFortran, bounds,
212212
treatIndexAsSection);
213213

214-
auto origSymbol = converter.getSymbolAddress(*object.sym());
215-
mlir::Value symAddr = info.addr;
216-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
217-
symAddr = origSymbol;
218-
219214
// Explicit map captures are captured ByRef by default,
220215
// optimisation passes may alter this to ByCopy or other capture
221216
// types to optimise
217+
mlir::Value baseOp = info.rawInput;
222218
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
223-
firOpBuilder, clauseLocation, symAddr,
219+
firOpBuilder, clauseLocation, baseOp,
224220
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
225221
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
226222
static_cast<
227223
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
228224
mapTypeBits),
229-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
225+
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
230226

231227
if (object.sym()->owner().IsDerivedType()) {
232228
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 79 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,92 +1698,98 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
16981698
if (dsp.getAllSymbolsToPrivatize().contains(&sym))
16991699
return;
17001700

1701+
// Structure component symbols don't have bindings, and can only be
1702+
// explicitly mapped individually. If a member is captured implicitly
1703+
// we map the entirety of the derived type when we find its symbol.
1704+
if (sym.owner().IsDerivedType())
1705+
return;
1706+
17011707
// if the symbol is part of an already mapped common block, do not make a
17021708
// map for it.
17031709
if (const Fortran::semantics::Symbol *common =
17041710
Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate()))
17051711
if (llvm::is_contained(mapSyms, common))
17061712
return;
17071713

1708-
if (!llvm::is_contained(mapSyms, &sym)) {
1709-
mlir::Value baseOp = converter.getSymbolAddress(sym);
1710-
if (!baseOp)
1711-
if (const auto *details =
1712-
sym.template detailsIf<semantics::HostAssocDetails>()) {
1713-
baseOp = converter.getSymbolAddress(details->symbol());
1714-
converter.copySymbolBinding(details->symbol(), sym);
1715-
}
1714+
// If we come across a symbol without a symbol address, we
1715+
// return as we cannot process it, this is intended as a
1716+
// catch all early exit for symbols that do not have a
1717+
// corresponding extended value. Such as subroutines,
1718+
// interfaces and named blocks.
1719+
if (!converter.getSymbolAddress(sym))
1720+
return;
17161721

1717-
if (baseOp) {
1718-
llvm::SmallVector<mlir::Value> bounds;
1719-
std::stringstream name;
1720-
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
1721-
name << sym.name().ToString();
1722-
1723-
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
1724-
converter, firOpBuilder, sym, converter.getCurrentLocation());
1725-
if (mlir::isa<fir::BaseBoxType>(
1726-
fir::unwrapRefType(info.addr.getType())))
1727-
bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
1728-
mlir::omp::MapBoundsType>(
1729-
firOpBuilder, converter.getCurrentLocation(), dataExv, info);
1730-
if (mlir::isa<fir::SequenceType>(
1731-
fir::unwrapRefType(info.addr.getType()))) {
1732-
bool dataExvIsAssumedSize =
1733-
semantics::IsAssumedSizeArray(sym.GetUltimate());
1734-
bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
1735-
mlir::omp::MapBoundsType>(
1736-
firOpBuilder, converter.getCurrentLocation(), dataExv,
1737-
dataExvIsAssumedSize);
1738-
}
1722+
if (!llvm::is_contained(mapSyms, &sym)) {
1723+
if (const auto *details =
1724+
sym.template detailsIf<semantics::HostAssocDetails>())
1725+
converter.copySymbolBinding(details->symbol(), sym);
1726+
llvm::SmallVector<mlir::Value> bounds;
1727+
std::stringstream name;
1728+
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
1729+
name << sym.name().ToString();
1730+
1731+
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
1732+
converter, firOpBuilder, sym, converter.getCurrentLocation());
1733+
mlir::Value baseOp = info.rawInput;
1734+
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
1735+
bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
1736+
mlir::omp::MapBoundsType>(
1737+
firOpBuilder, converter.getCurrentLocation(), dataExv, info);
1738+
if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
1739+
bool dataExvIsAssumedSize =
1740+
semantics::IsAssumedSizeArray(sym.GetUltimate());
1741+
bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
1742+
mlir::omp::MapBoundsType>(
1743+
firOpBuilder, converter.getCurrentLocation(), dataExv,
1744+
dataExvIsAssumedSize);
1745+
}
17391746

1740-
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
1741-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1742-
mlir::omp::VariableCaptureKind captureKind =
1743-
mlir::omp::VariableCaptureKind::ByRef;
1744-
1745-
mlir::Type eleType = baseOp.getType();
1746-
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
1747-
eleType = refType.getElementType();
1748-
1749-
// If a variable is specified in declare target link and if device
1750-
// type is not specified as `nohost`, it needs to be mapped tofrom
1751-
mlir::ModuleOp mod = firOpBuilder.getModule();
1752-
mlir::Operation *op = mod.lookupSymbol(converter.mangleName(sym));
1753-
auto declareTargetOp =
1754-
llvm::dyn_cast_if_present<mlir::omp::DeclareTargetInterface>(op);
1755-
if (declareTargetOp && declareTargetOp.isDeclareTarget()) {
1756-
if (declareTargetOp.getDeclareTargetCaptureClause() ==
1757-
mlir::omp::DeclareTargetCaptureClause::link &&
1758-
declareTargetOp.getDeclareTargetDeviceType() !=
1759-
mlir::omp::DeclareTargetDeviceType::nohost) {
1760-
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1761-
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1762-
}
1763-
} else if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) {
1764-
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
1765-
} else if (!fir::isa_builtin_cptr_type(eleType)) {
1747+
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
1748+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1749+
mlir::omp::VariableCaptureKind captureKind =
1750+
mlir::omp::VariableCaptureKind::ByRef;
1751+
1752+
mlir::Type eleType = baseOp.getType();
1753+
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
1754+
eleType = refType.getElementType();
1755+
1756+
// If a variable is specified in declare target link and if device
1757+
// type is not specified as `nohost`, it needs to be mapped tofrom
1758+
mlir::ModuleOp mod = firOpBuilder.getModule();
1759+
mlir::Operation *op = mod.lookupSymbol(converter.mangleName(sym));
1760+
auto declareTargetOp =
1761+
llvm::dyn_cast_if_present<mlir::omp::DeclareTargetInterface>(op);
1762+
if (declareTargetOp && declareTargetOp.isDeclareTarget()) {
1763+
if (declareTargetOp.getDeclareTargetCaptureClause() ==
1764+
mlir::omp::DeclareTargetCaptureClause::link &&
1765+
declareTargetOp.getDeclareTargetDeviceType() !=
1766+
mlir::omp::DeclareTargetDeviceType::nohost) {
17661767
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
17671768
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
17681769
}
1769-
auto location =
1770-
mlir::NameLoc::get(mlir::StringAttr::get(firOpBuilder.getContext(),
1771-
sym.name().ToString()),
1772-
baseOp.getLoc());
1773-
mlir::Value mapOp = createMapInfoOp(
1774-
firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{},
1775-
name.str(), bounds, /*members=*/{},
1776-
/*membersIndex=*/mlir::DenseIntElementsAttr{},
1777-
static_cast<
1778-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1779-
mapFlag),
1780-
captureKind, baseOp.getType());
1781-
1782-
clauseOps.mapVars.push_back(mapOp);
1783-
mapSyms.push_back(&sym);
1784-
mapLocs.push_back(baseOp.getLoc());
1785-
mapTypes.push_back(baseOp.getType());
1770+
} else if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) {
1771+
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
1772+
} else if (!fir::isa_builtin_cptr_type(eleType)) {
1773+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1774+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
17861775
}
1776+
auto location =
1777+
mlir::NameLoc::get(mlir::StringAttr::get(firOpBuilder.getContext(),
1778+
sym.name().ToString()),
1779+
baseOp.getLoc());
1780+
mlir::Value mapOp = createMapInfoOp(
1781+
firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{},
1782+
name.str(), bounds, /*members=*/{},
1783+
/*membersIndex=*/mlir::DenseIntElementsAttr{},
1784+
static_cast<
1785+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1786+
mapFlag),
1787+
captureKind, baseOp.getType());
1788+
1789+
clauseOps.mapVars.push_back(mapOp);
1790+
mapSyms.push_back(&sym);
1791+
mapLocs.push_back(baseOp.getLoc());
1792+
mapTypes.push_back(baseOp.getType());
17871793
}
17881794
};
17891795
lower::pft::visitAllSymbols(eval, captureImplicitMap);

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ class MapInfoFinalizationPass
5050
: public flangomp::impl::MapInfoFinalizationPassBase<
5151
MapInfoFinalizationPass> {
5252

53+
/// Tracks any intermediate function/subroutine local allocations we
54+
/// generate for the descriptors of box type dummy arguments, so that
55+
/// we can retrieve it for subsequent reuses within the functions
56+
/// scope
57+
std::map</*descriptor opaque pointer=*/void *,
58+
/*corresponding local alloca=*/fir::AllocaOp>
59+
localBoxAllocas;
60+
5361
void genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
5462
fir::FirOpBuilder &builder,
5563
mlir::Operation *target) {
@@ -74,14 +82,26 @@ class MapInfoFinalizationPass
7482
// perform an alloca and then store to it and retrieve the data from the new
7583
// alloca.
7684
if (mlir::isa<fir::BaseBoxType>(descriptor.getType())) {
77-
mlir::OpBuilder::InsertPoint insPt = builder.saveInsertionPoint();
78-
mlir::Block *allocaBlock = builder.getAllocaBlock();
79-
assert(allocaBlock && "No alloca block found for this top level op");
80-
builder.setInsertionPointToStart(allocaBlock);
81-
auto alloca = builder.create<fir::AllocaOp>(loc, descriptor.getType());
82-
builder.restoreInsertionPoint(insPt);
83-
builder.create<fir::StoreOp>(loc, descriptor, alloca);
84-
descriptor = alloca;
85+
// If we have already created a local allocation for this BoxType,
86+
// we must be sure to re-use it so that we end up with the same
87+
// allocations being utilised for the same descriptor across all map uses,
88+
// this prevents runtime issues such as not appropriately releasing or
89+
// deleting all mapped data.
90+
auto find = localBoxAllocas.find(descriptor.getAsOpaquePointer());
91+
if (find != localBoxAllocas.end()) {
92+
builder.create<fir::StoreOp>(loc, descriptor, find->second);
93+
descriptor = find->second;
94+
} else {
95+
mlir::OpBuilder::InsertPoint insPt = builder.saveInsertionPoint();
96+
mlir::Block *allocaBlock = builder.getAllocaBlock();
97+
assert(allocaBlock && "No alloca block found for this top level op");
98+
builder.setInsertionPointToStart(allocaBlock);
99+
auto alloca = builder.create<fir::AllocaOp>(loc, descriptor.getType());
100+
builder.restoreInsertionPoint(insPt);
101+
builder.create<fir::StoreOp>(loc, descriptor, alloca);
102+
localBoxAllocas[descriptor.getAsOpaquePointer()] = alloca;
103+
descriptor = alloca;
104+
}
85105
}
86106

87107
mlir::Value baseAddrAddr = builder.create<fir::BoxOffsetOp>(
@@ -234,27 +254,41 @@ class MapInfoFinalizationPass
234254
fir::KindMapping kindMap = fir::getKindMapping(module);
235255
fir::FirOpBuilder builder{module, std::move(kindMap)};
236256

237-
getOperation()->walk([&](mlir::omp::MapInfoOp op) {
238-
// TODO: Currently only supports a single user for the MapInfoOp, this
239-
// is fine for the moment as the Fortran Frontend will generate a
240-
// new MapInfoOp per Target operation for the moment. However, when/if
241-
// we optimise/cleanup the IR, it likely isn't too difficult to
242-
// extend this function, it would require some modification to create a
243-
// single new MapInfoOp per new MapInfoOp generated and share it across
244-
// all users appropriately, making sure to only add a single member link
245-
// per new generation for the original originating descriptor MapInfoOp.
246-
assert(llvm::hasSingleElement(op->getUsers()) &&
247-
"MapInfoFinalization currently only supports single users "
248-
"of a MapInfoOp");
257+
// We wish to maintain some function level scope (currently
258+
// just local function scope variables used to load and store box
259+
// variables into so we can access their base address, an
260+
// quirk of box_offset requires us to have an in memory box, but Fortran
261+
// in certain cases does not provide this) whilst not subjecting
262+
// ourselves to the possibility of race conditions while this pass
263+
// undergoes frequent re-iteration for the near future. So we loop
264+
// over function in the module and then map.info inside of those.
265+
getOperation()->walk([&](mlir::func::FuncOp func) {
266+
// clear all local allocations we made for any boxes in any prior
267+
// iterations from previous function scopes.
268+
localBoxAllocas.clear();
249269

250-
if (!op.getMembers().empty()) {
251-
addImplicitMembersToTarget(op, builder, *op->getUsers().begin());
252-
} else if (fir::isTypeWithDescriptor(op.getVarType()) ||
253-
mlir::isa_and_present<fir::BoxAddrOp>(
254-
op.getVarPtr().getDefiningOp())) {
255-
builder.setInsertionPoint(op);
256-
genDescriptorMemberMaps(op, builder, *op->getUsers().begin());
257-
}
270+
func->walk([&](mlir::omp::MapInfoOp op) {
271+
// TODO: Currently only supports a single user for the MapInfoOp, this
272+
// is fine for the moment as the Fortran Frontend will generate a
273+
// new MapInfoOp per Target operation for the moment. However, when/if
274+
// we optimise/cleanup the IR, it likely isn't too difficult to
275+
// extend this function, it would require some modification to create a
276+
// single new MapInfoOp per new MapInfoOp generated and share it across
277+
// all users appropriately, making sure to only add a single member link
278+
// per new generation for the original originating descriptor MapInfoOp.
279+
assert(llvm::hasSingleElement(op->getUsers()) &&
280+
"OMPMapInfoFinalization currently only supports single users "
281+
"of a MapInfoOp");
282+
283+
if (!op.getMembers().empty()) {
284+
addImplicitMembersToTarget(op, builder, *op->getUsers().begin());
285+
} else if (fir::isTypeWithDescriptor(op.getVarType()) ||
286+
mlir::isa_and_present<fir::BoxAddrOp>(
287+
op.getVarPtr().getDefiningOp())) {
288+
builder.setInsertionPoint(op);
289+
genDescriptorMemberMaps(op, builder, *op->getUsers().begin());
290+
}
291+
});
258292
});
259293
}
260294
};

0 commit comments

Comments
 (0)