Skip to content

Commit 8411549

Browse files
authored
[flang][Lower] Convert OMP Map and related functions to evaluate::Expr (#81626)
The related functions are `gatherDataOperandAddrAndBounds` and `genBoundsOps`. The former is used in OpenACC as well, and it was updated to pass evaluate::Expr instead of parser objects. The difference in the test case comes from unfolded conversions of index expressions, which are explicitly of type integer(kind=8). Delete now unused `findRepeatableClause2` and `findClause2`. Add `AsGenericExpr` that takes std::optional. It already returns optional Expr. Making it accept an optional Expr as input would reduce the number of necessary checks when handling frequent optional values in evaluator. [Clause representation 4/6]
1 parent 0177a95 commit 8411549

File tree

8 files changed

+335
-268
lines changed

8 files changed

+335
-268
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ inline Expr<SomeType> AsGenericExpr(Expr<SomeType> &&x) { return std::move(x); }
148148
std::optional<Expr<SomeType>> AsGenericExpr(DataRef &&);
149149
std::optional<Expr<SomeType>> AsGenericExpr(const Symbol &);
150150

151+
// Propagate std::optional from input to output.
152+
template <typename A>
153+
std::optional<Expr<SomeType>> AsGenericExpr(std::optional<A> &&x) {
154+
if (!x)
155+
return std::nullopt;
156+
return AsGenericExpr(std::move(*x));
157+
}
158+
151159
template <typename A>
152160
common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(
153161
A &&x) {

flang/lib/Lower/DirectivesCommon.h

Lines changed: 234 additions & 155 deletions
Large diffs are not rendered by default.

flang/lib/Lower/OpenACC.cpp

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,11 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
269269
Fortran::parser::GetLastName(arrayElement->base);
270270
return *name.symbol;
271271
}
272+
if (const auto *component =
273+
Fortran::parser::Unwrap<Fortran::parser::StructureComponent>(
274+
*designator)) {
275+
return *component->component.symbol;
276+
}
272277
} else if (const auto *name =
273278
std::get_if<Fortran::parser::Name>(&accObject.u)) {
274279
return *name->symbol;
@@ -286,17 +291,20 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
286291
mlir::acc::DataClause dataClause, bool structured,
287292
bool implicit, bool setDeclareAttr = false) {
288293
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
294+
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
289295
for (const auto &accObject : objectList.v) {
290296
llvm::SmallVector<mlir::Value> bounds;
291297
std::stringstream asFortran;
292298
mlir::Location operandLocation = genOperandLocation(converter, accObject);
299+
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
300+
Fortran::semantics::MaybeExpr designator =
301+
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
293302
Fortran::lower::AddrAndBoundsInfo info =
294303
Fortran::lower::gatherDataOperandAddrAndBounds<
295-
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
296-
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
297-
stmtCtx, accObject, operandLocation,
298-
asFortran, bounds,
299-
/*treatIndexAsSection=*/true);
304+
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
305+
converter, builder, semanticsContext, stmtCtx, symbol, designator,
306+
operandLocation, asFortran, bounds,
307+
/*treatIndexAsSection=*/true);
300308

301309
// If the input value is optional and is not a descriptor, we use the
302310
// rawInput directly.
@@ -321,16 +329,19 @@ static void genDeclareDataOperandOperations(
321329
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
322330
mlir::acc::DataClause dataClause, bool structured, bool implicit) {
323331
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
332+
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
324333
for (const auto &accObject : objectList.v) {
325334
llvm::SmallVector<mlir::Value> bounds;
326335
std::stringstream asFortran;
327336
mlir::Location operandLocation = genOperandLocation(converter, accObject);
337+
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
338+
Fortran::semantics::MaybeExpr designator =
339+
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
328340
Fortran::lower::AddrAndBoundsInfo info =
329341
Fortran::lower::gatherDataOperandAddrAndBounds<
330-
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
331-
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
332-
stmtCtx, accObject, operandLocation,
333-
asFortran, bounds);
342+
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
343+
converter, builder, semanticsContext, stmtCtx, symbol, designator,
344+
operandLocation, asFortran, bounds);
334345
EntryOp op = createDataEntryOp<EntryOp>(
335346
builder, operandLocation, info.addr, asFortran, bounds, structured,
336347
implicit, dataClause, info.addr.getType());
@@ -339,8 +350,7 @@ static void genDeclareDataOperandOperations(
339350
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
340351
mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
341352
modBuilder.setInsertionPointAfter(builder.getFunction());
342-
std::string prefix =
343-
converter.mangleName(getSymbolFromAccObject(accObject));
353+
std::string prefix = converter.mangleName(symbol);
344354
createDeclareAllocFuncWithArg<EntryOp>(
345355
modBuilder, builder, operandLocation, info.addr.getType(), prefix,
346356
asFortran, dataClause);
@@ -770,16 +780,19 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
770780
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
771781
llvm::SmallVector<mlir::Attribute> &privatizations) {
772782
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
783+
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
773784
for (const auto &accObject : objectList.v) {
774785
llvm::SmallVector<mlir::Value> bounds;
775786
std::stringstream asFortran;
776787
mlir::Location operandLocation = genOperandLocation(converter, accObject);
788+
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
789+
Fortran::semantics::MaybeExpr designator =
790+
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
777791
Fortran::lower::AddrAndBoundsInfo info =
778792
Fortran::lower::gatherDataOperandAddrAndBounds<
779-
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
780-
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
781-
stmtCtx, accObject, operandLocation,
782-
asFortran, bounds);
793+
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
794+
converter, builder, semanticsContext, stmtCtx, symbol, designator,
795+
operandLocation, asFortran, bounds);
783796
RecipeOp recipe;
784797
mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
785798
if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
@@ -1340,16 +1353,19 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
13401353
const auto &op =
13411354
std::get<Fortran::parser::AccReductionOperator>(objectList.t);
13421355
mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
1356+
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
13431357
for (const auto &accObject : objects.v) {
13441358
llvm::SmallVector<mlir::Value> bounds;
13451359
std::stringstream asFortran;
13461360
mlir::Location operandLocation = genOperandLocation(converter, accObject);
1361+
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
1362+
Fortran::semantics::MaybeExpr designator =
1363+
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
13471364
Fortran::lower::AddrAndBoundsInfo info =
13481365
Fortran::lower::gatherDataOperandAddrAndBounds<
1349-
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
1350-
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
1351-
stmtCtx, accObject, operandLocation,
1352-
asFortran, bounds);
1366+
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
1367+
converter, builder, semanticsContext, stmtCtx, symbol, designator,
1368+
operandLocation, asFortran, bounds);
13531369

13541370
mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
13551371
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -818,65 +818,61 @@ bool ClauseProcessor::processMap(
818818
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
819819
const {
820820
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
821-
return findRepeatableClause2<ClauseTy::Map>(
822-
[&](const ClauseTy::Map *mapClause,
821+
return findRepeatableClause<omp::clause::Map>(
822+
[&](const omp::clause::Map &clause,
823823
const Fortran::parser::CharBlock &source) {
824+
using Map = omp::clause::Map;
824825
mlir::Location clauseLocation = converter.genLocation(source);
825-
const auto &oMapType =
826-
std::get<std::optional<Fortran::parser::OmpMapType>>(
827-
mapClause->v.t);
826+
const auto &oMapType = std::get<std::optional<Map::MapType>>(clause.t);
828827
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
829828
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
830829
// If the map type is specified, then process it else Tofrom is the
831830
// default.
832831
if (oMapType) {
833-
const Fortran::parser::OmpMapType::Type &mapType =
834-
std::get<Fortran::parser::OmpMapType::Type>(oMapType->t);
832+
const Map::MapType::Type &mapType =
833+
std::get<Map::MapType::Type>(oMapType->t);
835834
switch (mapType) {
836-
case Fortran::parser::OmpMapType::Type::To:
835+
case Map::MapType::Type::To:
837836
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
838837
break;
839-
case Fortran::parser::OmpMapType::Type::From:
838+
case Map::MapType::Type::From:
840839
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
841840
break;
842-
case Fortran::parser::OmpMapType::Type::Tofrom:
841+
case Map::MapType::Type::Tofrom:
843842
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
844843
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
845844
break;
846-
case Fortran::parser::OmpMapType::Type::Alloc:
847-
case Fortran::parser::OmpMapType::Type::Release:
845+
case Map::MapType::Type::Alloc:
846+
case Map::MapType::Type::Release:
848847
// alloc and release is the default map_type for the Target Data
849848
// Ops, i.e. if no bits for map_type is supplied then alloc/release
850849
// is implicitly assumed based on the target directive. Default
851850
// value for Target Data and Enter Data is alloc and for Exit Data
852851
// it is release.
853852
break;
854-
case Fortran::parser::OmpMapType::Type::Delete:
853+
case Map::MapType::Type::Delete:
855854
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
856855
}
857856

858-
if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>(
859-
oMapType->t))
857+
if (std::get<std::optional<Map::MapType::Always>>(oMapType->t))
860858
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
861859
} else {
862860
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
863861
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
864862
}
865863

866-
for (const Fortran::parser::OmpObject &ompObject :
867-
std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) {
864+
for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) {
868865
llvm::SmallVector<mlir::Value> bounds;
869866
std::stringstream asFortran;
870867

871868
Fortran::lower::AddrAndBoundsInfo info =
872869
Fortran::lower::gatherDataOperandAddrAndBounds<
873-
Fortran::parser::OmpObject, mlir::omp::MapBoundsOp,
874-
mlir::omp::MapBoundsType>(
875-
converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
876-
clauseLocation, asFortran, bounds, treatIndexAsSection);
870+
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
871+
converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
872+
object.ref(), clauseLocation, asFortran, bounds,
873+
treatIndexAsSection);
877874

878-
auto origSymbol =
879-
converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
875+
auto origSymbol = converter.getSymbolAddress(*object.id());
880876
mlir::Value symAddr = info.addr;
881877
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
882878
symAddr = origSymbol;
@@ -899,7 +895,7 @@ bool ClauseProcessor::processMap(
899895
mapSymLocs->push_back(symAddr.getLoc());
900896

901897
if (mapSymbols)
902-
mapSymbols->push_back(getOmpObjectSymbol(ompObject));
898+
mapSymbols->push_back(object.id());
903899
}
904900
});
905901
}

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,6 @@ class ClauseProcessor {
162162
/// Utility to find a clause within a range in the clause list.
163163
template <typename T>
164164
static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
165-
template <typename T>
166-
static ClauseIterator2 findClause2(ClauseIterator2 begin,
167-
ClauseIterator2 end);
168165

169166
/// Return the first instance of the given clause found in the clause list or
170167
/// `nullptr` if not present. If more than one instance is expected, use
@@ -179,10 +176,6 @@ class ClauseProcessor {
179176
bool findRepeatableClause(
180177
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
181178
callbackFn) const;
182-
template <typename T>
183-
bool findRepeatableClause2(
184-
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
185-
callbackFn) const;
186179

187180
/// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
188181
template <typename T>
@@ -198,32 +191,31 @@ template <typename T>
198191
bool ClauseProcessor::processMotionClauses(
199192
Fortran::lower::StatementContext &stmtCtx,
200193
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
201-
return findRepeatableClause2<T>(
202-
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
194+
return findRepeatableClause<T>(
195+
[&](const T &clause, const Fortran::parser::CharBlock &source) {
203196
mlir::Location clauseLocation = converter.genLocation(source);
204197
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
205198

206-
static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> ||
207-
std::is_same_v<T, ClauseProcessor::ClauseTy::From>);
199+
static_assert(std::is_same_v<T, omp::clause::To> ||
200+
std::is_same_v<T, omp::clause::From>);
208201

209202
// TODO Support motion modifiers: present, mapper, iterator.
210203
constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
211-
std::is_same_v<T, ClauseProcessor::ClauseTy::To>
204+
std::is_same_v<T, omp::clause::To>
212205
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
213206
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
214207

215-
for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
208+
for (const omp::Object &object : clause.v) {
216209
llvm::SmallVector<mlir::Value> bounds;
217210
std::stringstream asFortran;
218211
Fortran::lower::AddrAndBoundsInfo info =
219212
Fortran::lower::gatherDataOperandAddrAndBounds<
220-
Fortran::parser::OmpObject, mlir::omp::MapBoundsOp,
221-
mlir::omp::MapBoundsType>(
222-
converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
223-
clauseLocation, asFortran, bounds, treatIndexAsSection);
213+
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
214+
converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
215+
object.ref(), clauseLocation, asFortran, bounds,
216+
treatIndexAsSection);
224217

225-
auto origSymbol =
226-
converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
218+
auto origSymbol = converter.getSymbolAddress(*object.id());
227219
mlir::Value symAddr = info.addr;
228220
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
229221
symAddr = origSymbol;
@@ -273,17 +265,6 @@ ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
273265
return end;
274266
}
275267

276-
template <typename T>
277-
ClauseProcessor::ClauseIterator2
278-
ClauseProcessor::findClause2(ClauseIterator2 begin, ClauseIterator2 end) {
279-
for (ClauseIterator2 it = begin; it != end; ++it) {
280-
if (std::get_if<T>(&it->u))
281-
return it;
282-
}
283-
284-
return end;
285-
}
286-
287268
template <typename T>
288269
const T *ClauseProcessor::findUniqueClause(
289270
const Fortran::parser::CharBlock **source) const {
@@ -314,24 +295,6 @@ bool ClauseProcessor::findRepeatableClause(
314295
return found;
315296
}
316297

317-
template <typename T>
318-
bool ClauseProcessor::findRepeatableClause2(
319-
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
320-
callbackFn) const {
321-
bool found = false;
322-
ClauseIterator2 nextIt, endIt = clauses2.v.end();
323-
for (ClauseIterator2 it = clauses2.v.begin(); it != endIt; it = nextIt) {
324-
nextIt = findClause2<T>(it, endIt);
325-
326-
if (nextIt != endIt) {
327-
callbackFn(&std::get<T>(nextIt->u), nextIt->source);
328-
found = true;
329-
++nextIt;
330-
}
331-
}
332-
return found;
333-
}
334-
335298
template <typename T>
336299
bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
337300
if (findUniqueClause<T>()) {

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -930,11 +930,8 @@ static OpTy genTargetEnterExitDataUpdateOp(
930930
cp.processNowait(nowaitAttr);
931931

932932
if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) {
933-
cp.processMotionClauses<Fortran::parser::OmpClause::To>(stmtCtx,
934-
mapOperands);
935-
cp.processMotionClauses<Fortran::parser::OmpClause::From>(stmtCtx,
936-
mapOperands);
937-
933+
cp.processMotionClauses<clause::To>(stmtCtx, mapOperands);
934+
cp.processMotionClauses<clause::From>(stmtCtx, mapOperands);
938935
} else {
939936
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
940937
}

flang/test/Lower/OpenACC/acc-bounds.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ subroutine acc_optional_data3(a, n)
184184
! CHECK: fir.result %c0{{.*}} : index
185185
! CHECK: }
186186
! CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[STRIDE]] : index) startIdx(%c1 : index) {strideInBytes = true}
187-
! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%14) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
187+
! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
188188
! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref<!fir.array<?xf32>>) {
189189

190190
end module

0 commit comments

Comments
 (0)