|
15 | 15 | #include "mlir/Transforms/DialectConversion.h"
|
16 | 16 |
|
17 | 17 | #include <memory>
|
| 18 | +#include <optional> |
| 19 | +#include <type_traits> |
18 | 20 |
|
19 | 21 | namespace flangomp {
|
20 | 22 | #define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
|
@@ -77,9 +79,6 @@ class GenericLoopConversionPattern
|
77 | 79 | if (loopOp.getOrder())
|
78 | 80 | return todo("order");
|
79 | 81 |
|
80 |
| - if (!loopOp.getReductionVars().empty()) |
81 |
| - return todo("reduction"); |
82 |
| - |
83 | 82 | return mlir::success();
|
84 | 83 | }
|
85 | 84 |
|
@@ -213,6 +212,7 @@ class GenericLoopConversionPattern
|
213 | 212 |
|
214 | 213 | void rewriteToDistrbute(mlir::omp::LoopOp loopOp,
|
215 | 214 | mlir::ConversionPatternRewriter &rewriter) const {
|
| 215 | + assert(loopOp.getReductionVars().empty()); |
216 | 216 | rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
|
217 | 217 | mlir::omp::DistributeOperands>(loopOp, rewriter);
|
218 | 218 | }
|
@@ -246,6 +246,12 @@ class GenericLoopConversionPattern
|
246 | 246 | Fortran::common::openmp::EntryBlockArgs args;
|
247 | 247 | args.priv.vars = clauseOps.privateVars;
|
248 | 248 |
|
| 249 | + if constexpr (!std::is_same_v<OpOperandsTy, |
| 250 | + mlir::omp::DistributeOperands>) { |
| 251 | + populateReductionClauseOps(loopOp, clauseOps); |
| 252 | + args.reduction.vars = clauseOps.reductionVars; |
| 253 | + } |
| 254 | + |
249 | 255 | auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
|
250 | 256 | mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());
|
251 | 257 |
|
@@ -288,20 +294,51 @@ class GenericLoopConversionPattern
|
288 | 294 | rewriter.createBlock(&distributeOp.getRegion());
|
289 | 295 |
|
290 | 296 | mlir::omp::WsloopOperands wsloopClauseOps;
|
| 297 | + populateReductionClauseOps(loopOp, wsloopClauseOps); |
| 298 | + Fortran::common::openmp::EntryBlockArgs wsloopArgs; |
| 299 | + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; |
| 300 | + |
291 | 301 | auto wsloopOp =
|
292 | 302 | rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
|
293 | 303 | wsloopOp.setComposite(true);
|
294 |
| - rewriter.createBlock(&wsloopOp.getRegion()); |
| 304 | + mlir::Block *loopBlock = |
| 305 | + genEntryBlock(rewriter, wsloopArgs, wsloopOp.getRegion()); |
295 | 306 |
|
296 | 307 | mlir::IRMapping mapper;
|
297 |
| - mlir::Block &loopBlock = *loopOp.getRegion().begin(); |
298 | 308 |
|
299 |
| - for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal( |
300 |
| - loopBlock.getArguments(), parallelBlock->getArguments())) |
| 309 | + auto loopBlockInterface = |
| 310 | + llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp); |
| 311 | + |
| 312 | + for (auto [loopOpArg, parallelOpArg] : |
| 313 | + llvm::zip_equal(loopBlockInterface.getPrivateBlockArgs(), |
| 314 | + parallelBlock->getArguments())) |
301 | 315 | mapper.map(loopOpArg, parallelOpArg);
|
302 | 316 |
|
| 317 | + for (auto [loopOpArg, wsloopOpArg] : |
| 318 | + llvm::zip_equal(loopBlockInterface.getReductionBlockArgs(), |
| 319 | + loopBlock->getArguments())) |
| 320 | + mapper.map(loopOpArg, wsloopOpArg); |
| 321 | + |
303 | 322 | rewriter.clone(*loopOp.begin(), mapper);
|
304 | 323 | }
|
| 324 | + |
| 325 | + template <typename OpOperandsTy> |
| 326 | + void populateReductionClauseOps(mlir::omp::LoopOp loopOp, |
| 327 | + OpOperandsTy &clauseOps) const { |
| 328 | + clauseOps.reductionMod = loopOp.getReductionModAttr(); |
| 329 | + clauseOps.reductionVars = loopOp.getReductionVars(); |
| 330 | + |
| 331 | + std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms(); |
| 332 | + if (reductionSyms) |
| 333 | + clauseOps.reductionSyms.assign(reductionSyms->begin(), |
| 334 | + reductionSyms->end()); |
| 335 | + |
| 336 | + std::optional<llvm::ArrayRef<bool>> reductionByref = |
| 337 | + loopOp.getReductionByref(); |
| 338 | + if (reductionByref) |
| 339 | + clauseOps.reductionByref.assign(reductionByref->begin(), |
| 340 | + reductionByref->end()); |
| 341 | + } |
305 | 342 | };
|
306 | 343 |
|
307 | 344 | class GenericLoopConversionPass
|
|
0 commit comments