Skip to content

Commit 9543e9e

Browse files
authored
[flang][OpenMP] Handle pre-detemined lastprivate for simd (#129507)
This PR tries to fix `lastprivate` update issues in composite constructs. In particular, pre-determined `lastprivate` symbols are attached to the wrong leaf of the composite construct (the outermost one). When using delayed privatization (should be the default mode in the future), this results in trying to update the `lastprivate` symbol in the wrong construct (outside the `omp.loop_nest` op). For example, given the following input: ```fortran !$omp target teams distribute parallel do simd collapse(2) private(y_max) do i=x_min,x_max do j=y_min,y_max enddo enddo ``` Without the fixes introduced in this PR, the `DataSharingProcessor` tries to generate the `lastprivate` update ops in the `parallel` op since this is the op for which the DSP instance is created. The fix consists of 2 main parts: 1. Instead of creating a single DSP instance, one instance is created for the leaf constructs that might need privatization (whether for explicit, implicit, or pre-determined symbols). 2. When generating the `lastprivate` comparison ops, we don't directly use the SSA values of the UBs and steps. Instead, we regenerated these SSA values from the original loop bounds' expressions. We have to do this to avoid using `host_eval` values in the `lastprivate` comparison logic which is illegal.
1 parent e15545c commit 9543e9e

File tree

11 files changed

+309
-167
lines changed

11 files changed

+309
-167
lines changed

flang/lib/Lower/OpenMP/ClauseFinder.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
//===-- Lower/OpenMP/ClauseFinder.h --------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10+
//
11+
//===----------------------------------------------------------------------===//
12+
#ifndef FORTRAN_LOWER_CLAUSEFINDER_H
13+
#define FORTRAN_LOWER_CLAUSEFINDER_H
14+
15+
#include "Clauses.h"
16+
17+
namespace Fortran {
18+
namespace lower {
19+
namespace omp {
20+
21+
class ClauseFinder {
22+
using ClauseIterator = List<Clause>::const_iterator;
23+
24+
public:
25+
/// Utility to find a clause within a range in the clause list.
26+
template <typename T>
27+
static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end) {
28+
for (ClauseIterator it = begin; it != end; ++it) {
29+
if (std::get_if<T>(&it->u))
30+
return it;
31+
}
32+
33+
return end;
34+
}
35+
36+
/// Return the first instance of the given clause found in the clause list or
37+
/// `nullptr` if not present. If more than one instance is expected, use
38+
/// `findRepeatableClause` instead.
39+
template <typename T>
40+
static const T *findUniqueClause(const List<Clause> &clauses,
41+
const parser::CharBlock **source = nullptr) {
42+
ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
43+
if (it != clauses.end()) {
44+
if (source)
45+
*source = &it->source;
46+
return &std::get<T>(it->u);
47+
}
48+
return nullptr;
49+
}
50+
51+
/// Call `callbackFn` for each occurrence of the given clause. Return `true`
52+
/// if at least one instance was found.
53+
template <typename T>
54+
static bool findRepeatableClause(
55+
const List<Clause> &clauses,
56+
std::function<void(const T &, const parser::CharBlock &source)>
57+
callbackFn) {
58+
bool found = false;
59+
ClauseIterator nextIt, endIt = clauses.end();
60+
for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
61+
nextIt = findClause<T>(it, endIt);
62+
63+
if (nextIt != endIt) {
64+
callbackFn(std::get<T>(nextIt->u), nextIt->source);
65+
found = true;
66+
++nextIt;
67+
}
68+
}
69+
return found;
70+
}
71+
};
72+
} // namespace omp
73+
} // namespace lower
74+
} // namespace Fortran
75+
76+
#endif // FORTRAN_LOWER_CLAUSEFINDER_H

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 3 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "ClauseProcessor.h"
1414
#include "Clauses.h"
15+
#include "Utils.h"
1516

1617
#include "flang/Lower/PFTBuilder.h"
1718
#include "flang/Parser/tools.h"
@@ -201,24 +202,6 @@ static void addUseDeviceClause(
201202
useDeviceSyms.push_back(object.sym());
202203
}
203204

204-
static void convertLoopBounds(lower::AbstractConverter &converter,
205-
mlir::Location loc,
206-
mlir::omp::LoopRelatedClauseOps &result,
207-
std::size_t loopVarTypeSize) {
208-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
209-
// The types of lower bound, upper bound, and step are converted into the
210-
// type of the loop variable if necessary.
211-
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
212-
for (unsigned it = 0; it < (unsigned)result.loopLowerBounds.size(); it++) {
213-
result.loopLowerBounds[it] = firOpBuilder.createConvert(
214-
loc, loopVarType, result.loopLowerBounds[it]);
215-
result.loopUpperBounds[it] = firOpBuilder.createConvert(
216-
loc, loopVarType, result.loopUpperBounds[it]);
217-
result.loopSteps[it] =
218-
firOpBuilder.createConvert(loc, loopVarType, result.loopSteps[it]);
219-
}
220-
}
221-
222205
//===----------------------------------------------------------------------===//
223206
// ClauseProcessor unique clauses
224207
//===----------------------------------------------------------------------===//
@@ -240,55 +223,8 @@ bool ClauseProcessor::processCollapse(
240223
mlir::Location currentLocation, lower::pft::Evaluation &eval,
241224
mlir::omp::LoopRelatedClauseOps &result,
242225
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
243-
bool found = false;
244-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
245-
246-
// Collect the loops to collapse.
247-
lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
248-
if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
249-
TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
250-
}
251-
252-
std::int64_t collapseValue = 1l;
253-
if (auto *clause = findUniqueClause<omp::clause::Collapse>()) {
254-
collapseValue = evaluate::ToInt64(clause->v).value();
255-
found = true;
256-
}
257-
258-
std::size_t loopVarTypeSize = 0;
259-
do {
260-
lower::pft::Evaluation *doLoop =
261-
&doConstructEval->getFirstNestedEvaluation();
262-
auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
263-
assert(doStmt && "Expected do loop to be in the nested evaluation");
264-
const auto &loopControl =
265-
std::get<std::optional<parser::LoopControl>>(doStmt->t);
266-
const parser::LoopControl::Bounds *bounds =
267-
std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
268-
assert(bounds && "Expected bounds for worksharing do loop");
269-
lower::StatementContext stmtCtx;
270-
result.loopLowerBounds.push_back(fir::getBase(
271-
converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)));
272-
result.loopUpperBounds.push_back(fir::getBase(
273-
converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)));
274-
if (bounds->step) {
275-
result.loopSteps.push_back(fir::getBase(
276-
converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)));
277-
} else { // If `step` is not present, assume it as `1`.
278-
result.loopSteps.push_back(firOpBuilder.createIntegerConstant(
279-
currentLocation, firOpBuilder.getIntegerType(32), 1));
280-
}
281-
iv.push_back(bounds->name.thing.symbol);
282-
loopVarTypeSize = std::max(loopVarTypeSize,
283-
bounds->name.thing.symbol->GetUltimate().size());
284-
collapseValue--;
285-
doConstructEval =
286-
&*std::next(doConstructEval->getNestedEvaluations().begin());
287-
} while (collapseValue > 0);
288-
289-
convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
290-
291-
return found;
226+
return collectLoopRelatedInfo(converter, currentLocation, eval, clauses,
227+
result, iv);
292228
}
293229

294230
bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx,

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#ifndef FORTRAN_LOWER_CLAUSEPROCESSOR_H
1313
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
1414

15+
#include "ClauseFinder.h"
1516
#include "Clauses.h"
1617
#include "ReductionProcessor.h"
1718
#include "Utils.h"
@@ -148,10 +149,6 @@ class ClauseProcessor {
148149
private:
149150
using ClauseIterator = List<Clause>::const_iterator;
150151

151-
/// Utility to find a clause within a range in the clause list.
152-
template <typename T>
153-
static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
154-
155152
/// Return the first instance of the given clause found in the clause list or
156153
/// `nullptr` if not present. If more than one instance is expected, use
157154
/// `findRepeatableClause` instead.
@@ -199,45 +196,17 @@ void ClauseProcessor::processTODO(mlir::Location currentLocation,
199196
(checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...);
200197
}
201198

202-
template <typename T>
203-
ClauseProcessor::ClauseIterator
204-
ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
205-
for (ClauseIterator it = begin; it != end; ++it) {
206-
if (std::get_if<T>(&it->u))
207-
return it;
208-
}
209-
210-
return end;
211-
}
212-
213199
template <typename T>
214200
const T *
215201
ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const {
216-
ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
217-
if (it != clauses.end()) {
218-
if (source)
219-
*source = &it->source;
220-
return &std::get<T>(it->u);
221-
}
222-
return nullptr;
202+
return ClauseFinder::findUniqueClause<T>(clauses, source);
223203
}
224204

225205
template <typename T>
226206
bool ClauseProcessor::findRepeatableClause(
227207
std::function<void(const T &, const parser::CharBlock &source)> callbackFn)
228208
const {
229-
bool found = false;
230-
ClauseIterator nextIt, endIt = clauses.end();
231-
for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
232-
nextIt = findClause<T>(it, endIt);
233-
234-
if (nextIt != endIt) {
235-
callbackFn(std::get<T>(nextIt->u), nextIt->source);
236-
found = true;
237-
++nextIt;
238-
}
239-
}
240-
return found;
209+
return ClauseFinder::findRepeatableClause<T>(clauses, callbackFn);
241210
}
242211

243212
template <typename T>

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,11 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
257257
return;
258258

259259
if (mlir::isa<mlir::omp::WsloopOp>(op) || mlir::isa<mlir::omp::SimdOp>(op)) {
260+
mlir::omp::LoopRelatedClauseOps result;
261+
llvm::SmallVector<const semantics::Symbol *> iv;
262+
collectLoopRelatedInfo(converter, converter.getCurrentLocation(), eval,
263+
clauses, result, iv);
264+
260265
// Update the original variable just before exiting the worksharing
261266
// loop. Conversion as follows:
262267
//
@@ -280,9 +285,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
280285
mlir::Value cmpOp;
281286
llvm::SmallVector<mlir::Value> vs;
282287
vs.reserve(loopOp.getIVs().size());
283-
for (auto [iv, ub, step] :
284-
llvm::zip_equal(loopOp.getIVs(), loopOp.getLoopUpperBounds(),
285-
loopOp.getLoopSteps())) {
288+
for (auto [iv, ub, step] : llvm::zip_equal(
289+
loopOp.getIVs(), result.loopUpperBounds, result.loopSteps)) {
286290
// v = iv + step
287291
// cmp = step < 0 ? v < ub : v > ub
288292
mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,27 +1208,27 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
12081208
if (privatize) {
12091209
// DataSharingProcessor::processStep2() may create operations before/after
12101210
// the one passed as argument. We need to treat loop wrappers and their
1211-
// nested loop as a unit, so we need to pass the top level wrapper (if
1211+
// nested loop as a unit, so we need to pass the bottom level wrapper (if
12121212
// present). Otherwise, these operations will be inserted within a
12131213
// wrapper region.
1214-
mlir::Operation *privatizationTopLevelOp = &op;
1214+
mlir::Operation *privatizationBottomLevelOp = &op;
12151215
if (auto loopNest = llvm::dyn_cast<mlir::omp::LoopNestOp>(op)) {
12161216
llvm::SmallVector<mlir::omp::LoopWrapperInterface> wrappers;
12171217
loopNest.gatherWrappers(wrappers);
12181218
if (!wrappers.empty())
1219-
privatizationTopLevelOp = &*wrappers.back();
1219+
privatizationBottomLevelOp = &*wrappers.front();
12201220
}
12211221

12221222
if (!info.dsp) {
12231223
assert(tempDsp.has_value());
1224-
tempDsp->processStep2(privatizationTopLevelOp, isLoop);
1224+
tempDsp->processStep2(privatizationBottomLevelOp, isLoop);
12251225
} else {
12261226
if (isLoop && regionArgs.size() > 0) {
12271227
for (const auto &regionArg : regionArgs) {
12281228
info.dsp->pushLoopIV(info.converter.getSymbolAddress(*regionArg));
12291229
}
12301230
}
1231-
info.dsp->processStep2(privatizationTopLevelOp, isLoop);
1231+
info.dsp->processStep2(privatizationBottomLevelOp, isLoop);
12321232
}
12331233
}
12341234
}
@@ -2741,18 +2741,20 @@ static void genCompositeDistributeParallelDoSimd(
27412741
genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc,
27422742
parallelClauseOps, parallelReductionSyms);
27432743

2744-
DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval,
2745-
/*shouldCollectPreDeterminedSymbols=*/true,
2746-
/*useDelayedPrivatization=*/true, symTable);
2747-
dsp.processStep1(&parallelClauseOps);
2744+
DataSharingProcessor parallelItemDSP(
2745+
converter, semaCtx, parallelItem->clauses, eval,
2746+
/*shouldCollectPreDeterminedSymbols=*/false,
2747+
/*useDelayedPrivatization=*/true, symTable);
2748+
parallelItemDSP.processStep1(&parallelClauseOps);
27482749

27492750
EntryBlockArgs parallelArgs;
2750-
parallelArgs.priv.syms = dsp.getDelayedPrivSymbols();
2751+
parallelArgs.priv.syms = parallelItemDSP.getDelayedPrivSymbols();
27512752
parallelArgs.priv.vars = parallelClauseOps.privateVars;
27522753
parallelArgs.reduction.syms = parallelReductionSyms;
27532754
parallelArgs.reduction.vars = parallelClauseOps.reductionVars;
27542755
genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem,
2755-
parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true);
2756+
parallelClauseOps, parallelArgs, &parallelItemDSP,
2757+
/*isComposite=*/true);
27562758

27572759
// Clause processing.
27582760
mlir::omp::DistributeOperands distributeClauseOps;
@@ -2769,6 +2771,11 @@ static void genCompositeDistributeParallelDoSimd(
27692771
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps,
27702772
simdReductionSyms);
27712773

2774+
DataSharingProcessor simdItemDSP(converter, semaCtx, simdItem->clauses, eval,
2775+
/*shouldCollectPreDeterminedSymbols=*/true,
2776+
/*useDelayedPrivatization=*/true, symTable);
2777+
simdItemDSP.processStep1(&simdClauseOps);
2778+
27722779
mlir::omp::LoopNestOperands loopNestClauseOps;
27732780
llvm::SmallVector<const semantics::Symbol *> iv;
27742781
genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
@@ -2790,7 +2797,8 @@ static void genCompositeDistributeParallelDoSimd(
27902797
wsloopOp.setComposite(/*val=*/true);
27912798

27922799
EntryBlockArgs simdArgs;
2793-
// TODO: Add private syms and vars.
2800+
simdArgs.priv.syms = simdItemDSP.getDelayedPrivSymbols();
2801+
simdArgs.priv.vars = simdClauseOps.privateVars;
27942802
simdArgs.reduction.syms = simdReductionSyms;
27952803
simdArgs.reduction.vars = simdClauseOps.reductionVars;
27962804
auto simdOp =
@@ -2802,7 +2810,8 @@ static void genCompositeDistributeParallelDoSimd(
28022810
{{distributeOp, distributeArgs},
28032811
{wsloopOp, wsloopArgs},
28042812
{simdOp, simdArgs}},
2805-
llvm::omp::Directive::OMPD_distribute_parallel_do_simd, dsp);
2813+
llvm::omp::Directive::OMPD_distribute_parallel_do_simd,
2814+
simdItemDSP);
28062815
}
28072816

28082817
static void genCompositeDistributeSimd(lower::AbstractConverter &converter,

0 commit comments

Comments
 (0)