Skip to content

Commit 5bd8202

Browse files
committed
[Flang][OpenMP][MLIR] Initial derived type member map support
This patch is one in a series of four patches that seeks to refactor slightly and extend the current record type map support that was put in place for Fortran's descriptor types to handle explicit member mapping for record types at a single level of depth. For example, the below case where two members of a Fortran derived type are mapped explicitly: '''' type :: scalar_and_array real(4) :: real integer(4) :: array(10) integer(4) :: int end type scalar_and_array type(scalar_and_array) :: scalar_arr !$omp target map(tofrom: scalar_arr%int, scalar_arr%real) '''' Current cases of derived type mapping left for future work are: > explicit member mapping of nested members (e.g. two layers of record types where we explicitly map a member from the internal record type) > Fortran's automagical mapping of all elements and nested elements of a derived type > explicit member mapping of a derived type and then constituient members (redundant in Fortran due to former case but still legal as far as I am aware) > explicit member mapping of a record type (may be handled reasonably, just not fully tested in this iteration) > explicit member mapping for Fortran allocatable types (a variation of nested record types) This patch seeks to support this by extending the Flang-new OpenMP lowering to support generation of this newly required information, creating the neccessary parent <-to-> member map_info links, calculating the member indices and setting if it's a partial map. The OMPDescriptorMapInfoGen pass has also been generalized into a map finalization phase, now named OMPMapInfoFinalization. This pass was extended to support the insertion of member maps into the BlockArg and MapOperands of relevant map carrying operations. Similar to the method in which descriptor types are expanded and constituient members inserted. Pull Request: llvm#82853
1 parent 122f30d commit 5bd8202

24 files changed

+880
-254
lines changed

flang/docs/OpenMP-descriptor-management.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ Currently, Flang will lower these descriptor types in the OpenMP lowering (lower
4444
to all other map types, generating an omp.MapInfoOp containing relevant information required for lowering
4545
the OpenMP dialect to LLVM-IR during the final stages of the MLIR lowering. However, after
4646
the lowering to FIR/HLFIR has been performed an OpenMP dialect specific pass for Fortran,
47-
`OMPDescriptorMapInfoGenPass` (Optimizer/OMPDescriptorMapInfoGen.cpp) will expand the
47+
`OMPMapInfoFinalizationPass` (Optimizer/OMPMapInfoFinalization.cpp) will expand the
4848
`omp.MapInfoOp`'s containing descriptors (which currently will be a `BoxType` or `BoxAddrOp`) into multiple
4949
mappings, with one extra per pointer member in the descriptor that is supported on top of the original
5050
descriptor map operation. These pointers members are linked to the parent descriptor by adding them to
5151
the member field of the original descriptor map operation, they are then inserted into the relevant map
5252
owning operation's (`omp.TargetOp`, `omp.DataOp` etc.) map operand list and in cases where the owning operation
5353
is `IsolatedFromAbove`, it also inserts them as `BlockArgs` to canonicalize the mappings and simplify lowering.
5454
55-
An example transformation by the `OMPDescriptorMapInfoGenPass`:
55+
An example transformation by the `OMPMapInfoFinalizationPass`:
5656
5757
```
5858

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ std::unique_ptr<mlir::Pass>
7676
createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
7777
std::unique_ptr<mlir::Pass> createPolymorphicOpConversionPass();
7878

79-
std::unique_ptr<mlir::Pass> createOMPDescriptorMapInfoGenPass();
79+
std::unique_ptr<mlir::Pass> createOMPMapInfoFinalizationPass();
8080
std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
8181
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
8282
createOMPMarkDeclareTargetPass();

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,15 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> {
318318
let dependentDialects = [ "fir::FIROpsDialect" ];
319319
}
320320

321-
def OMPDescriptorMapInfoGenPass
322-
: Pass<"omp-descriptor-map-info-gen", "mlir::func::FuncOp"> {
321+
def OMPMapInfoFinalizationPass
322+
: Pass<"omp-map-info-finalization", "mlir::func::FuncOp"> {
323323
let summary = "expands OpenMP MapInfo operations containing descriptors";
324324
let description = [{
325325
Expands MapInfo operations containing descriptor types into multiple
326326
MapInfo's for each pointer element in the descriptor that requires
327327
explicit individual mapping by the OpenMP runtime.
328328
}];
329-
let constructor = "::fir::createOMPDescriptorMapInfoGenPass()";
329+
let constructor = "::fir::createOMPMapInfoFinalizationPass()";
330330
let dependentDialects = ["mlir::omp::OpenMPDialect"];
331331
}
332332

flang/include/flang/Tools/CLOptions.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ inline void createHLFIRToFIRPassPipeline(
274274
/// rather than the host device.
275275
inline void createOpenMPFIRPassPipeline(
276276
mlir::PassManager &pm, bool isTargetDevice) {
277-
pm.addPass(fir::createOMPDescriptorMapInfoGenPass());
277+
pm.addPass(fir::createOMPMapInfoFinalizationPass());
278278
pm.addPass(fir::createOMPMarkDeclareTargetPass());
279279
if (isTargetDevice)
280280
pm.addPass(fir::createOMPFunctionFilteringPass());

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -843,9 +843,10 @@ mlir::omp::MapInfoOp
843843
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
844844
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
845845
mlir::SmallVector<mlir::Value> bounds,
846-
mlir::SmallVector<mlir::Value> members, uint64_t mapType,
846+
mlir::SmallVector<mlir::Value> members,
847+
mlir::ArrayAttr membersIndex, uint64_t mapType,
847848
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
848-
bool isVal) {
849+
bool partialMap) {
849850
if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
850851
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
851852
retTy = baseAddr.getType();
@@ -855,10 +856,10 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
855856
llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
856857

857858
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
858-
loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
859+
loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
859860
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
860861
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
861-
builder.getStringAttr(name));
862+
builder.getStringAttr(name), builder.getBoolAttr(partialMap));
862863

863864
return op;
864865
}
@@ -867,12 +868,16 @@ bool ClauseProcessor::processMap(
867868
mlir::Location currentLocation, const llvm::omp::Directive &directive,
868869
Fortran::lower::StatementContext &stmtCtx,
869870
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
871+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols,
870872
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
871-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
872-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
873-
const {
873+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs) const {
874874
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
875-
return findRepeatableClause<ClauseTy::Map>(
875+
876+
llvm::SmallVector<mlir::omp::MapInfoOp> memberMaps;
877+
llvm::SmallVector<mlir::Attribute> memberPlacementIndices;
878+
llvm::SmallVector<const Fortran::semantics::Symbol *> memberParentSyms;
879+
880+
bool clauseFound = findRepeatableClause<ClauseTy::Map>(
876881
[&](const ClauseTy::Map *mapClause,
877882
const Fortran::parser::CharBlock &source) {
878883
mlir::Location clauseLocation = converter.genLocation(source);
@@ -919,8 +924,27 @@ bool ClauseProcessor::processMap(
919924

920925
for (const Fortran::parser::OmpObject &ompObject :
921926
std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) {
927+
llvm::omp::OpenMPOffloadMappingFlags objectsMapTypeBits = mapTypeBits;
928+
checkAndApplyDeclTargetMapFlags(converter, objectsMapTypeBits,
929+
*getOmpObjectSymbol(ompObject));
930+
922931
llvm::SmallVector<mlir::Value> bounds;
923932
std::stringstream asFortran;
933+
const Fortran::semantics::Symbol *parentSym = nullptr;
934+
935+
if (getOmpObjectSymbol(ompObject)->owner().IsDerivedType()) {
936+
const auto *designator =
937+
Fortran::parser::Unwrap<Fortran::parser::Designator>(
938+
ompObject.u);
939+
assert(designator && "Expected a designator from derived type "
940+
"component during map clause processing");
941+
parentSym = GetFirstName(*designator).symbol;
942+
memberParentSyms.push_back(parentSym);
943+
memberPlacementIndices.push_back(
944+
firOpBuilder.getI64IntegerAttr(findComponentMemberPlacement(
945+
&parentSym->GetType()->derivedTypeSpec().typeSymbol(),
946+
getOmpObjectSymbol(ompObject))));
947+
}
924948

925949
Fortran::lower::AddrAndBoundsInfo info =
926950
Fortran::lower::gatherDataOperandAddrAndBounds<
@@ -938,24 +962,32 @@ bool ClauseProcessor::processMap(
938962
// Explicit map captures are captured ByRef by default,
939963
// optimisation passes may alter this to ByCopy or other capture
940964
// types to optimise
941-
mlir::Value mapOp = createMapInfoOp(
965+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
942966
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
943-
asFortran.str(), bounds, {},
967+
asFortran.str(), bounds, {}, mlir::ArrayAttr{},
944968
static_cast<
945969
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
946-
mapTypeBits),
970+
objectsMapTypeBits),
947971
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
948972

949-
mapOperands.push_back(mapOp);
950-
if (mapSymTypes)
951-
mapSymTypes->push_back(symAddr.getType());
952-
if (mapSymLocs)
953-
mapSymLocs->push_back(symAddr.getLoc());
954-
955-
if (mapSymbols)
973+
if (parentSym) {
974+
memberMaps.push_back(mapOp);
975+
} else {
976+
mapOperands.push_back(mapOp);
956977
mapSymbols->push_back(getOmpObjectSymbol(ompObject));
978+
if (mapSymTypes)
979+
mapSymTypes->push_back(symAddr.getType());
980+
if (mapSymLocs)
981+
mapSymLocs->push_back(symAddr.getLoc());
982+
}
957983
}
958984
});
985+
986+
insertChildMapInfoIntoParent(converter, memberParentSyms, memberMaps,
987+
memberPlacementIndices, mapOperands, mapSymTypes,
988+
mapSymLocs, mapSymbols);
989+
990+
return clauseFound;
959991
}
960992

961993
bool ClauseProcessor::processReduction(

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,13 @@ class ClauseProcessor {
114114
// store the original type, location and Fortran symbol for the map operands.
115115
// They may be used later on to create the block_arguments for some of the
116116
// target directives that require it.
117-
bool processMap(mlir::Location currentLocation,
118-
const llvm::omp::Directive &directive,
119-
Fortran::lower::StatementContext &stmtCtx,
120-
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
121-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
122-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
123-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
124-
*mapSymbols = nullptr) const;
117+
bool processMap(
118+
mlir::Location currentLocation, const llvm::omp::Directive &directive,
119+
Fortran::lower::StatementContext &stmtCtx,
120+
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
121+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols,
122+
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
123+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr) const;
125124
bool
126125
processReduction(mlir::Location currentLocation,
127126
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
@@ -188,7 +187,12 @@ template <typename T>
188187
bool ClauseProcessor::processMotionClauses(
189188
Fortran::lower::StatementContext &stmtCtx,
190189
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
191-
return findRepeatableClause<T>(
190+
llvm::SmallVector<mlir::omp::MapInfoOp> memberMaps;
191+
llvm::SmallVector<mlir::Attribute> memberPlacementIndices;
192+
llvm::SmallVector<const Fortran::semantics::Symbol *> memberParentSyms,
193+
mapSymbols;
194+
195+
bool clauseFound = findRepeatableClause<T>(
192196
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
193197
mlir::Location clauseLocation = converter.genLocation(source);
194198
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -203,8 +207,28 @@ bool ClauseProcessor::processMotionClauses(
203207
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
204208

205209
for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
210+
llvm::omp::OpenMPOffloadMappingFlags objectsMapTypeBits = mapTypeBits;
211+
checkAndApplyDeclTargetMapFlags(converter, objectsMapTypeBits,
212+
*getOmpObjectSymbol(ompObject));
213+
206214
llvm::SmallVector<mlir::Value> bounds;
207215
std::stringstream asFortran;
216+
const Fortran::semantics::Symbol *parentSym = nullptr;
217+
218+
if (getOmpObjectSymbol(ompObject)->owner().IsDerivedType()) {
219+
const auto *designator =
220+
Fortran::parser::Unwrap<Fortran::parser::Designator>(
221+
ompObject.u);
222+
assert(designator && "Expected a designator from derived type "
223+
"component during motion clause processing");
224+
parentSym = GetFirstName(*designator).symbol;
225+
memberParentSyms.push_back(parentSym);
226+
memberPlacementIndices.push_back(
227+
firOpBuilder.getI64IntegerAttr(findComponentMemberPlacement(
228+
&parentSym->GetType()->derivedTypeSpec().typeSymbol(),
229+
getOmpObjectSymbol(ompObject))));
230+
}
231+
208232
Fortran::lower::AddrAndBoundsInfo info =
209233
Fortran::lower::gatherDataOperandAddrAndBounds<
210234
Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
@@ -221,17 +245,27 @@ bool ClauseProcessor::processMotionClauses(
221245
// Explicit map captures are captured ByRef by default,
222246
// optimisation passes may alter this to ByCopy or other capture
223247
// types to optimise
224-
mlir::Value mapOp = createMapInfoOp(
248+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
225249
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
226-
asFortran.str(), bounds, {},
250+
asFortran.str(), bounds, {}, mlir::ArrayAttr{},
227251
static_cast<
228252
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
229-
mapTypeBits),
253+
objectsMapTypeBits),
230254
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
231255

232-
mapOperands.push_back(mapOp);
256+
if (parentSym) {
257+
memberMaps.push_back(mapOp);
258+
} else {
259+
mapOperands.push_back(mapOp);
260+
mapSymbols.push_back(getOmpObjectSymbol(ompObject));
261+
}
233262
}
234263
});
264+
265+
insertChildMapInfoIntoParent(converter, memberParentSyms, memberMaps,
266+
memberPlacementIndices, mapOperands, nullptr,
267+
nullptr, &mapSymbols);
268+
return clauseFound;
235269
}
236270

237271
template <typename... Ts>

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,8 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
805805
deviceAddrOperands;
806806
llvm::SmallVector<mlir::Type> useDeviceTypes;
807807
llvm::SmallVector<mlir::Location> useDeviceLocs;
808-
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
808+
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols,
809+
mapSymbols;
809810

810811
ClauseProcessor cp(converter, semaCtx, clauseList);
811812
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
@@ -816,7 +817,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
816817
cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
817818
useDeviceSymbols);
818819
cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data,
819-
stmtCtx, mapOperands);
820+
stmtCtx, mapOperands, &mapSymbols);
820821

821822
auto dataOp = converter.getFirOpBuilder().create<mlir::omp::DataOp>(
822823
currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
@@ -839,6 +840,7 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
839840
mlir::UnitAttr nowaitAttr;
840841
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
841842
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
843+
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
842844

843845
Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
844846
// GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
@@ -872,7 +874,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
872874
mapOperands);
873875

874876
} else {
875-
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
877+
cp.processMap(currentLocation, directive, stmtCtx, mapOperands,
878+
&mapSymbols);
876879
}
877880

878881
return firOpBuilder.create<OpTy>(
@@ -993,7 +996,7 @@ static void genBodyOfTargetOp(
993996
firOpBuilder.setInsertionPoint(targetOp);
994997
mlir::Value mapOp = createMapInfoOp(
995998
firOpBuilder, copyVal.getLoc(), copyVal, mlir::Value{}, name.str(),
996-
bounds, llvm::SmallVector<mlir::Value>{},
999+
bounds, llvm::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
9971000
static_cast<
9981001
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
9991002
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
@@ -1061,9 +1064,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
10611064
cp.processThreadLimit(stmtCtx, threadLimitOperand);
10621065
cp.processDepend(dependTypeOperands, dependOperands);
10631066
cp.processNowait(nowaitAttr);
1064-
cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
1065-
&mapSymLocs, &mapSymbols);
1066-
1067+
cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymbols,
1068+
&mapSymTypes, &mapSymLocs);
10671069
cp.processTODO<Fortran::parser::OmpClause::Private,
10681070
Fortran::parser::OmpClause::Firstprivate,
10691071
Fortran::parser::OmpClause::IsDevicePtr,
@@ -1142,9 +1144,11 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
11421144
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
11431145
}
11441146

1147+
checkAndApplyDeclTargetMapFlags(converter, mapFlag, sym);
1148+
11451149
mlir::Value mapOp = createMapInfoOp(
11461150
converter.getFirOpBuilder(), baseOp.getLoc(), baseOp, mlir::Value{},
1147-
name.str(), bounds, {},
1151+
name.str(), bounds, {}, mlir::ArrayAttr{},
11481152
static_cast<
11491153
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
11501154
mapFlag),

0 commit comments

Comments
 (0)