Skip to content

Commit de3d3bc

Browse files
committed
[Flang][OpenMP][MLIR] Initial derived type member map support
This patch is one in a series of four patches that seeks to refactor slightly and extend the current record type map support that was put in place for Fortran's descriptor types to handle explicit member mapping for record types at a single level of depth. For example, the below case where two members of a Fortran derived type are mapped explicitly: '''' type :: scalar_and_array real(4) :: real integer(4) :: array(10) integer(4) :: int end type scalar_and_array type(scalar_and_array) :: scalar_arr !$omp target map(tofrom: scalar_arr%int, scalar_arr%real) '''' Current cases of derived type mapping left for future work are: > explicit member mapping of nested members (e.g. two layers of record types where we explicitly map a member from the internal record type) > Fortran's automagical mapping of all elements and nested elements of a derived type > explicit member mapping of a derived type and then constituient members (redundant in Fortran due to former case but still legal as far as I am aware) > explicit member mapping of a record type (may be handled reasonably, just not fully tested in this iteration) > explicit member mapping for Fortran allocatable types (a variation of nested record types) This patch seeks to support this by extending the Flang-new OpenMP lowering to support generation of this newly required information, creating the neccessary parent <-to-> member map_info links, calculating the member indices and setting if it's a partial map. The OMPDescriptorMapInfoGen pass has also been generalized into a map finalization phase, now named OMPMapInfoFinalization. This pass was extended to support the insertion of member maps into the BlockArg and MapOperands of relevant map carrying operations. Similar to the method in which descriptor types are expanded and constituient members inserted. Pull Request: llvm#82853
1 parent f40b28e commit de3d3bc

25 files changed

+1183
-267
lines changed

flang/docs/OpenMP-descriptor-management.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Currently, Flang will lower these descriptor types in the OpenMP lowering (lower
4444
to all other map types, generating an omp.MapInfoOp containing relevant information required for lowering
4545
the OpenMP dialect to LLVM-IR during the final stages of the MLIR lowering. However, after
4646
the lowering to FIR/HLFIR has been performed an OpenMP dialect specific pass for Fortran,
47-
`OMPDescriptorMapInfoGenPass` (Optimizer/OMPDescriptorMapInfoGen.cpp) will expand the
47+
`OMPMapInfoFinalizationPass` (Optimizer/OMPMapInfoFinalization.cpp) will expand the
4848
`omp.MapInfoOp`'s containing descriptors (which currently will be a `BoxType` or `BoxAddrOp`) into multiple
4949
mappings, with one extra per pointer member in the descriptor that is supported on top of the original
5050
descriptor map operation. These pointers members are linked to the parent descriptor by adding them to
@@ -53,7 +53,7 @@ owning operation's (`omp.TargetOp`, `omp.TargetDataOp` etc.) map operand list an
5353
operation is `IsolatedFromAbove`, it also inserts them as `BlockArgs` to canonicalize the mappings and
5454
simplify lowering.
5555
56-
An example transformation by the `OMPDescriptorMapInfoGenPass`:
56+
An example transformation by the `OMPMapInfoFinalizationPass`:
5757
5858
```
5959

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ std::unique_ptr<mlir::Pass>
7575
createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
7676
std::unique_ptr<mlir::Pass> createPolymorphicOpConversionPass();
7777

78-
std::unique_ptr<mlir::Pass> createOMPDescriptorMapInfoGenPass();
78+
std::unique_ptr<mlir::Pass> createOMPMapInfoFinalizationPass();
7979
std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
8080
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
8181
createOMPMarkDeclareTargetPass();

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,15 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> {
328328
let dependentDialects = [ "fir::FIROpsDialect" ];
329329
}
330330

331-
def OMPDescriptorMapInfoGenPass
332-
: Pass<"omp-descriptor-map-info-gen", "mlir::func::FuncOp"> {
331+
def OMPMapInfoFinalizationPass
332+
: Pass<"omp-map-info-finalization", "mlir::func::FuncOp"> {
333333
let summary = "expands OpenMP MapInfo operations containing descriptors";
334334
let description = [{
335335
Expands MapInfo operations containing descriptor types into multiple
336336
MapInfo's for each pointer element in the descriptor that requires
337337
explicit individual mapping by the OpenMP runtime.
338338
}];
339-
let constructor = "::fir::createOMPDescriptorMapInfoGenPass()";
339+
let constructor = "::fir::createOMPMapInfoFinalizationPass()";
340340
let dependentDialects = ["mlir::omp::OpenMPDialect"];
341341
}
342342

flang/include/flang/Tools/CLOptions.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ inline void createHLFIRToFIRPassPipeline(
319319
/// rather than the host device.
320320
inline void createOpenMPFIRPassPipeline(
321321
mlir::PassManager &pm, bool isTargetDevice) {
322-
pm.addPass(fir::createOMPDescriptorMapInfoGenPass());
322+
pm.addPass(fir::createOMPMapInfoFinalizationPass());
323323
pm.addPass(fir::createOMPMarkDeclareTargetPass());
324324
if (isTargetDevice)
325325
pm.addPass(fir::createOMPFunctionFilteringPass());

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -811,9 +811,10 @@ mlir::omp::MapInfoOp
811811
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
812812
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
813813
llvm::ArrayRef<mlir::Value> bounds,
814-
llvm::ArrayRef<mlir::Value> members, uint64_t mapType,
814+
llvm::ArrayRef<mlir::Value> members,
815+
mlir::DenseIntElementsAttr membersIndex, uint64_t mapType,
815816
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
816-
bool isVal) {
817+
bool partialMap) {
817818
if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
818819
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
819820
retTy = baseAddr.getType();
@@ -823,10 +824,10 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
823824
llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
824825

825826
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
826-
loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
827+
loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
827828
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
828829
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
829-
builder.getStringAttr(name));
830+
builder.getStringAttr(name), builder.getBoolAttr(partialMap));
830831

831832
return op;
832833
}
@@ -838,7 +839,11 @@ bool ClauseProcessor::processMap(
838839
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
839840
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
840841
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
841-
return findRepeatableClause<omp::clause::Map>(
842+
std::map<const Fortran::semantics::Symbol *,
843+
llvm::SmallVector<OmpMapMemberIndicesData>>
844+
parentMemberIndices;
845+
846+
bool clauseFound = findRepeatableClause<omp::clause::Map>(
842847
[&](const omp::clause::Map &clause,
843848
const Fortran::parser::CharBlock &source) {
844849
using Map = omp::clause::Map;
@@ -903,24 +908,39 @@ bool ClauseProcessor::processMap(
903908
// Explicit map captures are captured ByRef by default,
904909
// optimisation passes may alter this to ByCopy or other capture
905910
// types to optimise
906-
mlir::Value mapOp = createMapInfoOp(
911+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
907912
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
908-
asFortran.str(), bounds, {},
913+
asFortran.str(), bounds, {}, mlir::DenseIntElementsAttr{},
909914
static_cast<
910915
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
911916
mapTypeBits),
912917
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
913918

914-
result.mapVars.push_back(mapOp);
915-
916-
if (mapSyms)
919+
if (object.id()->owner().IsDerivedType()) {
920+
if (auto dataRef{ExtractDataRef(object.designator)}) {
921+
const Fortran::semantics::Symbol *parentSym = parentSym =
922+
&dataRef->GetFirstSymbol();
923+
assert(parentSym &&
924+
"Could not find parent symbol during lower of "
925+
"a component member in OpenMP map clause");
926+
parentMemberIndices[parentSym].push_back(
927+
{generateMemberPlacementIndices(object, semaCtx), mapOp});
928+
}
929+
} else {
930+
result.mapVars.push_back(mapOp);
917931
mapSyms->push_back(object.id());
918-
if (mapSymLocs)
919-
mapSymLocs->push_back(symAddr.getLoc());
920-
if (mapSymTypes)
921-
mapSymTypes->push_back(symAddr.getType());
932+
if (mapSymTypes)
933+
mapSymTypes->push_back(symAddr.getType());
934+
if (mapSymLocs)
935+
mapSymLocs->push_back(symAddr.getLoc());
936+
}
922937
}
923938
});
939+
940+
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
941+
mapSymTypes, mapSymLocs, mapSyms);
942+
943+
return clauseFound;
924944
}
925945

926946
bool ClauseProcessor::processReduction(

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,9 @@ class ClauseProcessor {
115115
bool processMap(
116116
mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx,
117117
mlir::omp::MapClauseOps &result,
118-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms =
119-
nullptr,
120-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
121-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
118+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols,
119+
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
120+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr) const;
122121
bool processReduction(
123122
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
124123
llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
@@ -185,7 +184,12 @@ template <typename T>
185184
bool ClauseProcessor::processMotionClauses(
186185
Fortran::lower::StatementContext &stmtCtx,
187186
mlir::omp::MapClauseOps &result) {
188-
return findRepeatableClause<T>(
187+
std::map<const Fortran::semantics::Symbol *,
188+
llvm::SmallVector<OmpMapMemberIndicesData>>
189+
parentMemberIndices;
190+
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
191+
192+
bool clauseFound = findRepeatableClause<T>(
189193
[&](const T &clause, const Fortran::parser::CharBlock &source) {
190194
mlir::Location clauseLocation = converter.genLocation(source);
191195
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -203,6 +207,7 @@ bool ClauseProcessor::processMotionClauses(
203207
for (const omp::Object &object : objects) {
204208
llvm::SmallVector<mlir::Value> bounds;
205209
std::stringstream asFortran;
210+
206211
Fortran::lower::AddrAndBoundsInfo info =
207212
Fortran::lower::gatherDataOperandAddrAndBounds<
208213
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
@@ -218,17 +223,34 @@ bool ClauseProcessor::processMotionClauses(
218223
// Explicit map captures are captured ByRef by default,
219224
// optimisation passes may alter this to ByCopy or other capture
220225
// types to optimise
221-
mlir::Value mapOp = createMapInfoOp(
226+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
222227
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
223-
asFortran.str(), bounds, {},
228+
asFortran.str(), bounds, {}, mlir::DenseIntElementsAttr{},
224229
static_cast<
225230
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
226231
mapTypeBits),
227232
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
228233

229-
result.mapVars.push_back(mapOp);
234+
if (object.id()->owner().IsDerivedType()) {
235+
if (auto dataRef{ExtractDataRef(object.designator)}) {
236+
const Fortran::semantics::Symbol *parentSym =
237+
&dataRef->GetFirstSymbol();
238+
assert(parentSym &&
239+
"Could not find parent symbol during lower of "
240+
"a component member in OpenMP map clause");
241+
parentMemberIndices[parentSym].push_back(
242+
{generateMemberPlacementIndices(object, semaCtx), mapOp});
243+
}
244+
} else {
245+
result.mapVars.push_back(mapOp);
246+
mapSymbols.push_back(object.id());
247+
}
230248
}
231249
});
250+
251+
insertChildMapInfoIntoParent(converter, parentMemberIndices, mapOperands,
252+
nullptr, nullptr, &mapSymbols);
253+
return clauseFound;
232254
}
233255

234256
template <typename... Ts>

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,7 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
936936
mlir::Value mapOp = createMapInfoOp(
937937
firOpBuilder, copyVal.getLoc(), copyVal, mlir::Value{}, name.str(),
938938
bounds, llvm::SmallVector<mlir::Value>{},
939+
mlir::DenseIntElementsAttr{},
939940
static_cast<
940941
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
941942
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
@@ -1617,7 +1618,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
16171618

16181619
mlir::Value mapOp = createMapInfoOp(
16191620
firOpBuilder, baseOp.getLoc(), baseOp, mlir::Value{}, name.str(),
1620-
bounds, {},
1621+
bounds, {}, mlir::DenseIntElementsAttr{},
16211622
static_cast<
16221623
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
16231624
mapFlag),

0 commit comments

Comments
 (0)