11
11
#include " flang/Optimizer/OpenMP/Utils.h"
12
12
#include " mlir/Analysis/SliceAnalysis.h"
13
13
#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
14
+ #include " mlir/IR/IRMapping.h"
14
15
#include " mlir/Transforms/DialectConversion.h"
15
16
#include " mlir/Transforms/RegionUtils.h"
16
17
@@ -24,8 +25,126 @@ namespace flangomp {
24
25
25
26
namespace {
26
27
namespace looputils {
27
- using LoopNest = llvm::SetVector<fir::DoLoopOp>;
28
+ // / Stores info needed about the induction/iteration variable for each `do
29
+ // / concurrent` in a loop nest. This includes:
30
+ // / * the operation allocating memory for iteration variable,
31
+ // / * the operation(s) updating the iteration variable with the current
32
+ // / iteration number.
33
+ struct InductionVariableInfo {
34
+ mlir::Operation *iterVarMemDef;
35
+ llvm::SetVector<mlir::Operation *> indVarUpdateOps;
36
+ };
37
+
38
+ using LoopNestToIndVarMap =
39
+ llvm::MapVector<fir::DoLoopOp, InductionVariableInfo>;
40
+
41
+ // / Given an operation `op`, this returns true if one of `op`'s operands is
42
+ // / "ultimately" the loop's induction variable. This helps in cases where the
43
+ // / induction variable's use is "hidden" behind a convert/cast.
44
+ // /
45
+ // / For example, give the following loop:
46
+ // / ```
47
+ // / fir.do_loop %ind_var = %lb to %ub step %s unordered {
48
+ // / %ind_var_conv = fir.convert %ind_var : (index) -> i32
49
+ // / fir.store %ind_var_conv to %i#1 : !fir.ref<i32>
50
+ // / ...
51
+ // / }
52
+ // / ```
53
+ // /
54
+ // / If \p op is the `fir.store` operation, then this function will return true
55
+ // / since the IV is the "ultimate" opeerand to the `fir.store` op through the
56
+ // / `%ind_var_conv` -> `%ind_var` conversion sequence.
57
+ // /
58
+ // / For why this is useful, see its use in `findLoopIndVarMemDecl`.
59
+ bool isIndVarUltimateOperand (mlir::Operation *op, fir::DoLoopOp doLoop) {
60
+ while (op != nullptr && op->getNumOperands () > 0 ) {
61
+ auto ivIt = llvm::find_if (op->getOperands (), [&](mlir::Value operand) {
62
+ return operand == doLoop.getInductionVar ();
63
+ });
64
+
65
+ if (ivIt != op->getOperands ().end ())
66
+ return true ;
67
+
68
+ op = op->getOperand (0 ).getDefiningOp ();
69
+ }
70
+
71
+ return false ;
72
+ }
73
+
74
+ // / For the \p doLoop parameter, find the operation that declares its iteration
75
+ // / variable or allocates memory for it.
76
+ // /
77
+ // / For example, give the following loop:
78
+ // / ```
79
+ // / ...
80
+ // / %i:2 = hlfir.declare %0 {uniq_name = "_QFEi"} : ...
81
+ // / ...
82
+ // / fir.do_loop %ind_var = %lb to %ub step %s unordered {
83
+ // / %ind_var_conv = fir.convert %ind_var : (index) -> i32
84
+ // / fir.store %ind_var_conv to %i#1 : !fir.ref<i32>
85
+ // / ...
86
+ // / }
87
+ // / ```
88
+ // /
89
+ // / This function returns the `hlfir.declare` op for `%i`.
90
+ mlir::Operation *findLoopIterationVarMemDecl (fir::DoLoopOp doLoop) {
91
+ mlir::Value result = nullptr ;
92
+ mlir::visitUsedValuesDefinedAbove (
93
+ doLoop.getRegion (), [&](mlir::OpOperand *operand) {
94
+ if (result)
95
+ return ;
96
+
97
+ if (isIndVarUltimateOperand (operand->getOwner (), doLoop)) {
98
+ assert (result == nullptr &&
99
+ " loop can have only one induction variable" );
100
+ result = operand->get ();
101
+ }
102
+ });
103
+
104
+ assert (result != nullptr && result.getDefiningOp () != nullptr );
105
+ return result.getDefiningOp ();
106
+ }
28
107
108
+ // / Collects the op(s) responsible for updating a loop's iteration variable with
109
+ // / the current iteration number. For example, for the input IR:
110
+ // / ```
111
+ // / %i = fir.alloca i32 {bindc_name = "i"}
112
+ // / %i_decl:2 = hlfir.declare %i ...
113
+ // / ...
114
+ // / fir.do_loop %i_iv = %lb to %ub step %step unordered {
115
+ // / %1 = fir.convert %i_iv : (index) -> i32
116
+ // / fir.store %1 to %i_decl#1 : !fir.ref<i32>
117
+ // / ...
118
+ // / }
119
+ // / ```
120
+ // / this function would return the first 2 ops in the `fir.do_loop`'s region.
121
+ llvm::SetVector<mlir::Operation *>
122
+ extractIndVarUpdateOps (fir::DoLoopOp doLoop) {
123
+ mlir::Value indVar = doLoop.getInductionVar ();
124
+ llvm::SetVector<mlir::Operation *> indVarUpdateOps;
125
+
126
+ llvm::SmallVector<mlir::Value> toProcess;
127
+ toProcess.push_back (indVar);
128
+
129
+ llvm::DenseSet<mlir::Value> done;
130
+
131
+ while (!toProcess.empty ()) {
132
+ mlir::Value val = toProcess.back ();
133
+ toProcess.pop_back ();
134
+
135
+ if (!done.insert (val).second )
136
+ continue ;
137
+
138
+ for (mlir::Operation *user : val.getUsers ()) {
139
+ indVarUpdateOps.insert (user);
140
+
141
+ for (mlir::Value result : user->getResults ())
142
+ toProcess.push_back (result);
143
+ }
144
+ }
145
+
146
+ return std::move (indVarUpdateOps);
147
+ }
29
148
// / Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
30
149
// / there are no operations in \p outerloop's body other than:
31
150
// /
@@ -93,11 +212,16 @@ bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
93
212
// / recognize a certain nested loop as part of the nest it just returns the
94
213
// / parent loops it discovered before.
95
214
mlir::LogicalResult collectLoopNest (fir::DoLoopOp currentLoop,
96
- LoopNest &loopNest) {
215
+ LoopNestToIndVarMap &loopNest) {
97
216
assert (currentLoop.getUnordered ());
98
217
99
218
while (true ) {
100
- loopNest.insert (currentLoop);
219
+ loopNest.try_emplace (
220
+ currentLoop,
221
+ InductionVariableInfo{
222
+ findLoopIterationVarMemDecl (currentLoop),
223
+ std::move (looputils::extractIndVarUpdateOps (currentLoop))});
224
+
101
225
auto directlyNestedLoops = currentLoop.getRegion ().getOps <fir::DoLoopOp>();
102
226
llvm::SmallVector<fir::DoLoopOp> unorderedLoops;
103
227
@@ -127,26 +251,136 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
127
251
public:
128
252
using mlir::OpConversionPattern<fir::DoLoopOp>::OpConversionPattern;
129
253
130
- DoConcurrentConversion (mlir::MLIRContext *context, bool mapToDevice)
131
- : OpConversionPattern(context), mapToDevice(mapToDevice) {}
254
+ DoConcurrentConversion (mlir::MLIRContext *context, bool mapToDevice,
255
+ llvm::DenseSet<fir::DoLoopOp> &concurrentLoopsToSkip)
256
+ : OpConversionPattern(context), mapToDevice(mapToDevice),
257
+ concurrentLoopsToSkip (concurrentLoopsToSkip) {}
132
258
133
259
mlir::LogicalResult
134
260
matchAndRewrite (fir::DoLoopOp doLoop, OpAdaptor adaptor,
135
261
mlir::ConversionPatternRewriter &rewriter) const override {
136
- looputils::LoopNest loopNest;
262
+ looputils::LoopNestToIndVarMap loopNest;
137
263
bool hasRemainingNestedLoops =
138
264
failed (looputils::collectLoopNest (doLoop, loopNest));
139
265
if (hasRemainingNestedLoops)
140
266
mlir::emitWarning (doLoop.getLoc (),
141
267
" Some `do concurent` loops are not perfectly-nested. "
142
268
" These will be serialzied." );
143
269
144
- // TODO This will be filled in with the next PRs that upstreams the rest of
145
- // the ROCm implementaion.
270
+ mlir::IRMapping mapper;
271
+ genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
272
+ mlir::omp::LoopNestOperands loopNestClauseOps;
273
+ genLoopNestClauseOps (doLoop.getLoc (), rewriter, loopNest, mapper,
274
+ loopNestClauseOps);
275
+
276
+ mlir::omp::LoopNestOp ompLoopNest =
277
+ genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps,
278
+ /* isComposite=*/ mapToDevice);
279
+
280
+ rewriter.eraseOp (doLoop);
281
+
282
+ // Mark `unordered` loops that are not perfectly nested to be skipped from
283
+ // the legality check of the `ConversionTarget` since we are not interested
284
+ // in mapping them to OpenMP.
285
+ ompLoopNest->walk ([&](fir::DoLoopOp doLoop) {
286
+ if (doLoop.getUnordered ()) {
287
+ concurrentLoopsToSkip.insert (doLoop);
288
+ }
289
+ });
290
+
146
291
return mlir::success ();
147
292
}
148
293
294
+ private:
295
+ mlir::omp::ParallelOp genParallelOp (mlir::Location loc,
296
+ mlir::ConversionPatternRewriter &rewriter,
297
+ looputils::LoopNestToIndVarMap &loopNest,
298
+ mlir::IRMapping &mapper) const {
299
+ auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(loc);
300
+ rewriter.createBlock (¶llelOp.getRegion ());
301
+ rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(loc));
302
+
303
+ genLoopNestIndVarAllocs (rewriter, loopNest, mapper);
304
+ return parallelOp;
305
+ }
306
+
307
+ void genLoopNestIndVarAllocs (mlir::ConversionPatternRewriter &rewriter,
308
+ looputils::LoopNestToIndVarMap &loopNest,
309
+ mlir::IRMapping &mapper) const {
310
+
311
+ for (auto &[_, indVarInfo] : loopNest)
312
+ genInductionVariableAlloc (rewriter, indVarInfo.iterVarMemDef , mapper);
313
+ }
314
+
315
+ mlir::Operation *
316
+ genInductionVariableAlloc (mlir::ConversionPatternRewriter &rewriter,
317
+ mlir::Operation *indVarMemDef,
318
+ mlir::IRMapping &mapper) const {
319
+ assert (
320
+ indVarMemDef != nullptr &&
321
+ " Induction variable memdef is expected to have a defining operation." );
322
+
323
+ llvm::SmallSetVector<mlir::Operation *, 2 > indVarDeclareAndAlloc;
324
+ for (auto operand : indVarMemDef->getOperands ())
325
+ indVarDeclareAndAlloc.insert (operand.getDefiningOp ());
326
+ indVarDeclareAndAlloc.insert (indVarMemDef);
327
+
328
+ mlir::Operation *result;
329
+ for (mlir::Operation *opToClone : indVarDeclareAndAlloc)
330
+ result = rewriter.clone (*opToClone, mapper);
331
+
332
+ return result;
333
+ }
334
+
335
+ void genLoopNestClauseOps (
336
+ mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
337
+ looputils::LoopNestToIndVarMap &loopNest, mlir::IRMapping &mapper,
338
+ mlir::omp::LoopNestOperands &loopNestClauseOps) const {
339
+ assert (loopNestClauseOps.loopLowerBounds .empty () &&
340
+ " Loop nest bounds were already emitted!" );
341
+
342
+ auto populateBounds = [&](mlir::Value var,
343
+ llvm::SmallVectorImpl<mlir::Value> &bounds) {
344
+ bounds.push_back (var.getDefiningOp ()->getResult (0 ));
345
+ };
346
+
347
+ for (auto &[doLoop, _] : loopNest) {
348
+ populateBounds (doLoop.getLowerBound (), loopNestClauseOps.loopLowerBounds );
349
+ populateBounds (doLoop.getUpperBound (), loopNestClauseOps.loopUpperBounds );
350
+ populateBounds (doLoop.getStep (), loopNestClauseOps.loopSteps );
351
+ }
352
+
353
+ loopNestClauseOps.loopInclusive = rewriter.getUnitAttr ();
354
+ }
355
+
356
+ mlir::omp::LoopNestOp
357
+ genWsLoopOp (mlir::ConversionPatternRewriter &rewriter, fir::DoLoopOp doLoop,
358
+ mlir::IRMapping &mapper,
359
+ const mlir::omp::LoopNestOperands &clauseOps,
360
+ bool isComposite) const {
361
+
362
+ auto wsloopOp = rewriter.create <mlir::omp::WsloopOp>(doLoop.getLoc ());
363
+ wsloopOp.setComposite (isComposite);
364
+ rewriter.createBlock (&wsloopOp.getRegion ());
365
+
366
+ auto loopNestOp =
367
+ rewriter.create <mlir::omp::LoopNestOp>(doLoop.getLoc (), clauseOps);
368
+
369
+ // Clone the loop's body inside the loop nest construct using the
370
+ // mapped values.
371
+ rewriter.cloneRegionBefore (doLoop.getRegion (), loopNestOp.getRegion (),
372
+ loopNestOp.getRegion ().begin (), mapper);
373
+
374
+ mlir::Operation *terminator = loopNestOp.getRegion ().back ().getTerminator ();
375
+ rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
376
+ rewriter.create <mlir::omp::YieldOp>(terminator->getLoc ());
377
+ rewriter.eraseOp (terminator);
378
+
379
+ return loopNestOp;
380
+ }
381
+
149
382
bool mapToDevice;
383
+ llvm::DenseSet<fir::DoLoopOp> &concurrentLoopsToSkip;
150
384
};
151
385
152
386
class DoConcurrentConversionPass
@@ -175,16 +409,18 @@ class DoConcurrentConversionPass
175
409
return ;
176
410
}
177
411
412
+ llvm::DenseSet<fir::DoLoopOp> concurrentLoopsToSkip;
178
413
mlir::RewritePatternSet patterns (context);
179
414
patterns.insert <DoConcurrentConversion>(
180
- context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device);
415
+ context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
416
+ concurrentLoopsToSkip);
181
417
mlir::ConversionTarget target (*context);
182
418
target.addDynamicallyLegalOp <fir::DoLoopOp>([&](fir::DoLoopOp op) {
183
419
// The goal is to handle constructs that eventually get lowered to
184
420
// `fir.do_loop` with the `unordered` attribute (e.g. array expressions).
185
421
// Currently, this is only enabled for the `do concurrent` construct since
186
422
// the pass runs early in the pipeline.
187
- return !op.getUnordered ();
423
+ return !op.getUnordered () || concurrentLoopsToSkip. contains (op) ;
188
424
});
189
425
target.markUnknownOpDynamicallyLegal (
190
426
[](mlir::Operation *) { return true ; });
0 commit comments