Skip to content

Commit 28f4bf9

Browse files
committed
[Flang][OpenMP] Derived type explicit allocatable member mapping
This PR is one of 3 in a PR stack, this is the primary change set which seeks to extend the current derived type explicit member mapping support to handle descriptor member mapping at arbitrary levels of nesting. The PR stack seems to do this reasonably (from testing so far) but as you can create quite complex mappings with derived types (in particular when adding allocatable derived types or arrays of allocatable derived types) I imagine there will be hiccups, which I am more than happy to address. There will also be further extensions to this work to handle the implicit auto-magical mapping of descriptor members in derived types and a few other changes planned for the future (with some ideas on optimizing things). The changes in this PR primarily occur in the OpenMP lowering and the OMPMapInfoFinalization pass. In the OpenMP lowering several utility functions were added or extended to support the generation of appropriate intermediate member mappings which are currently required when the parent (or multiple parents) of a mapped member are descriptor types. We need to map the entirety of these types or do a "deep copy" for lack of a better term, where we map both the base address and the descriptor as without the copying of both of these we lack the information in the case of the descriptor to access the member or attach the pointers data to the pointer and in the latter case we require the base address to map the chunk of data. Currently we do not segment descriptor based derived types as we do with regular non-descriptor derived types, we effectively map their entirety in all cases at the moment, I hope to address this at some point in the future as it adds a fair bit of a performance penalty to having nestings of allocatable derived types as an example. The process of mapping all intermediate descriptor members in a members path only occurs if a member has an allocatable or object parent in its symbol path or the member itself is a member or allocatable. This occurs in the createParentSymAndGenIntermediateMaps function, which will also generate the appropriate address for the allocatable member within the derived type to use as a the varPtr field of the map (for intermediate allocatable maps and final allocatable mappings). In this case it's necessary as we can't utilise the usual Fortran::lower functionality such as gatherDataOperandAddrAndBounds without causing issues later in the lowering due to extra allocas being spawned which seem to affect the pointer attachment (at least this is my current assumption, it results in memory access errors on the device due to incorrect map information generation). This is similar to why we do not use the MLIR value generated for this and utilise the original symbol provided when mapping descriptor types external to derived types. Hopefully this can be rectified in the future so this function can be simplified and more closely aligned to the other type mappings. We also make use of fir::CoordinateOp as opposed to the HLFIR version as the HLFIR version doesn't support the appropriate lowering to FIR necessary at the moment, we also cannot use a single CoordinateOp (similarly to a single GEP) as when we index through a descriptor operation (BoxType) we encounter issues later in the lowering, however in either case we need access to intermediate descriptors so individual CoordinateOp's aid this (although, being able to compress them into a smaller amount of CoordinateOp's may simplify the IR and perhaps result in a better end product, something to consider for the future). The other large change area was in the OMPMapInfoFinalization pass, where the pass had to be extended to support the expansion of box types (or multiple nestings of box types) within derived types, or box type derived types. This requires expanding each BoxType mapping from one into two maps and then modifying all of the existing member indices of the overarching parent mapping to account for the addition of these new members alongside adjusting the existing member indices to support the addition of these new maps which extend the original member indices (as a base address of a box type is currently considered a member of the box type at a position of 0 as when lowered to LLVM-IR it's a pointer contained at this position in the descriptor type, however, this means extending mapped children of this expanded descriptor type to additionally incorporate the new member index in the correct location in its own index list). I believe there is a reasonable amount of comments that should aid in understanding this better, alongside the test alterations for the pass. A subset of the changes were also aimed at making some of the utilities for packing and unpacking the DenseIntElementsAttr containing the member indices shareable across the lowering and OMPMapInfoFinalization, this required moving some functions to the Lower/Support/Utils.h header, and transforming the lowering structure containing the member index data into something more similar to the version used in OMPMapInfoFinalization. There we also some other attempts at tidying things up in relation to the member index data generation in the lowering, some of which required creating a logical operator for the OpenMP ID class so it can be utilised as a map key (it simply utilises the symbol address for the moment as ordering isn't particularly important). Otherwise I have added a set of new tests encompassing some of the mappings currently supported by this PR (unfortunately as you can have arbitrary nestings of all shapes and types it's not very feasible to cover them all).
1 parent 5192cb7 commit 28f4bf9

20 files changed

+1884
-397
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
215215
llvm::ArrayRef<mlir::Value> lenParams,
216216
bool asTarget = false);
217217

218+
/// Create a two dimensional ArrayAttr containing integer data as
219+
/// IntegerAttrs, effectively: ArrayAttr<ArrayAttr<IntegerAttr>>>.
220+
mlir::ArrayAttr create2DI64ArrayAttr(
221+
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &intData);
222+
218223
/// Create a temporary using `fir.alloca`. This function does not hoist.
219224
/// It is the callers responsibility to set the insertion point if
220225
/// hoisting is required.

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -891,14 +891,15 @@ void ClauseProcessor::processMapObjects(
891891
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
892892
const omp::ObjectList &objects,
893893
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
894-
std::map<const semantics::Symbol *,
895-
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
894+
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
896895
llvm::SmallVectorImpl<mlir::Value> &mapVars,
897896
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const {
898897
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
898+
899899
for (const omp::Object &object : objects) {
900900
llvm::SmallVector<mlir::Value> bounds;
901901
std::stringstream asFortran;
902+
std::optional<omp::Object> parentObj;
902903

903904
lower::AddrAndBoundsInfo info =
904905
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
@@ -907,24 +908,43 @@ void ClauseProcessor::processMapObjects(
907908
object.ref(), clauseLocation, asFortran, bounds,
908909
treatIndexAsSection);
909910

911+
mlir::Value baseOp = info.rawInput;
912+
if (object.sym()->owner().IsDerivedType()) {
913+
omp::ObjectList objectList = gatherObjectsOf(object, semaCtx);
914+
assert(!objectList.empty() &&
915+
"could not find parent objects of derived type member");
916+
parentObj = objectList[0];
917+
parentMemberIndices.emplace(parentObj.value(),
918+
OmpMapParentAndMemberData{});
919+
920+
if (isMemberOrParentAllocatableOrPointer(object, semaCtx)) {
921+
llvm::SmallVector<int64_t> indices;
922+
generateMemberPlacementIndices(object, indices, semaCtx);
923+
baseOp = createParentSymAndGenIntermediateMaps(
924+
clauseLocation, converter, semaCtx, stmtCtx, objectList, indices,
925+
parentMemberIndices[parentObj.value()], asFortran.str(),
926+
mapTypeBits);
927+
}
928+
}
929+
910930
// Explicit map captures are captured ByRef by default,
911931
// optimisation passes may alter this to ByCopy or other capture
912932
// types to optimise
913-
mlir::Value baseOp = info.rawInput;
914933
auto location = mlir::NameLoc::get(
915934
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
916935
baseOp.getLoc());
917936
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
918937
firOpBuilder, location, baseOp,
919938
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
920-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
939+
/*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{},
921940
static_cast<
922941
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
923942
mapTypeBits),
924943
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
925944

926-
if (object.sym()->owner().IsDerivedType()) {
927-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
945+
if (parentObj.has_value()) {
946+
parentMemberIndices[parentObj.value()].addChildIndexAndMapToParent(
947+
object, mapOp, semaCtx);
928948
} else {
929949
mapVars.push_back(mapOp);
930950
mapSyms.push_back(object.sym());
@@ -942,9 +962,7 @@ bool ClauseProcessor::processMap(
942962
llvm::SmallVector<const semantics::Symbol *> localMapSyms;
943963
llvm::SmallVectorImpl<const semantics::Symbol *> *ptrMapSyms =
944964
mapSyms ? mapSyms : &localMapSyms;
945-
std::map<const semantics::Symbol *,
946-
llvm::SmallVector<OmpMapMemberIndicesData>>
947-
parentMemberIndices;
965+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
948966

949967
auto process = [&](const omp::clause::Map &clause,
950968
const parser::CharBlock &source) {
@@ -1004,17 +1022,15 @@ bool ClauseProcessor::processMap(
10041022
};
10051023

10061024
bool clauseFound = findRepeatableClause<omp::clause::Map>(process);
1007-
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
1008-
*ptrMapSyms);
1025+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1026+
result.mapVars, *ptrMapSyms);
10091027

10101028
return clauseFound;
10111029
}
10121030

10131031
bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
10141032
mlir::omp::MapClauseOps &result) {
1015-
std::map<const semantics::Symbol *,
1016-
llvm::SmallVector<OmpMapMemberIndicesData>>
1017-
parentMemberIndices;
1033+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
10181034
llvm::SmallVector<const semantics::Symbol *> mapSymbols;
10191035

10201036
auto callbackFn = [&](const auto &clause, const parser::CharBlock &source) {
@@ -1035,8 +1051,9 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
10351051
clauseFound =
10361052
findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;
10371053

1038-
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
1039-
mapSymbols);
1054+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1055+
result.mapVars, mapSymbols);
1056+
10401057
return clauseFound;
10411058
}
10421059

@@ -1098,9 +1115,7 @@ bool ClauseProcessor::processEnter(
10981115
bool ClauseProcessor::processUseDeviceAddr(
10991116
lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
11001117
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1101-
std::map<const semantics::Symbol *,
1102-
llvm::SmallVector<OmpMapMemberIndicesData>>
1103-
parentMemberIndices;
1118+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
11041119
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
11051120
[&](const omp::clause::UseDeviceAddr &clause,
11061121
const parser::CharBlock &source) {
@@ -1113,17 +1128,16 @@ bool ClauseProcessor::processUseDeviceAddr(
11131128
useDeviceSyms);
11141129
});
11151130

1116-
insertChildMapInfoIntoParent(converter, parentMemberIndices,
1131+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
11171132
result.useDeviceAddrVars, useDeviceSyms);
11181133
return clauseFound;
11191134
}
11201135

11211136
bool ClauseProcessor::processUseDevicePtr(
11221137
lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
11231138
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1124-
std::map<const semantics::Symbol *,
1125-
llvm::SmallVector<OmpMapMemberIndicesData>>
1126-
parentMemberIndices;
1139+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1140+
11271141
bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
11281142
[&](const omp::clause::UseDevicePtr &clause,
11291143
const parser::CharBlock &source) {
@@ -1136,7 +1150,7 @@ bool ClauseProcessor::processUseDevicePtr(
11361150
useDeviceSyms);
11371151
});
11381152

1139-
insertChildMapInfoIntoParent(converter, parentMemberIndices,
1153+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
11401154
result.useDevicePtrVars, useDeviceSyms);
11411155
return clauseFound;
11421156
}

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,7 @@ class ClauseProcessor {
166166
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
167167
const omp::ObjectList &objects,
168168
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
169-
std::map<const semantics::Symbol *,
170-
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
169+
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
171170
llvm::SmallVectorImpl<mlir::Value> &mapVars,
172171
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const;
173172

flang/lib/Lower/OpenMP/Clauses.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ struct IdTyTemplate {
5151
return designator == other.designator;
5252
}
5353

54+
// Defining an "ordering" which allows types derived from this to be
55+
// utilised in maps and other containers that require comparison
56+
// operators for ordering
57+
bool operator<(const IdTyTemplate &other) const {
58+
return symbol < other.symbol;
59+
}
60+
5461
operator bool() const { return symbol != nullptr; }
5562
};
5663

@@ -72,6 +79,10 @@ struct ObjectT<Fortran::lower::omp::IdTyTemplate<Fortran::lower::omp::ExprTy>,
7279
Fortran::semantics::Symbol *sym() const { return identity.symbol; }
7380
const std::optional<ExprTy> &ref() const { return identity.designator; }
7481

82+
bool operator<(const ObjectT<IdTy, ExprTy> &other) const {
83+
return identity < other.identity;
84+
}
85+
7586
IdTy identity;
7687
};
7788
} // namespace tomp::type

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ static void genBodyOfTargetOp(
10111011
firOpBuilder, copyVal.getLoc(), copyVal,
10121012
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds,
10131013
/*members=*/llvm::SmallVector<mlir::Value>{},
1014-
/*membersIndex=*/mlir::DenseIntElementsAttr{},
1014+
/*membersIndex=*/mlir::ArrayAttr{},
10151015
static_cast<
10161016
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
10171017
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
@@ -1801,7 +1801,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
18011801
mlir::Value mapOp = createMapInfoOp(
18021802
firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{},
18031803
name.str(), bounds, /*members=*/{},
1804-
/*membersIndex=*/mlir::DenseIntElementsAttr{},
1804+
/*membersIndex=*/mlir::ArrayAttr{},
18051805
static_cast<
18061806
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
18071807
mapFlag),

0 commit comments

Comments
 (0)