Skip to content

Commit 7340263

Browse files
authored
Reapply "[Flang][OpenMP][Lower] NFC: Move clause processing helpers into the ClauseProcessor (#85258)" (#85807)
This patch contains slight modifications to the reverted PR #85258 to avoid issues with constructs containing multiple reduction clauses, uncovered by a test on the gfortran testsuite. This reverts commit 9f80444.
1 parent 02cb89b commit 7340263

File tree

6 files changed

+171
-80
lines changed

6 files changed

+171
-80
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,25 @@ addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
208208
useDeviceSymbols.push_back(object.id());
209209
}
210210

211+
static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
212+
mlir::Location loc,
213+
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
214+
llvm::SmallVectorImpl<mlir::Value> &upperBound,
215+
llvm::SmallVectorImpl<mlir::Value> &step,
216+
std::size_t loopVarTypeSize) {
217+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
218+
// The types of lower bound, upper bound, and step are converted into the
219+
// type of the loop variable if necessary.
220+
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
221+
for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
222+
lowerBound[it] =
223+
firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
224+
upperBound[it] =
225+
firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
226+
step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
227+
}
228+
}
229+
211230
//===----------------------------------------------------------------------===//
212231
// ClauseProcessor unique clauses
213232
//===----------------------------------------------------------------------===//
@@ -217,8 +236,7 @@ bool ClauseProcessor::processCollapse(
217236
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
218237
llvm::SmallVectorImpl<mlir::Value> &upperBound,
219238
llvm::SmallVectorImpl<mlir::Value> &step,
220-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
221-
std::size_t &loopVarTypeSize) const {
239+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const {
222240
bool found = false;
223241
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
224242

@@ -236,7 +254,7 @@ bool ClauseProcessor::processCollapse(
236254
found = true;
237255
}
238256

239-
loopVarTypeSize = 0;
257+
std::size_t loopVarTypeSize = 0;
240258
do {
241259
Fortran::lower::pft::Evaluation *doLoop =
242260
&doConstructEval->getFirstNestedEvaluation();
@@ -267,6 +285,9 @@ bool ClauseProcessor::processCollapse(
267285
&*std::next(doConstructEval->getNestedEvaluations().begin());
268286
} while (collapseValue > 0);
269287

288+
convertLoopBounds(converter, currentLocation, lowerBound, upperBound, step,
289+
loopVarTypeSize);
290+
270291
return found;
271292
}
272293

@@ -902,17 +923,39 @@ bool ClauseProcessor::processMap(
902923

903924
bool ClauseProcessor::processReduction(
904925
mlir::Location currentLocation,
905-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
906-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
907-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
908-
const {
926+
llvm::SmallVectorImpl<mlir::Value> &outReductionVars,
927+
llvm::SmallVectorImpl<mlir::Type> &outReductionTypes,
928+
llvm::SmallVectorImpl<mlir::Attribute> &outReductionDeclSymbols,
929+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
930+
*outReductionSymbols) const {
909931
return findRepeatableClause<omp::clause::Reduction>(
910932
[&](const omp::clause::Reduction &clause,
911933
const Fortran::parser::CharBlock &) {
934+
// Use local lists of reductions to prevent variables from other
935+
// already-processed reduction clauses from impacting this reduction.
936+
// For example, the whole `reductionVars` array is queried to decide
937+
// whether to do the reduction byref.
938+
llvm::SmallVector<mlir::Value> reductionVars;
939+
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
940+
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
912941
ReductionProcessor rp;
913942
rp.addDeclareReduction(currentLocation, converter, clause,
914943
reductionVars, reductionDeclSymbols,
915-
reductionSymbols);
944+
outReductionSymbols ? &reductionSymbols
945+
: nullptr);
946+
947+
// Copy local lists into the output.
948+
llvm::copy(reductionVars, std::back_inserter(outReductionVars));
949+
llvm::copy(reductionDeclSymbols,
950+
std::back_inserter(outReductionDeclSymbols));
951+
if (outReductionSymbols)
952+
llvm::copy(reductionSymbols,
953+
std::back_inserter(*outReductionSymbols));
954+
955+
outReductionTypes.reserve(outReductionTypes.size() +
956+
reductionVars.size());
957+
llvm::transform(reductionVars, std::back_inserter(outReductionTypes),
958+
[](mlir::Value v) { return v.getType(); });
916959
});
917960
}
918961

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,12 @@ class ClauseProcessor {
5656
clauses(makeList(clauses, semaCtx)) {}
5757

5858
// 'Unique' clauses: They can appear at most once in the clause list.
59-
bool
60-
processCollapse(mlir::Location currentLocation,
61-
Fortran::lower::pft::Evaluation &eval,
62-
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
63-
llvm::SmallVectorImpl<mlir::Value> &upperBound,
64-
llvm::SmallVectorImpl<mlir::Value> &step,
65-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
66-
std::size_t &loopVarTypeSize) const;
59+
bool processCollapse(
60+
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
61+
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
62+
llvm::SmallVectorImpl<mlir::Value> &upperBound,
63+
llvm::SmallVectorImpl<mlir::Value> &step,
64+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const;
6765
bool processDefault() const;
6866
bool processDevice(Fortran::lower::StatementContext &stmtCtx,
6967
mlir::Value &result) const;
@@ -126,6 +124,7 @@ class ClauseProcessor {
126124
bool
127125
processReduction(mlir::Location currentLocation,
128126
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
127+
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
129128
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
130129
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
131130
*reductionSymbols = nullptr) const;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 10 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -214,24 +214,6 @@ static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter,
214214
firOpBuilder.restoreInsertionPoint(insPt);
215215
}
216216

217-
static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
218-
std::size_t loopVarTypeSize) {
219-
// OpenMP runtime requires 32-bit or 64-bit loop variables.
220-
loopVarTypeSize = loopVarTypeSize * 8;
221-
if (loopVarTypeSize < 32) {
222-
loopVarTypeSize = 32;
223-
} else if (loopVarTypeSize > 64) {
224-
loopVarTypeSize = 64;
225-
mlir::emitWarning(converter.getCurrentLocation(),
226-
"OpenMP loop iteration variable cannot have more than 64 "
227-
"bits size and will be narrowed into 64 bits.");
228-
}
229-
assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
230-
"OpenMP loop iteration variable size must be transformed into 32-bit "
231-
"or 64-bit");
232-
return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
233-
}
234-
235217
static mlir::Operation *
236218
createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
237219
mlir::Location loc, mlir::Value indexVal,
@@ -568,6 +550,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
568550
mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
569551
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
570552
reductionVars;
553+
llvm::SmallVector<mlir::Type> reductionTypes;
571554
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
572555
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
573556

@@ -578,13 +561,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
578561
cp.processDefault();
579562
cp.processAllocate(allocatorOperands, allocateOperands);
580563
if (!outerCombined)
581-
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
582-
&reductionSymbols);
583-
584-
llvm::SmallVector<mlir::Type> reductionTypes;
585-
reductionTypes.reserve(reductionVars.size());
586-
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
587-
[](mlir::Value v) { return v.getType(); });
564+
cp.processReduction(currentLocation, reductionVars, reductionTypes,
565+
reductionDeclSymbols, &reductionSymbols);
588566

589567
auto reductionCallback = [&](mlir::Operation *op) {
590568
llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
@@ -1465,25 +1443,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
14651443
standaloneConstruct.u);
14661444
}
14671445

1468-
static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
1469-
mlir::Location loc,
1470-
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
1471-
llvm::SmallVectorImpl<mlir::Value> &upperBound,
1472-
llvm::SmallVectorImpl<mlir::Value> &step,
1473-
std::size_t loopVarTypeSize) {
1474-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1475-
// The types of lower bound, upper bound, and step are converted into the
1476-
// type of the loop variable if necessary.
1477-
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
1478-
for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
1479-
lowerBound[it] =
1480-
firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
1481-
upperBound[it] =
1482-
firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
1483-
step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
1484-
}
1485-
}
1486-
14871446
static llvm::SmallVector<const Fortran::semantics::Symbol *>
14881447
genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
14891448
mlir::Location &loc,
@@ -1517,7 +1476,7 @@ genLoopAndReductionVars(
15171476
mlir::Location &loc,
15181477
llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
15191478
llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
1520-
llvm::SmallVectorImpl<mlir::Type> &reductionTypes) {
1479+
llvm::ArrayRef<mlir::Type> reductionTypes) {
15211480
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
15221481

15231482
llvm::SmallVector<mlir::Type> blockArgTypes;
@@ -1579,16 +1538,15 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
15791538
llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
15801539
llvm::SmallVector<mlir::Value> alignedVars, nontemporalVars;
15811540
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
1541+
llvm::SmallVector<mlir::Type> reductionTypes;
15821542
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
15831543
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
15841544
mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand;
1585-
std::size_t loopVarTypeSize;
15861545

15871546
ClauseProcessor cp(converter, semaCtx, loopOpClauseList);
1588-
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv,
1589-
loopVarTypeSize);
1547+
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
15901548
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
1591-
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
1549+
cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols);
15921550
cp.processIf(clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
15931551
cp.processSimdlen(simdlenClauseOperand);
15941552
cp.processSafelen(safelenClauseOperand);
@@ -1598,9 +1556,6 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
15981556
Fortran::parser::OmpClause::Nontemporal,
15991557
Fortran::parser::OmpClause::Order>(loc, ompDirective);
16001558

1601-
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
1602-
loopVarTypeSize);
1603-
16041559
mlir::TypeRange resultType;
16051560
auto simdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
16061561
loc, resultType, lowerBound, upperBound, step, alignedVars,
@@ -1638,27 +1593,23 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
16381593
llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
16391594
llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
16401595
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
1596+
llvm::SmallVector<mlir::Type> reductionTypes;
16411597
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
16421598
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
16431599
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
16441600
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
16451601
mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
16461602
mlir::IntegerAttr orderedClauseOperand;
16471603
mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
1648-
std::size_t loopVarTypeSize;
16491604

16501605
ClauseProcessor cp(converter, semaCtx, beginClauseList);
1651-
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv,
1652-
loopVarTypeSize);
1606+
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
16531607
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
1654-
cp.processReduction(loc, reductionVars, reductionDeclSymbols,
1608+
cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols,
16551609
&reductionSymbols);
16561610
cp.processTODO<Fortran::parser::OmpClause::Linear,
16571611
Fortran::parser::OmpClause::Order>(loc, ompDirective);
16581612

1659-
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
1660-
loopVarTypeSize);
1661-
16621613
if (ReductionProcessor::doReductionByRef(reductionVars))
16631614
byrefOperand = firOpBuilder.getUnitAttr();
16641615

@@ -1699,11 +1650,6 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
16991650
auto *nestedEval = getCollapsedLoopEval(
17001651
eval, Fortran::lower::getCollapseValue(beginClauseList));
17011652

1702-
llvm::SmallVector<mlir::Type> reductionTypes;
1703-
reductionTypes.reserve(reductionVars.size());
1704-
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
1705-
[](mlir::Value v) { return v.getType(); });
1706-
17071653
auto ivCallback = [&](mlir::Operation *op) {
17081654
return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols,
17091655
reductionTypes);

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <flang/Lower/AbstractConverter.h>
1717
#include <flang/Lower/ConvertType.h>
18+
#include <flang/Optimizer/Builder/FIRBuilder.h>
1819
#include <flang/Parser/parse-tree.h>
1920
#include <flang/Parser/tools.h>
2021
#include <flang/Semantics/tools.h>
@@ -70,6 +71,24 @@ void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
7071
}
7172
}
7273

74+
mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
75+
std::size_t loopVarTypeSize) {
76+
// OpenMP runtime requires 32-bit or 64-bit loop variables.
77+
loopVarTypeSize = loopVarTypeSize * 8;
78+
if (loopVarTypeSize < 32) {
79+
loopVarTypeSize = 32;
80+
} else if (loopVarTypeSize > 64) {
81+
loopVarTypeSize = 64;
82+
mlir::emitWarning(converter.getCurrentLocation(),
83+
"OpenMP loop iteration variable cannot have more than 64 "
84+
"bits size and will be narrowed into 64 bits.");
85+
}
86+
assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
87+
"OpenMP loop iteration variable size must be transformed into 32-bit "
88+
"or 64-bit");
89+
return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
90+
}
91+
7392
void gatherFuncAndVarSyms(
7493
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
7594
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {

flang/lib/Lower/OpenMP/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
5151
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
5252
bool isVal = false);
5353

54+
mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
55+
std::size_t loopVarTypeSize);
56+
5457
void gatherFuncAndVarSyms(
5558
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
5659
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);

0 commit comments

Comments
 (0)