15
15
#include " mlir/Transforms/DialectConversion.h"
16
16
17
17
#include < memory>
18
+ #include < optional>
19
+ #include < type_traits>
18
20
19
21
namespace flangomp {
20
22
#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
@@ -58,7 +60,7 @@ class GenericLoopConversionPattern
58
60
if (teamsLoopCanBeParallelFor (loopOp))
59
61
rewriteToDistributeParallelDo (loopOp, rewriter);
60
62
else
61
- rewriteToDistrbute (loopOp, rewriter);
63
+ rewriteToDistribute (loopOp, rewriter);
62
64
break ;
63
65
}
64
66
@@ -77,9 +79,6 @@ class GenericLoopConversionPattern
77
79
if (loopOp.getOrder ())
78
80
return todo (" order" );
79
81
80
- if (!loopOp.getReductionVars ().empty ())
81
- return todo (" reduction" );
82
-
83
82
return mlir::success ();
84
83
}
85
84
@@ -168,7 +167,7 @@ class GenericLoopConversionPattern
168
167
case ClauseBindKind::Parallel:
169
168
return rewriteToWsloop (loopOp, rewriter);
170
169
case ClauseBindKind::Teams:
171
- return rewriteToDistrbute (loopOp, rewriter);
170
+ return rewriteToDistribute (loopOp, rewriter);
172
171
case ClauseBindKind::Thread:
173
172
return rewriteToSimdLoop (loopOp, rewriter);
174
173
}
@@ -211,8 +210,9 @@ class GenericLoopConversionPattern
211
210
loopOp, rewriter);
212
211
}
213
212
214
- void rewriteToDistrbute (mlir::omp::LoopOp loopOp,
215
- mlir::ConversionPatternRewriter &rewriter) const {
213
+ void rewriteToDistribute (mlir::omp::LoopOp loopOp,
214
+ mlir::ConversionPatternRewriter &rewriter) const {
215
+ assert (loopOp.getReductionVars ().empty ());
216
216
rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
217
217
mlir::omp::DistributeOperands>(loopOp, rewriter);
218
218
}
@@ -246,6 +246,12 @@ class GenericLoopConversionPattern
246
246
Fortran::common::openmp::EntryBlockArgs args;
247
247
args.priv .vars = clauseOps.privateVars ;
248
248
249
+ if constexpr (!std::is_same_v<OpOperandsTy,
250
+ mlir::omp::DistributeOperands>) {
251
+ populateReductionClauseOps (loopOp, clauseOps);
252
+ args.reduction .vars = clauseOps.reductionVars ;
253
+ }
254
+
249
255
auto wrapperOp = rewriter.create <OpTy>(loopOp.getLoc (), clauseOps);
250
256
mlir::Block *opBlock = genEntryBlock (rewriter, args, wrapperOp.getRegion ());
251
257
@@ -275,8 +281,7 @@ class GenericLoopConversionPattern
275
281
276
282
auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(loopOp.getLoc (),
277
283
parallelClauseOps);
278
- mlir::Block *parallelBlock =
279
- genEntryBlock (rewriter, parallelArgs, parallelOp.getRegion ());
284
+ genEntryBlock (rewriter, parallelArgs, parallelOp.getRegion ());
280
285
parallelOp.setComposite (true );
281
286
rewriter.setInsertionPoint (
282
287
rewriter.create <mlir::omp::TerminatorOp>(loopOp.getLoc ()));
@@ -288,20 +293,54 @@ class GenericLoopConversionPattern
288
293
rewriter.createBlock (&distributeOp.getRegion ());
289
294
290
295
mlir::omp::WsloopOperands wsloopClauseOps;
296
+ populateReductionClauseOps (loopOp, wsloopClauseOps);
297
+ Fortran::common::openmp::EntryBlockArgs wsloopArgs;
298
+ wsloopArgs.reduction .vars = wsloopClauseOps.reductionVars ;
299
+
291
300
auto wsloopOp =
292
301
rewriter.create <mlir::omp::WsloopOp>(loopOp.getLoc (), wsloopClauseOps);
293
302
wsloopOp.setComposite (true );
294
- rewriter. createBlock (& wsloopOp.getRegion ());
303
+ genEntryBlock (rewriter, wsloopArgs, wsloopOp.getRegion ());
295
304
296
305
mlir::IRMapping mapper;
297
- mlir::Block &loopBlock = *loopOp.getRegion ().begin ();
298
306
299
- for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal (
300
- loopBlock.getArguments (), parallelBlock->getArguments ()))
307
+ auto loopBlockInterface =
308
+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp);
309
+ auto parallelBlockInterface =
310
+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*parallelOp);
311
+ auto wsloopBlockInterface =
312
+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*wsloopOp);
313
+
314
+ for (auto [loopOpArg, parallelOpArg] :
315
+ llvm::zip_equal (loopBlockInterface.getPrivateBlockArgs (),
316
+ parallelBlockInterface.getPrivateBlockArgs ()))
301
317
mapper.map (loopOpArg, parallelOpArg);
302
318
319
+ for (auto [loopOpArg, wsloopOpArg] :
320
+ llvm::zip_equal (loopBlockInterface.getReductionBlockArgs (),
321
+ wsloopBlockInterface.getReductionBlockArgs ()))
322
+ mapper.map (loopOpArg, wsloopOpArg);
323
+
303
324
rewriter.clone (*loopOp.begin (), mapper);
304
325
}
326
+
327
+ void
328
+ populateReductionClauseOps (mlir::omp::LoopOp loopOp,
329
+ mlir::omp::ReductionClauseOps &clauseOps) const {
330
+ clauseOps.reductionMod = loopOp.getReductionModAttr ();
331
+ clauseOps.reductionVars = loopOp.getReductionVars ();
332
+
333
+ std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms ();
334
+ if (reductionSyms)
335
+ clauseOps.reductionSyms .assign (reductionSyms->begin (),
336
+ reductionSyms->end ());
337
+
338
+ std::optional<llvm::ArrayRef<bool >> reductionByref =
339
+ loopOp.getReductionByref ();
340
+ if (reductionByref)
341
+ clauseOps.reductionByref .assign (reductionByref->begin (),
342
+ reductionByref->end ());
343
+ }
305
344
};
306
345
307
346
class GenericLoopConversionPass
0 commit comments