Skip to content

Commit d8d7b0b

Browse files
committed
Addressed reviewer comments.
1 parent be90dc5 commit d8d7b0b

File tree

4 files changed

+252
-316
lines changed

4 files changed

+252
-316
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 76 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -909,13 +909,68 @@ bool ClauseProcessor::processLink(
909909
});
910910
}
911911

912+
void ClauseProcessor::processMapObjects(
913+
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
914+
const omp::ObjectList &objects,
915+
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
916+
std::map<const semantics::Symbol *,
917+
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
918+
llvm::SmallVectorImpl<mlir::Value> &mapVars,
919+
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
920+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
921+
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
922+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
923+
for (const omp::Object &object : objects) {
924+
llvm::SmallVector<mlir::Value> bounds;
925+
std::stringstream asFortran;
926+
927+
lower::AddrAndBoundsInfo info =
928+
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
929+
mlir::omp::MapBoundsType>(
930+
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
931+
object.ref(), clauseLocation, asFortran, bounds,
932+
treatIndexAsSection);
933+
934+
auto origSymbol = converter.getSymbolAddress(*object.sym());
935+
mlir::Value symAddr = info.addr;
936+
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
937+
symAddr = origSymbol;
938+
939+
// Explicit map captures are captured ByRef by default,
940+
// optimisation passes may alter this to ByCopy or other capture
941+
// types to optimise
942+
auto location = mlir::NameLoc::get(
943+
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
944+
symAddr.getLoc());
945+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
946+
firOpBuilder, location, symAddr,
947+
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
948+
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
949+
static_cast<
950+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
951+
mapTypeBits),
952+
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
953+
954+
if (object.sym()->owner().IsDerivedType()) {
955+
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
956+
} else {
957+
mapVars.push_back(mapOp);
958+
if (mapSyms)
959+
mapSyms->push_back(object.sym());
960+
if (mapSymTypes)
961+
mapSymTypes->push_back(symAddr.getType());
962+
if (mapSymLocs)
963+
mapSymLocs->push_back(symAddr.getLoc());
964+
}
965+
}
966+
}
967+
912968
bool ClauseProcessor::processMap(
913969
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
914970
mlir::omp::MapClauseOps &result,
915971
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
916972
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
917973
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
918-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
919974
// We always require tracking of symbols, even if the caller does not,
920975
// so we create an optionally used local set of symbols when the mapSyms
921976
// argument is not present.
@@ -970,50 +1025,10 @@ bool ClauseProcessor::processMap(
9701025
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
9711026
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
9721027
}
973-
974-
for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) {
975-
llvm::SmallVector<mlir::Value> bounds;
976-
std::stringstream asFortran;
977-
978-
lower::AddrAndBoundsInfo info =
979-
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
980-
mlir::omp::MapBoundsType>(
981-
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
982-
object.ref(), clauseLocation, asFortran, bounds,
983-
treatIndexAsSection);
984-
985-
auto origSymbol = converter.getSymbolAddress(*object.sym());
986-
mlir::Value symAddr = info.addr;
987-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
988-
symAddr = origSymbol;
989-
990-
// Explicit map captures are captured ByRef by default,
991-
// optimisation passes may alter this to ByCopy or other capture
992-
// types to optimise
993-
auto location = mlir::NameLoc::get(
994-
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
995-
symAddr.getLoc());
996-
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
997-
firOpBuilder, location, symAddr,
998-
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
999-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
1000-
static_cast<
1001-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1002-
mapTypeBits),
1003-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
1004-
1005-
if (object.sym()->owner().IsDerivedType()) {
1006-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
1007-
semaCtx);
1008-
} else {
1009-
result.mapVars.push_back(mapOp);
1010-
ptrMapSyms->push_back(object.sym());
1011-
if (mapSymTypes)
1012-
mapSymTypes->push_back(symAddr.getType());
1013-
if (mapSymLocs)
1014-
mapSymLocs->push_back(symAddr.getLoc());
1015-
}
1016-
}
1028+
processMapObjects(stmtCtx, clauseLocation,
1029+
std::get<omp::ObjectList>(clause.t), mapTypeBits,
1030+
parentMemberIndices, result.mapVars, ptrMapSyms,
1031+
mapSymLocs, mapSymTypes);
10171032
});
10181033

10191034
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
@@ -1076,62 +1091,23 @@ bool ClauseProcessor::processEnter(
10761091
}
10771092

10781093
bool ClauseProcessor::processUseDeviceAddr(
1079-
Fortran::lower::StatementContext &stmtCtx,
1080-
mlir::omp::UseDeviceAddrClauseOps &result,
1094+
lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
10811095
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
10821096
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
1083-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
1084-
const {
1085-
std::map<const Fortran::semantics::Symbol *,
1097+
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1098+
std::map<const semantics::Symbol *,
10861099
llvm::SmallVector<OmpMapMemberIndicesData>>
10871100
parentMemberIndices;
10881101
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
1089-
[&](const omp::clause::UseDeviceAddr &clause,
1090-
const Fortran::parser::CharBlock &) {
1091-
const Fortran::parser::CharBlock source;
1102+
[&](const omp::clause::UseDeviceAddr &clause, const parser::CharBlock &) {
1103+
const parser::CharBlock source;
10921104
mlir::Location location = converter.genLocation(source);
1093-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
10941105
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
10951106
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
10961107
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1097-
for (const omp::Object &object : clause.v) {
1098-
llvm::SmallVector<mlir::Value> bounds;
1099-
std::stringstream asFortran;
1100-
1101-
Fortran::lower::AddrAndBoundsInfo info =
1102-
Fortran::lower::gatherDataOperandAddrAndBounds<
1103-
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
1104-
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
1105-
object.ref(), location, asFortran, bounds,
1106-
treatIndexAsSection);
1107-
1108-
auto origSymbol = converter.getSymbolAddress(*object.sym());
1109-
mlir::Value symAddr = info.addr;
1110-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
1111-
symAddr = origSymbol;
1112-
1113-
// Explicit map captures are captured ByRef by default,
1114-
// optimisation passes may alter this to ByCopy or other capture
1115-
// types to optimise
1116-
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
1117-
firOpBuilder, location, symAddr,
1118-
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
1119-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
1120-
static_cast<
1121-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1122-
mapTypeBits),
1123-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
1124-
1125-
if (object.sym()->owner().IsDerivedType()) {
1126-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
1127-
semaCtx);
1128-
} else {
1129-
useDeviceSyms.push_back(object.sym());
1130-
useDeviceTypes.push_back(symAddr.getType());
1131-
useDeviceLocs.push_back(symAddr.getLoc());
1132-
result.useDeviceAddrVars.push_back(mapOp);
1133-
}
1134-
}
1108+
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1109+
parentMemberIndices, result.useDeviceAddrVars,
1110+
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
11351111
});
11361112

11371113
insertChildMapInfoIntoParent(converter, parentMemberIndices,
@@ -1141,62 +1117,23 @@ bool ClauseProcessor::processUseDeviceAddr(
11411117
}
11421118

11431119
bool ClauseProcessor::processUseDevicePtr(
1144-
Fortran::lower::StatementContext &stmtCtx,
1145-
mlir::omp::UseDevicePtrClauseOps &result,
1120+
lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
11461121
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
11471122
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
1148-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
1149-
const {
1150-
std::map<const Fortran::semantics::Symbol *,
1123+
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1124+
std::map<const semantics::Symbol *,
11511125
llvm::SmallVector<OmpMapMemberIndicesData>>
11521126
parentMemberIndices;
11531127
bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
1154-
[&](const omp::clause::UseDevicePtr &clause,
1155-
const Fortran::parser::CharBlock &) {
1156-
const Fortran::parser::CharBlock source;
1128+
[&](const omp::clause::UseDevicePtr &clause, const parser::CharBlock &) {
1129+
const parser::CharBlock source;
11571130
mlir::Location location = converter.genLocation(source);
1158-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
11591131
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
11601132
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
11611133
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1162-
for (const omp::Object &object : clause.v) {
1163-
llvm::SmallVector<mlir::Value> bounds;
1164-
std::stringstream asFortran;
1165-
1166-
Fortran::lower::AddrAndBoundsInfo info =
1167-
Fortran::lower::gatherDataOperandAddrAndBounds<
1168-
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
1169-
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
1170-
object.ref(), location, asFortran, bounds,
1171-
treatIndexAsSection);
1172-
1173-
auto origSymbol = converter.getSymbolAddress(*object.sym());
1174-
mlir::Value symAddr = info.addr;
1175-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
1176-
symAddr = origSymbol;
1177-
1178-
// Explicit map captures are captured ByRef by default,
1179-
// optimisation passes may alter this to ByCopy or other capture
1180-
// types to optimise
1181-
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
1182-
firOpBuilder, location, symAddr,
1183-
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
1184-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
1185-
static_cast<
1186-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1187-
mapTypeBits),
1188-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
1189-
1190-
if (object.sym()->owner().IsDerivedType()) {
1191-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
1192-
semaCtx);
1193-
} else {
1194-
useDeviceSyms.push_back(object.sym());
1195-
useDeviceTypes.push_back(symAddr.getType());
1196-
useDeviceLocs.push_back(symAddr.getLoc());
1197-
result.useDevicePtrVars.push_back(mapOp);
1198-
}
1199-
}
1134+
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1135+
parentMemberIndices, result.useDevicePtrVars,
1136+
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
12001137
});
12011138

12021139
insertChildMapInfoIntoParent(converter, parentMemberIndices,

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 27 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,18 @@ class ClauseProcessor {
128128
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
129129
nullptr) const;
130130
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
131-
bool
132-
processUseDeviceAddr(Fortran::lower::StatementContext &stmtCtx,
133-
mlir::omp::UseDeviceAddrClauseOps &result,
134-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
135-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
136-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
137-
&useDeviceSyms) const;
138-
bool
139-
processUseDevicePtr(Fortran::lower::StatementContext &stmtCtx,
140-
mlir::omp::UseDevicePtrClauseOps &result,
141-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
142-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
143-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
144-
&useDeviceSyms) const;
131+
bool processUseDeviceAddr(
132+
lower::StatementContext &stmtCtx,
133+
mlir::omp::UseDeviceAddrClauseOps &result,
134+
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
135+
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
136+
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
137+
bool processUseDevicePtr(
138+
lower::StatementContext &stmtCtx,
139+
mlir::omp::UseDevicePtrClauseOps &result,
140+
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
141+
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
142+
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
145143

146144
template <typename T>
147145
bool processMotionClauses(lower::StatementContext &stmtCtx,
@@ -177,6 +175,17 @@ class ClauseProcessor {
177175
template <typename T>
178176
bool markClauseOccurrence(mlir::UnitAttr &result) const;
179177

178+
void processMapObjects(
179+
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
180+
const omp::ObjectList &objects,
181+
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
182+
std::map<const semantics::Symbol *,
183+
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
184+
llvm::SmallVectorImpl<mlir::Value> &mapVars,
185+
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
186+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
187+
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
188+
180189
lower::AbstractConverter &converter;
181190
semantics::SemanticsContext &semaCtx;
182191
List<Clause> clauses;
@@ -193,7 +202,6 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
193202
bool clauseFound = findRepeatableClause<T>(
194203
[&](const T &clause, const parser::CharBlock &source) {
195204
mlir::Location clauseLocation = converter.genLocation(source);
196-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
197205

198206
static_assert(std::is_same_v<T, omp::clause::To> ||
199207
std::is_same_v<T, omp::clause::From>);
@@ -204,43 +212,10 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
204212
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
205213
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
206214

207-
auto &objects = std::get<ObjectList>(clause.t);
208-
for (const omp::Object &object : objects) {
209-
llvm::SmallVector<mlir::Value> bounds;
210-
std::stringstream asFortran;
211-
212-
lower::AddrAndBoundsInfo info =
213-
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
214-
mlir::omp::MapBoundsType>(
215-
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
216-
object.ref(), clauseLocation, asFortran, bounds,
217-
treatIndexAsSection);
218-
219-
auto origSymbol = converter.getSymbolAddress(*object.sym());
220-
mlir::Value symAddr = info.addr;
221-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
222-
symAddr = origSymbol;
223-
224-
// Explicit map captures are captured ByRef by default,
225-
// optimisation passes may alter this to ByCopy or other capture
226-
// types to optimise
227-
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
228-
firOpBuilder, clauseLocation, symAddr,
229-
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
230-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
231-
static_cast<
232-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
233-
mapTypeBits),
234-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
235-
236-
if (object.sym()->owner().IsDerivedType()) {
237-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
238-
semaCtx);
239-
} else {
240-
result.mapVars.push_back(mapOp);
241-
mapSymbols.push_back(object.sym());
242-
}
243-
}
215+
processMapObjects(stmtCtx, clauseLocation,
216+
std::get<ObjectList>(clause.t), mapTypeBits,
217+
parentMemberIndices, result.mapVars, &mapSymbols,
218+
nullptr, nullptr);
244219
});
245220

246221
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,

0 commit comments

Comments
 (0)