Skip to content

Commit 74535f5

Browse files
committed
Addressed reviewer comments.
1 parent f2c2e0a commit 74535f5

File tree

4 files changed

+256
-321
lines changed

4 files changed

+256
-321
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 76 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -905,13 +905,68 @@ bool ClauseProcessor::processLink(
905905
});
906906
}
907907

908+
void ClauseProcessor::processMapObjects(
909+
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
910+
const omp::ObjectList &objects,
911+
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
912+
std::map<const semantics::Symbol *,
913+
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
914+
llvm::SmallVectorImpl<mlir::Value> &mapVars,
915+
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
916+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
917+
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
918+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
919+
for (const omp::Object &object : objects) {
920+
llvm::SmallVector<mlir::Value> bounds;
921+
std::stringstream asFortran;
922+
923+
lower::AddrAndBoundsInfo info =
924+
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
925+
mlir::omp::MapBoundsType>(
926+
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
927+
object.ref(), clauseLocation, asFortran, bounds,
928+
treatIndexAsSection);
929+
930+
auto origSymbol = converter.getSymbolAddress(*object.sym());
931+
mlir::Value symAddr = info.addr;
932+
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
933+
symAddr = origSymbol;
934+
935+
// Explicit map captures are captured ByRef by default,
936+
// optimisation passes may alter this to ByCopy or other capture
937+
// types to optimise
938+
auto location = mlir::NameLoc::get(
939+
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
940+
symAddr.getLoc());
941+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
942+
firOpBuilder, location, symAddr,
943+
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
944+
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
945+
static_cast<
946+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
947+
mapTypeBits),
948+
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
949+
950+
if (object.sym()->owner().IsDerivedType()) {
951+
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
952+
} else {
953+
mapVars.push_back(mapOp);
954+
if (mapSyms)
955+
mapSyms->push_back(object.sym());
956+
if (mapSymTypes)
957+
mapSymTypes->push_back(symAddr.getType());
958+
if (mapSymLocs)
959+
mapSymLocs->push_back(symAddr.getLoc());
960+
}
961+
}
962+
}
963+
908964
bool ClauseProcessor::processMap(
909965
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
910966
mlir::omp::MapClauseOps &result,
911967
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
912968
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
913969
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
914-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
915970
// We always require tracking of symbols, even if the caller does not,
916971
// so we create an optionally used local set of symbols when the mapSyms
917972
// argument is not present.
@@ -966,50 +1021,10 @@ bool ClauseProcessor::processMap(
9661021
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
9671022
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
9681023
}
969-
970-
for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) {
971-
llvm::SmallVector<mlir::Value> bounds;
972-
std::stringstream asFortran;
973-
974-
lower::AddrAndBoundsInfo info =
975-
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
976-
mlir::omp::MapBoundsType>(
977-
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
978-
object.ref(), clauseLocation, asFortran, bounds,
979-
treatIndexAsSection);
980-
981-
auto origSymbol = converter.getSymbolAddress(*object.sym());
982-
mlir::Value symAddr = info.addr;
983-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
984-
symAddr = origSymbol;
985-
986-
// Explicit map captures are captured ByRef by default,
987-
// optimisation passes may alter this to ByCopy or other capture
988-
// types to optimise
989-
auto location = mlir::NameLoc::get(
990-
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
991-
symAddr.getLoc());
992-
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
993-
firOpBuilder, location, symAddr,
994-
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
995-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
996-
static_cast<
997-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
998-
mapTypeBits),
999-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
1000-
1001-
if (object.sym()->owner().IsDerivedType()) {
1002-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
1003-
semaCtx);
1004-
} else {
1005-
result.mapVars.push_back(mapOp);
1006-
ptrMapSyms->push_back(object.sym());
1007-
if (mapSymTypes)
1008-
mapSymTypes->push_back(symAddr.getType());
1009-
if (mapSymLocs)
1010-
mapSymLocs->push_back(symAddr.getLoc());
1011-
}
1012-
}
1024+
processMapObjects(stmtCtx, clauseLocation,
1025+
std::get<omp::ObjectList>(clause.t), mapTypeBits,
1026+
parentMemberIndices, result.mapVars, ptrMapSyms,
1027+
mapSymLocs, mapSymTypes);
10131028
});
10141029

10151030
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
@@ -1072,62 +1087,23 @@ bool ClauseProcessor::processEnter(
10721087
}
10731088

10741089
bool ClauseProcessor::processUseDeviceAddr(
1075-
Fortran::lower::StatementContext &stmtCtx,
1076-
mlir::omp::UseDeviceAddrClauseOps &result,
1090+
lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
10771091
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
10781092
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
1079-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
1080-
const {
1081-
std::map<const Fortran::semantics::Symbol *,
1093+
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1094+
std::map<const semantics::Symbol *,
10821095
llvm::SmallVector<OmpMapMemberIndicesData>>
10831096
parentMemberIndices;
10841097
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
1085-
[&](const omp::clause::UseDeviceAddr &clause,
1086-
const Fortran::parser::CharBlock &) {
1087-
const Fortran::parser::CharBlock source;
1098+
[&](const omp::clause::UseDeviceAddr &clause, const parser::CharBlock &) {
1099+
const parser::CharBlock source;
10881100
mlir::Location location = converter.genLocation(source);
1089-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
10901101
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
10911102
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
10921103
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1093-
for (const omp::Object &object : clause.v) {
1094-
llvm::SmallVector<mlir::Value> bounds;
1095-
std::stringstream asFortran;
1096-
1097-
Fortran::lower::AddrAndBoundsInfo info =
1098-
Fortran::lower::gatherDataOperandAddrAndBounds<
1099-
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
1100-
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
1101-
object.ref(), location, asFortran, bounds,
1102-
treatIndexAsSection);
1103-
1104-
auto origSymbol = converter.getSymbolAddress(*object.sym());
1105-
mlir::Value symAddr = info.addr;
1106-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
1107-
symAddr = origSymbol;
1108-
1109-
// Explicit map captures are captured ByRef by default,
1110-
// optimisation passes may alter this to ByCopy or other capture
1111-
// types to optimise
1112-
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
1113-
firOpBuilder, location, symAddr,
1114-
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
1115-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
1116-
static_cast<
1117-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1118-
mapTypeBits),
1119-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
1120-
1121-
if (object.sym()->owner().IsDerivedType()) {
1122-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
1123-
semaCtx);
1124-
} else {
1125-
useDeviceSyms.push_back(object.sym());
1126-
useDeviceTypes.push_back(symAddr.getType());
1127-
useDeviceLocs.push_back(symAddr.getLoc());
1128-
result.useDeviceAddrVars.push_back(mapOp);
1129-
}
1130-
}
1104+
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1105+
parentMemberIndices, result.useDeviceAddrVars,
1106+
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
11311107
});
11321108

11331109
insertChildMapInfoIntoParent(converter, parentMemberIndices,
@@ -1137,62 +1113,23 @@ bool ClauseProcessor::processUseDeviceAddr(
11371113
}
11381114

11391115
bool ClauseProcessor::processUseDevicePtr(
1140-
Fortran::lower::StatementContext &stmtCtx,
1141-
mlir::omp::UseDevicePtrClauseOps &result,
1116+
lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
11421117
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
11431118
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
1144-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
1145-
const {
1146-
std::map<const Fortran::semantics::Symbol *,
1119+
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1120+
std::map<const semantics::Symbol *,
11471121
llvm::SmallVector<OmpMapMemberIndicesData>>
11481122
parentMemberIndices;
11491123
bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
1150-
[&](const omp::clause::UseDevicePtr &clause,
1151-
const Fortran::parser::CharBlock &) {
1152-
const Fortran::parser::CharBlock source;
1124+
[&](const omp::clause::UseDevicePtr &clause, const parser::CharBlock &) {
1125+
const parser::CharBlock source;
11531126
mlir::Location location = converter.genLocation(source);
1154-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
11551127
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
11561128
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
11571129
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1158-
for (const omp::Object &object : clause.v) {
1159-
llvm::SmallVector<mlir::Value> bounds;
1160-
std::stringstream asFortran;
1161-
1162-
Fortran::lower::AddrAndBoundsInfo info =
1163-
Fortran::lower::gatherDataOperandAddrAndBounds<
1164-
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
1165-
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
1166-
object.ref(), location, asFortran, bounds,
1167-
treatIndexAsSection);
1168-
1169-
auto origSymbol = converter.getSymbolAddress(*object.sym());
1170-
mlir::Value symAddr = info.addr;
1171-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
1172-
symAddr = origSymbol;
1173-
1174-
// Explicit map captures are captured ByRef by default,
1175-
// optimisation passes may alter this to ByCopy or other capture
1176-
// types to optimise
1177-
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
1178-
firOpBuilder, location, symAddr,
1179-
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
1180-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
1181-
static_cast<
1182-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1183-
mapTypeBits),
1184-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
1185-
1186-
if (object.sym()->owner().IsDerivedType()) {
1187-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
1188-
semaCtx);
1189-
} else {
1190-
useDeviceSyms.push_back(object.sym());
1191-
useDeviceTypes.push_back(symAddr.getType());
1192-
useDeviceLocs.push_back(symAddr.getLoc());
1193-
result.useDevicePtrVars.push_back(mapOp);
1194-
}
1195-
}
1130+
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1131+
parentMemberIndices, result.useDevicePtrVars,
1132+
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
11961133
});
11971134

11981135
insertChildMapInfoIntoParent(converter, parentMemberIndices,

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 27 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,18 @@ class ClauseProcessor {
128128
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
129129
nullptr) const;
130130
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
131-
bool
132-
processUseDeviceAddr(Fortran::lower::StatementContext &stmtCtx,
133-
mlir::omp::UseDeviceAddrClauseOps &result,
134-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
135-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
136-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
137-
&useDeviceSyms) const;
138-
bool
139-
processUseDevicePtr(Fortran::lower::StatementContext &stmtCtx,
140-
mlir::omp::UseDevicePtrClauseOps &result,
141-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
142-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
143-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
144-
&useDeviceSyms) const;
131+
bool processUseDeviceAddr(
132+
lower::StatementContext &stmtCtx,
133+
mlir::omp::UseDeviceAddrClauseOps &result,
134+
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
135+
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
136+
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
137+
bool processUseDevicePtr(
138+
lower::StatementContext &stmtCtx,
139+
mlir::omp::UseDevicePtrClauseOps &result,
140+
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
141+
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
142+
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
145143

146144
template <typename T>
147145
bool processMotionClauses(lower::StatementContext &stmtCtx,
@@ -177,6 +175,17 @@ class ClauseProcessor {
177175
template <typename T>
178176
bool markClauseOccurrence(mlir::UnitAttr &result) const;
179177

178+
void processMapObjects(
179+
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
180+
const omp::ObjectList &objects,
181+
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
182+
std::map<const semantics::Symbol *,
183+
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
184+
llvm::SmallVectorImpl<mlir::Value> &mapVars,
185+
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
186+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
187+
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
188+
180189
lower::AbstractConverter &converter;
181190
semantics::SemanticsContext &semaCtx;
182191
List<Clause> clauses;
@@ -193,7 +202,6 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
193202
bool clauseFound = findRepeatableClause<T>(
194203
[&](const T &clause, const parser::CharBlock &source) {
195204
mlir::Location clauseLocation = converter.genLocation(source);
196-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
197205

198206
static_assert(std::is_same_v<T, omp::clause::To> ||
199207
std::is_same_v<T, omp::clause::From>);
@@ -204,43 +212,10 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
204212
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
205213
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
206214

207-
auto &objects = std::get<ObjectList>(clause.t);
208-
for (const omp::Object &object : objects) {
209-
llvm::SmallVector<mlir::Value> bounds;
210-
std::stringstream asFortran;
211-
212-
lower::AddrAndBoundsInfo info =
213-
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
214-
mlir::omp::MapBoundsType>(
215-
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
216-
object.ref(), clauseLocation, asFortran, bounds,
217-
treatIndexAsSection);
218-
219-
auto origSymbol = converter.getSymbolAddress(*object.sym());
220-
mlir::Value symAddr = info.addr;
221-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
222-
symAddr = origSymbol;
223-
224-
// Explicit map captures are captured ByRef by default,
225-
// optimisation passes may alter this to ByCopy or other capture
226-
// types to optimise
227-
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
228-
firOpBuilder, clauseLocation, symAddr,
229-
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
230-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
231-
static_cast<
232-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
233-
mapTypeBits),
234-
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
235-
236-
if (object.sym()->owner().IsDerivedType()) {
237-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
238-
semaCtx);
239-
} else {
240-
result.mapVars.push_back(mapOp);
241-
mapSymbols.push_back(object.sym());
242-
}
243-
}
215+
processMapObjects(stmtCtx, clauseLocation,
216+
std::get<ObjectList>(clause.t), mapTypeBits,
217+
parentMemberIndices, result.mapVars, &mapSymbols,
218+
nullptr, nullptr);
244219
});
245220

246221
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,

0 commit comments

Comments
 (0)