Skip to content

Commit 9ba4103

Browse files
authored
[OpenMP]Update use_device_clause lowering (#101703)
This patch updates the use_device_ptr and use_device_addr clauses to use the mapInfoOps for lowering. This allows all the types that are handle by the map clauses such as derived types to also be supported by the use_device_clauses. This is patch 1/2 in a series of patches. Co-authored-by: Raghu Maddhipatla [email protected]
1 parent 43b8ae3 commit 9ba4103

File tree

6 files changed

+283
-252
lines changed

6 files changed

+283
-252
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 94 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -887,13 +887,64 @@ bool ClauseProcessor::processLink(
887887
});
888888
}
889889

890+
void ClauseProcessor::processMapObjects(
891+
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
892+
const omp::ObjectList &objects,
893+
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
894+
std::map<const semantics::Symbol *,
895+
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
896+
llvm::SmallVectorImpl<mlir::Value> &mapVars,
897+
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
898+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
899+
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
900+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
901+
for (const omp::Object &object : objects) {
902+
llvm::SmallVector<mlir::Value> bounds;
903+
std::stringstream asFortran;
904+
905+
lower::AddrAndBoundsInfo info =
906+
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
907+
mlir::omp::MapBoundsType>(
908+
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
909+
object.ref(), clauseLocation, asFortran, bounds,
910+
treatIndexAsSection);
911+
912+
// Explicit map captures are captured ByRef by default,
913+
// optimisation passes may alter this to ByCopy or other capture
914+
// types to optimise
915+
mlir::Value baseOp = info.rawInput;
916+
auto location = mlir::NameLoc::get(
917+
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
918+
baseOp.getLoc());
919+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
920+
firOpBuilder, location, baseOp,
921+
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
922+
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
923+
static_cast<
924+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
925+
mapTypeBits),
926+
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
927+
928+
if (object.sym()->owner().IsDerivedType()) {
929+
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
930+
} else {
931+
mapVars.push_back(mapOp);
932+
if (mapSyms)
933+
mapSyms->push_back(object.sym());
934+
if (mapSymTypes)
935+
mapSymTypes->push_back(baseOp.getType());
936+
if (mapSymLocs)
937+
mapSymLocs->push_back(baseOp.getLoc());
938+
}
939+
}
940+
}
941+
890942
bool ClauseProcessor::processMap(
891943
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
892944
mlir::omp::MapClauseOps &result,
893945
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
894946
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
895947
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
896-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
897948
// We always require tracking of symbols, even if the caller does not,
898949
// so we create an optionally used local set of symbols when the mapSyms
899950
// argument is not present.
@@ -948,46 +999,10 @@ bool ClauseProcessor::processMap(
948999
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
9491000
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
9501001
}
951-
952-
for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) {
953-
llvm::SmallVector<mlir::Value> bounds;
954-
std::stringstream asFortran;
955-
956-
lower::AddrAndBoundsInfo info =
957-
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
958-
mlir::omp::MapBoundsType>(
959-
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
960-
object.ref(), clauseLocation, asFortran, bounds,
961-
treatIndexAsSection);
962-
963-
// Explicit map captures are captured ByRef by default,
964-
// optimisation passes may alter this to ByCopy or other capture
965-
// types to optimise
966-
mlir::Value baseOp = info.rawInput;
967-
auto location = mlir::NameLoc::get(
968-
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
969-
baseOp.getLoc());
970-
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
971-
firOpBuilder, location, baseOp,
972-
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
973-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
974-
static_cast<
975-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
976-
mapTypeBits),
977-
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
978-
979-
if (object.sym()->owner().IsDerivedType()) {
980-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
981-
semaCtx);
982-
} else {
983-
result.mapVars.push_back(mapOp);
984-
ptrMapSyms->push_back(object.sym());
985-
if (mapSymTypes)
986-
mapSymTypes->push_back(baseOp.getType());
987-
if (mapSymLocs)
988-
mapSymLocs->push_back(baseOp.getLoc());
989-
}
990-
}
1002+
processMapObjects(stmtCtx, clauseLocation,
1003+
std::get<omp::ObjectList>(clause.t), mapTypeBits,
1004+
parentMemberIndices, result.mapVars, ptrMapSyms,
1005+
mapSymLocs, mapSymTypes);
9911006
});
9921007

9931008
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
@@ -1050,27 +1065,55 @@ bool ClauseProcessor::processEnter(
10501065
}
10511066

10521067
bool ClauseProcessor::processUseDeviceAddr(
1053-
mlir::omp::UseDeviceAddrClauseOps &result,
1068+
lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
10541069
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
10551070
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
10561071
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1057-
return findRepeatableClause<omp::clause::UseDeviceAddr>(
1058-
[&](const omp::clause::UseDeviceAddr &clause, const parser::CharBlock &) {
1059-
addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars,
1060-
useDeviceTypes, useDeviceLocs, useDeviceSyms);
1072+
std::map<const semantics::Symbol *,
1073+
llvm::SmallVector<OmpMapMemberIndicesData>>
1074+
parentMemberIndices;
1075+
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
1076+
[&](const omp::clause::UseDeviceAddr &clause,
1077+
const parser::CharBlock &source) {
1078+
mlir::Location location = converter.genLocation(source);
1079+
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1080+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1081+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1082+
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1083+
parentMemberIndices, result.useDeviceAddrVars,
1084+
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
10611085
});
1086+
1087+
insertChildMapInfoIntoParent(converter, parentMemberIndices,
1088+
result.useDeviceAddrVars, useDeviceSyms,
1089+
&useDeviceTypes, &useDeviceLocs);
1090+
return clauseFound;
10621091
}
10631092

10641093
bool ClauseProcessor::processUseDevicePtr(
1065-
mlir::omp::UseDevicePtrClauseOps &result,
1094+
lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
10661095
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
10671096
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
10681097
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1069-
return findRepeatableClause<omp::clause::UseDevicePtr>(
1070-
[&](const omp::clause::UseDevicePtr &clause, const parser::CharBlock &) {
1071-
addUseDeviceClause(converter, clause.v, result.useDevicePtrVars,
1072-
useDeviceTypes, useDeviceLocs, useDeviceSyms);
1098+
std::map<const semantics::Symbol *,
1099+
llvm::SmallVector<OmpMapMemberIndicesData>>
1100+
parentMemberIndices;
1101+
bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
1102+
[&](const omp::clause::UseDevicePtr &clause,
1103+
const parser::CharBlock &source) {
1104+
mlir::Location location = converter.genLocation(source);
1105+
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1106+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1107+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1108+
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1109+
parentMemberIndices, result.useDevicePtrVars,
1110+
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
10731111
});
1112+
1113+
insertChildMapInfoIntoParent(converter, parentMemberIndices,
1114+
result.useDevicePtrVars, useDeviceSyms,
1115+
&useDeviceTypes, &useDeviceLocs);
1116+
return clauseFound;
10741117
}
10751118

10761119
} // namespace omp

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,13 @@ class ClauseProcessor {
128128
nullptr) const;
129129
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
130130
bool processUseDeviceAddr(
131+
lower::StatementContext &stmtCtx,
131132
mlir::omp::UseDeviceAddrClauseOps &result,
132133
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
133134
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
134135
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
135136
bool processUseDevicePtr(
137+
lower::StatementContext &stmtCtx,
136138
mlir::omp::UseDevicePtrClauseOps &result,
137139
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
138140
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
@@ -172,6 +174,17 @@ class ClauseProcessor {
172174
template <typename T>
173175
bool markClauseOccurrence(mlir::UnitAttr &result) const;
174176

177+
void processMapObjects(
178+
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
179+
const omp::ObjectList &objects,
180+
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
181+
std::map<const semantics::Symbol *,
182+
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
183+
llvm::SmallVectorImpl<mlir::Value> &mapVars,
184+
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
185+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
186+
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
187+
175188
lower::AbstractConverter &converter;
176189
semantics::SemanticsContext &semaCtx;
177190
List<Clause> clauses;
@@ -188,7 +201,6 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
188201
bool clauseFound = findRepeatableClause<T>(
189202
[&](const T &clause, const parser::CharBlock &source) {
190203
mlir::Location clauseLocation = converter.genLocation(source);
191-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
192204

193205
static_assert(std::is_same_v<T, omp::clause::To> ||
194206
std::is_same_v<T, omp::clause::From>);
@@ -199,39 +211,9 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
199211
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
200212
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
201213

202-
auto &objects = std::get<ObjectList>(clause.t);
203-
for (const omp::Object &object : objects) {
204-
llvm::SmallVector<mlir::Value> bounds;
205-
std::stringstream asFortran;
206-
207-
lower::AddrAndBoundsInfo info =
208-
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
209-
mlir::omp::MapBoundsType>(
210-
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
211-
object.ref(), clauseLocation, asFortran, bounds,
212-
treatIndexAsSection);
213-
214-
// Explicit map captures are captured ByRef by default,
215-
// optimisation passes may alter this to ByCopy or other capture
216-
// types to optimise
217-
mlir::Value baseOp = info.rawInput;
218-
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
219-
firOpBuilder, clauseLocation, baseOp,
220-
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
221-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
222-
static_cast<
223-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
224-
mapTypeBits),
225-
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
226-
227-
if (object.sym()->owner().IsDerivedType()) {
228-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
229-
semaCtx);
230-
} else {
231-
result.mapVars.push_back(mapOp);
232-
mapSymbols.push_back(object.sym());
233-
}
234-
}
214+
processMapObjects(stmtCtx, clauseLocation,
215+
std::get<ObjectList>(clause.t), mapTypeBits,
216+
parentMemberIndices, result.mapVars, &mapSymbols);
235217
});
236218

237219
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,

0 commit comments

Comments
 (0)