Skip to content

Commit ca4dbc2

Browse files
authored
[Flang][OpenMP][Lower] Update workshare-loop lowering (5/5) (#89215)
This patch updates lowering from PFT to MLIR of workshare loops to follow the loop wrapper approach. Unit tests impacted by this change are also updated. As the last patch of the stack, this should compile and pass unit tests.
1 parent 2e37f28 commit ca4dbc2

File tree

97 files changed

+3543
-2943
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+3543
-2943
lines changed

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ void DataSharingProcessor::insertBarrier() {
135135
}
136136

137137
void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
138+
mlir::omp::LoopNestOp loopOp;
139+
if (auto wrapper = mlir::dyn_cast<mlir::omp::LoopWrapperInterface>(op))
140+
loopOp = wrapper.isWrapper()
141+
? mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop())
142+
: nullptr;
143+
138144
bool cmpCreated = false;
139145
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
140146
for (const omp::Clause &clause : clauses) {
@@ -214,33 +220,34 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
214220
// Update the original variable just before exiting the worksharing
215221
// loop. Conversion as follows:
216222
//
217-
// omp.wsloop {
218-
// omp.wsloop { ...
219-
// ... store
220-
// store ===> %v = arith.addi %iv, %step
221-
// omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub
222-
// } fir.if %cmp {
223-
// fir.store %v to %loopIV
224-
// ^%lpv_update_blk:
225-
// }
226-
// omp.yield
227-
// }
228-
//
223+
// omp.wsloop { omp.wsloop {
224+
// omp.loop_nest { omp.loop_nest {
225+
// ... ...
226+
// store ===> store
227+
// omp.yield %v = arith.addi %iv, %step
228+
// } %cmp = %step < 0 ? %v < %ub : %v > %ub
229+
// omp.terminator fir.if %cmp {
230+
// } fir.store %v to %loopIV
231+
// ^%lpv_update_blk:
232+
// }
233+
// omp.yield
234+
// }
235+
// omp.terminator
236+
// }
229237

230238
// Only generate the compare once in presence of multiple LastPrivate
231239
// clauses.
232240
if (cmpCreated)
233241
continue;
234242
cmpCreated = true;
235243

236-
mlir::Location loc = op->getLoc();
237-
mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
244+
mlir::Location loc = loopOp.getLoc();
245+
mlir::Operation *lastOper = loopOp.getRegion().back().getTerminator();
238246
firOpBuilder.setInsertionPoint(lastOper);
239247

240-
mlir::Value iv = op->getRegion(0).front().getArguments()[0];
241-
mlir::Value ub =
242-
mlir::dyn_cast<mlir::omp::WsloopOp>(op).getUpperBound()[0];
243-
mlir::Value step = mlir::dyn_cast<mlir::omp::WsloopOp>(op).getStep()[0];
248+
mlir::Value iv = loopOp.getIVs()[0];
249+
mlir::Value ub = loopOp.getUpperBound()[0];
250+
mlir::Value step = loopOp.getStep()[0];
244251

245252
// v = iv + step
246253
// cmp = step < 0 ? v < ub : v > ub
@@ -259,7 +266,7 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
259266
auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
260267
firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
261268
assert(loopIV && "loopIV was not set");
262-
firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
269+
firOpBuilder.create<fir::StoreOp>(loopOp.getLoc(), v, loopIV);
263270
lastPrivIP = firOpBuilder.saveInsertionPoint();
264271
} else {
265272
TODO(converter.getCurrentLocation(),

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 61 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,29 @@ getDeclareTargetFunctionDevice(
366366
return std::nullopt;
367367
}
368368

369-
static llvm::SmallVector<const Fortran::semantics::Symbol *>
369+
/// Set up the entry block of the given `omp.loop_nest` operation, adding a
370+
/// block argument for each loop induction variable and allocating and
371+
/// initializing a private value to hold each of them.
372+
///
373+
/// This function can also bind the symbols of any variables that should match
374+
/// block arguments on parent loop wrapper operations attached to the same
375+
/// loop. This allows the introduction of any necessary `hlfir.declare`
376+
/// operations inside of the entry block of the `omp.loop_nest` operation and
377+
/// not directly under any of the wrappers, which would invalidate them.
378+
///
379+
/// \param [in] op - the loop nest operation.
380+
/// \param [in] converter - PFT to MLIR conversion interface.
381+
/// \param [in] loc - location.
382+
/// \param [in] args - symbols of induction variables.
383+
/// \param [in] wrapperSyms - symbols of variables to be mapped to loop wrapper
384+
/// entry block arguments.
385+
/// \param [in] wrapperArgs - entry block arguments of parent loop wrappers.
386+
static void
370387
genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
371388
mlir::Location &loc,
372-
llvm::ArrayRef<const Fortran::semantics::Symbol *> args) {
389+
llvm::ArrayRef<const Fortran::semantics::Symbol *> args,
390+
llvm::ArrayRef<const Fortran::semantics::Symbol *> wrapperSyms = {},
391+
llvm::ArrayRef<mlir::BlockArgument> wrapperArgs = {}) {
373392
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
374393
auto &region = op->getRegion(0);
375394

@@ -380,6 +399,12 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
380399
llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
381400
llvm::SmallVector<mlir::Location> locs(args.size(), loc);
382401
firOpBuilder.createBlock(&region, {}, tiv, locs);
402+
403+
// Bind the entry block arguments of parent wrappers to the corresponding
404+
// symbols.
405+
for (auto [arg, prv] : llvm::zip_equal(wrapperSyms, wrapperArgs))
406+
converter.bindSymbol(*arg, prv);
407+
383408
// The argument is not currently in memory, so make a temporary for the
384409
// argument, and store it there, then bind that location to the argument.
385410
mlir::Operation *storeOp = nullptr;
@@ -389,7 +414,6 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
389414
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
390415
}
391416
firOpBuilder.setInsertionPointAfter(storeOp);
392-
return llvm::SmallVector<const Fortran::semantics::Symbol *>(args);
393417
}
394418

395419
static void genReductionVars(
@@ -410,58 +434,6 @@ static void genReductionVars(
410434
}
411435
}
412436

413-
static llvm::SmallVector<const Fortran::semantics::Symbol *>
414-
genLoopAndReductionVars(
415-
mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
416-
mlir::Location &loc,
417-
llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
418-
llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
419-
llvm::ArrayRef<mlir::Type> reductionTypes) {
420-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
421-
422-
llvm::SmallVector<mlir::Type> blockArgTypes;
423-
llvm::SmallVector<mlir::Location> blockArgLocs;
424-
blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
425-
blockArgLocs.reserve(blockArgTypes.size());
426-
mlir::Block *entryBlock;
427-
428-
if (loopArgs.size()) {
429-
std::size_t loopVarTypeSize = 0;
430-
for (const Fortran::semantics::Symbol *arg : loopArgs)
431-
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
432-
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
433-
std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(),
434-
loopVarType);
435-
std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc);
436-
}
437-
if (reductionArgs.size()) {
438-
llvm::copy(reductionTypes, std::back_inserter(blockArgTypes));
439-
std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc);
440-
}
441-
entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes,
442-
blockArgLocs);
443-
// The argument is not currently in memory, so make a temporary for the
444-
// argument, and store it there, then bind that location to the argument.
445-
if (loopArgs.size()) {
446-
mlir::Operation *storeOp = nullptr;
447-
for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
448-
mlir::Value indexVal =
449-
fir::getBase(op->getRegion(0).front().getArgument(argIndex));
450-
storeOp =
451-
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
452-
}
453-
firOpBuilder.setInsertionPointAfter(storeOp);
454-
}
455-
// Bind the reduction arguments to their block arguments
456-
for (auto [arg, prv] : llvm::zip_equal(
457-
reductionArgs,
458-
llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) {
459-
converter.bindSymbol(*arg, prv);
460-
}
461-
462-
return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs);
463-
}
464-
465437
static void
466438
markDeclareTarget(mlir::Operation *op,
467439
Fortran::lower::AbstractConverter &converter,
@@ -1270,20 +1242,16 @@ static void genTeamsClauses(Fortran::lower::AbstractConverter &converter,
12701242
static void genWsloopClauses(
12711243
Fortran::lower::AbstractConverter &converter,
12721244
Fortran::semantics::SemanticsContext &semaCtx,
1273-
Fortran::lower::StatementContext &stmtCtx,
1274-
Fortran::lower::pft::Evaluation &eval, const List<Clause> &clauses,
1245+
Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
12751246
mlir::Location loc, mlir::omp::WsloopClauseOps &clauseOps,
1276-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
12771247
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
12781248
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
12791249
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
12801250
ClauseProcessor cp(converter, semaCtx, clauses);
1281-
cp.processCollapse(loc, eval, clauseOps, iv);
12821251
cp.processNowait(clauseOps);
12831252
cp.processOrdered(clauseOps);
12841253
cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
12851254
cp.processSchedule(stmtCtx, clauseOps);
1286-
clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
12871255
// TODO Support delayed privatization.
12881256

12891257
if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
@@ -1526,7 +1494,8 @@ genSimdOp(Fortran::lower::AbstractConverter &converter,
15261494
auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(clauses));
15271495

15281496
auto ivCallback = [&](mlir::Operation *op) {
1529-
return genLoopVars(op, converter, loc, iv);
1497+
genLoopVars(op, converter, loc, iv);
1498+
return iv;
15301499
};
15311500

15321501
createBodyOfOp(*loopOp,
@@ -1801,32 +1770,48 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
18011770
Fortran::semantics::SemanticsContext &semaCtx,
18021771
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
18031772
const List<Clause> &clauses) {
1773+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
18041774
DataSharingProcessor dsp(converter, semaCtx, clauses, eval);
18051775
dsp.processStep1();
18061776

18071777
Fortran::lower::StatementContext stmtCtx;
1808-
mlir::omp::WsloopClauseOps clauseOps;
1778+
mlir::omp::LoopNestClauseOps loopClauseOps;
1779+
mlir::omp::WsloopClauseOps wsClauseOps;
18091780
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
18101781
llvm::SmallVector<mlir::Type> reductionTypes;
18111782
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
1812-
genWsloopClauses(converter, semaCtx, stmtCtx, eval, clauses, loc, clauseOps,
1813-
iv, reductionTypes, reductionSyms);
1783+
genLoopNestClauses(converter, semaCtx, eval, clauses, loc, loopClauseOps, iv);
1784+
genWsloopClauses(converter, semaCtx, stmtCtx, clauses, loc, wsClauseOps,
1785+
reductionTypes, reductionSyms);
1786+
1787+
// Create omp.wsloop wrapper and populate entry block arguments with reduction
1788+
// variables.
1789+
auto wsloopOp = firOpBuilder.create<mlir::omp::WsloopOp>(loc, wsClauseOps);
1790+
llvm::SmallVector<mlir::Location> reductionLocs(reductionSyms.size(), loc);
1791+
mlir::Block *wsloopEntryBlock = firOpBuilder.createBlock(
1792+
&wsloopOp.getRegion(), {}, reductionTypes, reductionLocs);
1793+
firOpBuilder.setInsertionPoint(
1794+
Fortran::lower::genOpenMPTerminator(firOpBuilder, wsloopOp, loc));
1795+
1796+
// Create nested omp.loop_nest and fill body with loop contents.
1797+
auto loopOp = firOpBuilder.create<mlir::omp::LoopNestOp>(loc, loopClauseOps);
18141798

18151799
auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(clauses));
18161800

18171801
auto ivCallback = [&](mlir::Operation *op) {
1818-
return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms,
1819-
reductionTypes);
1802+
genLoopVars(op, converter, loc, iv, reductionSyms,
1803+
wsloopEntryBlock->getArguments());
1804+
return iv;
18201805
};
18211806

1822-
return genOpWithBody<mlir::omp::WsloopOp>(
1823-
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval,
1824-
llvm::omp::Directive::OMPD_do)
1825-
.setClauses(&clauses)
1826-
.setDataSharingProcessor(&dsp)
1827-
.setReductions(&reductionSyms, &reductionTypes)
1828-
.setGenRegionEntryCb(ivCallback),
1829-
clauseOps);
1807+
createBodyOfOp(*loopOp,
1808+
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval,
1809+
llvm::omp::Directive::OMPD_do)
1810+
.setClauses(&clauses)
1811+
.setDataSharingProcessor(&dsp)
1812+
.setReductions(&reductionSyms, &reductionTypes)
1813+
.setGenRegionEntryCb(ivCallback));
1814+
return wsloopOp;
18301815
}
18311816

18321817
//===----------------------------------------------------------------------===//
@@ -2482,8 +2467,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
24822467
mlir::Operation *Fortran::lower::genOpenMPTerminator(fir::FirOpBuilder &builder,
24832468
mlir::Operation *op,
24842469
mlir::Location loc) {
2485-
if (mlir::isa<mlir::omp::WsloopOp, mlir::omp::DeclareReductionOp,
2486-
mlir::omp::AtomicUpdateOp, mlir::omp::LoopNestOp>(op))
2470+
if (mlir::isa<mlir::omp::AtomicUpdateOp, mlir::omp::DeclareReductionOp,
2471+
mlir::omp::LoopNestOp>(op))
24872472
return builder.create<mlir::omp::YieldOp>(loc);
24882473
return builder.create<mlir::omp::TerminatorOp>(loc);
24892474
}

0 commit comments

Comments
 (0)