|
| 1 | +//===- DoConcurrentConversion.cpp -- map `DO CONCURRENT` to OpenMP loops --===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "flang/Optimizer/Dialect/FIRDialect.h" |
| 10 | +#include "flang/Optimizer/Dialect/FIROps.h" |
| 11 | +#include "flang/Optimizer/Dialect/FIRType.h" |
| 12 | +#include "flang/Optimizer/Dialect/Support/FIRContext.h" |
| 13 | +#include "flang/Optimizer/HLFIR/HLFIRDialect.h" |
| 14 | +#include "flang/Optimizer/Transforms/Passes.h" |
| 15 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 16 | +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| 17 | +#include "mlir/IR/Diagnostics.h" |
| 18 | +#include "mlir/IR/IRMapping.h" |
| 19 | +#include "mlir/Pass/Pass.h" |
| 20 | +#include "mlir/Transforms/DialectConversion.h" |
| 21 | + |
| 22 | +#include <memory> |
| 23 | + |
| 24 | +namespace fir { |
| 25 | +#define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS |
| 26 | +#include "flang/Optimizer/Transforms/Passes.h.inc" |
| 27 | +} // namespace fir |
| 28 | + |
| 29 | +#define DEBUG_TYPE "fopenmp-do-concurrent-conversion" |
| 30 | + |
| 31 | +namespace { |
| 32 | +class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> { |
| 33 | +public: |
| 34 | + using mlir::OpConversionPattern<fir::DoLoopOp>::OpConversionPattern; |
| 35 | + |
| 36 | + mlir::LogicalResult |
| 37 | + matchAndRewrite(fir::DoLoopOp doLoop, OpAdaptor adaptor, |
| 38 | + mlir::ConversionPatternRewriter &rewriter) const override { |
| 39 | + mlir::OpPrintingFlags flags; |
| 40 | + flags.printGenericOpForm(); |
| 41 | + |
| 42 | + mlir::omp::ParallelOp parallelOp = |
| 43 | + rewriter.create<mlir::omp::ParallelOp>(doLoop.getLoc()); |
| 44 | + |
| 45 | + rewriter.createBlock(¶llelOp.getRegion()); |
| 46 | + mlir::Block &block = parallelOp.getRegion().back(); |
| 47 | + |
| 48 | + rewriter.setInsertionPointToEnd(&block); |
| 49 | + rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc()); |
| 50 | + |
| 51 | + rewriter.setInsertionPointToStart(&block); |
| 52 | + |
| 53 | + // Clone the LB, UB, step defining ops inside the parallel region. |
| 54 | + llvm::SmallVector<mlir::Value> lowerBound, upperBound, step; |
| 55 | + lowerBound.push_back( |
| 56 | + rewriter.clone(*doLoop.getLowerBound().getDefiningOp())->getResult(0)); |
| 57 | + upperBound.push_back( |
| 58 | + rewriter.clone(*doLoop.getUpperBound().getDefiningOp())->getResult(0)); |
| 59 | + step.push_back( |
| 60 | + rewriter.clone(*doLoop.getStep().getDefiningOp())->getResult(0)); |
| 61 | + |
| 62 | + auto wsLoopOp = rewriter.create<mlir::omp::WsLoopOp>( |
| 63 | + doLoop.getLoc(), lowerBound, upperBound, step); |
| 64 | + wsLoopOp.setInclusive(true); |
| 65 | + |
| 66 | + auto outlineableOp = |
| 67 | + mlir::dyn_cast<mlir::omp::OutlineableOpenMPOpInterface>(*parallelOp); |
| 68 | + assert(outlineableOp); |
| 69 | + rewriter.setInsertionPointToStart(outlineableOp.getAllocaBlock()); |
| 70 | + |
| 71 | + // For the induction variable, we need to privative its allocation and |
| 72 | + // binding inside the parallel region. |
| 73 | + llvm::SmallSetVector<mlir::Operation *, 2> workList; |
| 74 | + // Therefore, we first discover the induction variable by discovering |
| 75 | + // `fir.store`s where the source is the loop's block argument. |
| 76 | + workList.insert(doLoop.getInductionVar().getUsers().begin(), |
| 77 | + doLoop.getInductionVar().getUsers().end()); |
| 78 | + llvm::SmallSetVector<fir::StoreOp, 2> inductionVarTargetStores; |
| 79 | + |
| 80 | + // Walk the def-chain of the loop's block argument until we hit `fir.store`. |
| 81 | + while (!workList.empty()) { |
| 82 | + mlir::Operation *item = workList.front(); |
| 83 | + |
| 84 | + if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(item)) { |
| 85 | + inductionVarTargetStores.insert(storeOp); |
| 86 | + } else { |
| 87 | + workList.insert(item->getUsers().begin(), item->getUsers().end()); |
| 88 | + } |
| 89 | + |
| 90 | + workList.remove(item); |
| 91 | + } |
| 92 | + |
| 93 | + // For each collected `fir.sotre`, find the target memref's alloca's and |
| 94 | + // declare ops. |
| 95 | + llvm::SmallSetVector<mlir::Operation *, 4> declareAndAllocasToClone; |
| 96 | + for (auto storeOp : inductionVarTargetStores) { |
| 97 | + mlir::Operation *storeTarget = storeOp.getMemref().getDefiningOp(); |
| 98 | + |
| 99 | + for (auto operand : storeTarget->getOperands()) { |
| 100 | + declareAndAllocasToClone.insert(operand.getDefiningOp()); |
| 101 | + } |
| 102 | + declareAndAllocasToClone.insert(storeTarget); |
| 103 | + } |
| 104 | + |
| 105 | + mlir::IRMapping mapper; |
| 106 | + |
| 107 | + // Collect the memref defining ops in the parallel region. |
| 108 | + for (mlir::Operation *opToClone : declareAndAllocasToClone) { |
| 109 | + rewriter.clone(*opToClone, mapper); |
| 110 | + } |
| 111 | + |
| 112 | + // Clone the loop's body inside the worksharing construct using the mapped |
| 113 | + // memref values. |
| 114 | + rewriter.cloneRegionBefore(doLoop.getRegion(), wsLoopOp.getRegion(), |
| 115 | + wsLoopOp.getRegion().begin(), mapper); |
| 116 | + |
| 117 | + mlir::Operation *terminator = wsLoopOp.getRegion().back().getTerminator(); |
| 118 | + rewriter.setInsertionPointToEnd(&wsLoopOp.getRegion().back()); |
| 119 | + rewriter.create<mlir::omp::YieldOp>(terminator->getLoc()); |
| 120 | + rewriter.eraseOp(terminator); |
| 121 | + |
| 122 | + rewriter.eraseOp(doLoop); |
| 123 | + |
| 124 | + return mlir::success(); |
| 125 | + } |
| 126 | +}; |
| 127 | + |
| 128 | +class DoConcurrentConversionPass |
| 129 | + : public fir::impl::DoConcurrentConversionPassBase< |
| 130 | + DoConcurrentConversionPass> { |
| 131 | +public: |
| 132 | + void runOnOperation() override { |
| 133 | + mlir::func::FuncOp func = getOperation(); |
| 134 | + |
| 135 | + if (func.isDeclaration()) { |
| 136 | + return; |
| 137 | + } |
| 138 | + |
| 139 | + auto *context = &getContext(); |
| 140 | + mlir::RewritePatternSet patterns(context); |
| 141 | + patterns.insert<DoConcurrentConversion>(context); |
| 142 | + mlir::ConversionTarget target(*context); |
| 143 | + target.addLegalDialect<fir::FIROpsDialect, hlfir::hlfirDialect, |
| 144 | + mlir::arith::ArithDialect, mlir::func::FuncDialect, |
| 145 | + mlir::omp::OpenMPDialect>(); |
| 146 | + |
| 147 | + target.addDynamicallyLegalOp<fir::DoLoopOp>( |
| 148 | + [](fir::DoLoopOp op) { return !op.getUnordered(); }); |
| 149 | + |
| 150 | + if (mlir::failed(mlir::applyFullConversion(getOperation(), target, |
| 151 | + std::move(patterns)))) { |
| 152 | + mlir::emitError(mlir::UnknownLoc::get(context), |
| 153 | + "error in converting do-concurrent op"); |
| 154 | + signalPassFailure(); |
| 155 | + } |
| 156 | + } |
| 157 | +}; |
| 158 | +} // namespace |
| 159 | + |
| 160 | +std::unique_ptr<mlir::Pass> fir::createDoConcurrentConversionPass() { |
| 161 | + return std::make_unique<DoConcurrentConversionPass>(); |
| 162 | +} |
0 commit comments