|
| 1 | +//===- InlineHLFIRCopyIn.cpp - Inline hlfir.copy_in ops -------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// Transform hlfir.copy_in array operations into loop nests performing element |
| 9 | +// per element assignments. For simplicity, the inlining is done for trivial |
| 10 | +// data types when the copy_in does not require a corresponding copy_out and |
| 11 | +// when the input array is not behind a pointer. This may change in the future. |
| 12 | +//===----------------------------------------------------------------------===// |
| 13 | + |
| 14 | +#include "flang/Optimizer/Builder/FIRBuilder.h" |
| 15 | +#include "flang/Optimizer/Builder/HLFIRTools.h" |
| 16 | +#include "flang/Optimizer/Dialect/FIRType.h" |
| 17 | +#include "flang/Optimizer/HLFIR/HLFIROps.h" |
| 18 | +#include "flang/Optimizer/OpenMP/Passes.h" |
| 19 | +#include "mlir/IR/PatternMatch.h" |
| 20 | +#include "mlir/Support/LLVM.h" |
| 21 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 22 | + |
| 23 | +namespace hlfir { |
| 24 | +#define GEN_PASS_DEF_INLINEHLFIRCOPYIN |
| 25 | +#include "flang/Optimizer/HLFIR/Passes.h.inc" |
| 26 | +} // namespace hlfir |
| 27 | + |
| 28 | +#define DEBUG_TYPE "inline-hlfir-copy-in" |
| 29 | + |
| 30 | +static llvm::cl::opt<bool> noInlineHLFIRCopyIn( |
| 31 | + "no-inline-hlfir-copy-in", |
| 32 | + llvm::cl::desc("Do not inline hlfir.copy_in operations"), |
| 33 | + llvm::cl::init(false)); |
| 34 | + |
| 35 | +namespace { |
| 36 | +class InlineCopyInConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> { |
| 37 | +public: |
| 38 | + using mlir::OpRewritePattern<hlfir::CopyInOp>::OpRewritePattern; |
| 39 | + |
| 40 | + llvm::LogicalResult |
| 41 | + matchAndRewrite(hlfir::CopyInOp copyIn, |
| 42 | + mlir::PatternRewriter &rewriter) const override; |
| 43 | +}; |
| 44 | + |
| 45 | +llvm::LogicalResult |
| 46 | +InlineCopyInConversion::matchAndRewrite(hlfir::CopyInOp copyIn, |
| 47 | + mlir::PatternRewriter &rewriter) const { |
| 48 | + fir::FirOpBuilder builder(rewriter, copyIn.getOperation()); |
| 49 | + mlir::Location loc = copyIn.getLoc(); |
| 50 | + hlfir::Entity inputVariable{copyIn.getVar()}; |
| 51 | + mlir::Type resultAddrType = copyIn.getCopiedIn().getType(); |
| 52 | + if (!fir::isa_trivial(inputVariable.getFortranElementType())) |
| 53 | + return rewriter.notifyMatchFailure(copyIn, |
| 54 | + "CopyInOp's data type is not trivial"); |
| 55 | + |
| 56 | + // There should be exactly one user of WasCopied - the corresponding |
| 57 | + // CopyOutOp. |
| 58 | + if (!copyIn.getWasCopied().hasOneUse()) |
| 59 | + return rewriter.notifyMatchFailure( |
| 60 | + copyIn, "CopyInOp's WasCopied has no single user"); |
| 61 | + // The copy out should always be present, either to actually copy or just |
| 62 | + // deallocate memory. |
| 63 | + auto copyOut = mlir::dyn_cast<hlfir::CopyOutOp>( |
| 64 | + copyIn.getWasCopied().user_begin().getCurrent().getUser()); |
| 65 | + |
| 66 | + if (!copyOut) |
| 67 | + return rewriter.notifyMatchFailure(copyIn, |
| 68 | + "CopyInOp has no direct CopyOut"); |
| 69 | + |
| 70 | + if (mlir::cast<fir::BaseBoxType>(resultAddrType).isAssumedRank()) |
| 71 | + return rewriter.notifyMatchFailure(copyIn, |
| 72 | + "The result array is assumed-rank"); |
| 73 | + |
| 74 | + // Only inline the copy_in when copy_out does not need to be done, i.e. in |
| 75 | + // case of intent(in). |
| 76 | + if (copyOut.getVar()) |
| 77 | + return rewriter.notifyMatchFailure(copyIn, "CopyIn needs a copy-out"); |
| 78 | + |
| 79 | + inputVariable = |
| 80 | + hlfir::derefPointersAndAllocatables(loc, builder, inputVariable); |
| 81 | + mlir::Type sequenceType = |
| 82 | + hlfir::getFortranElementOrSequenceType(inputVariable.getType()); |
| 83 | + fir::BoxType resultBoxType = fir::BoxType::get(sequenceType); |
| 84 | + mlir::Value isContiguous = |
| 85 | + builder.create<fir::IsContiguousBoxOp>(loc, inputVariable); |
| 86 | + mlir::Operation::result_range results = |
| 87 | + builder |
| 88 | + .genIfOp(loc, {resultBoxType, builder.getI1Type()}, isContiguous, |
| 89 | + /*withElseRegion=*/true) |
| 90 | + .genThen([&]() { |
| 91 | + mlir::Value result = inputVariable; |
| 92 | + if (fir::isPointerType(inputVariable.getType())) { |
| 93 | + result = builder.create<fir::ReboxOp>( |
| 94 | + loc, resultBoxType, inputVariable, mlir::Value{}, |
| 95 | + mlir::Value{}); |
| 96 | + } |
| 97 | + builder.create<fir::ResultOp>( |
| 98 | + loc, mlir::ValueRange{result, builder.createBool(loc, false)}); |
| 99 | + }) |
| 100 | + .genElse([&] { |
| 101 | + mlir::Value shape = hlfir::genShape(loc, builder, inputVariable); |
| 102 | + llvm::SmallVector<mlir::Value> extents = |
| 103 | + hlfir::getIndexExtents(loc, builder, shape); |
| 104 | + llvm::StringRef tmpName{".tmp.copy_in"}; |
| 105 | + llvm::SmallVector<mlir::Value> lenParams; |
| 106 | + mlir::Value alloc = builder.createHeapTemporary( |
| 107 | + loc, sequenceType, tmpName, extents, lenParams); |
| 108 | + |
| 109 | + auto declareOp = builder.create<hlfir::DeclareOp>( |
| 110 | + loc, alloc, tmpName, shape, lenParams, |
| 111 | + /*dummy_scope=*/nullptr); |
| 112 | + hlfir::Entity temp{declareOp.getBase()}; |
| 113 | + hlfir::LoopNest loopNest = |
| 114 | + hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, |
| 115 | + flangomp::shouldUseWorkshareLowering(copyIn), |
| 116 | + /*couldVectorize=*/false); |
| 117 | + builder.setInsertionPointToStart(loopNest.body); |
| 118 | + hlfir::Entity elem = hlfir::getElementAt( |
| 119 | + loc, builder, inputVariable, loopNest.oneBasedIndices); |
| 120 | + elem = hlfir::loadTrivialScalar(loc, builder, elem); |
| 121 | + hlfir::Entity tempElem = hlfir::getElementAt( |
| 122 | + loc, builder, temp, loopNest.oneBasedIndices); |
| 123 | + builder.create<hlfir::AssignOp>(loc, elem, tempElem); |
| 124 | + builder.setInsertionPointAfter(loopNest.outerOp); |
| 125 | + |
| 126 | + mlir::Value result; |
| 127 | + // Make sure the result is always a boxed array by boxing it |
| 128 | + // ourselves if need be. |
| 129 | + if (mlir::isa<fir::BaseBoxType>(temp.getType())) { |
| 130 | + result = temp; |
| 131 | + } else { |
| 132 | + fir::ReferenceType refTy = |
| 133 | + fir::ReferenceType::get(temp.getElementOrSequenceType()); |
| 134 | + mlir::Value refVal = builder.createConvert(loc, refTy, temp); |
| 135 | + result = builder.create<fir::EmboxOp>(loc, resultBoxType, refVal, |
| 136 | + shape); |
| 137 | + } |
| 138 | + |
| 139 | + builder.create<fir::ResultOp>( |
| 140 | + loc, mlir::ValueRange{result, builder.createBool(loc, true)}); |
| 141 | + }) |
| 142 | + .getResults(); |
| 143 | + |
| 144 | + mlir::OpResult resultBox = results[0]; |
| 145 | + mlir::OpResult needsCleanup = results[1]; |
| 146 | + |
| 147 | + // Prepare the corresponding copyOut to free the temporary if it is required |
| 148 | + auto alloca = builder.create<fir::AllocaOp>(loc, resultBox.getType()); |
| 149 | + auto store = builder.create<fir::StoreOp>(loc, resultBox, alloca); |
| 150 | + rewriter.startOpModification(copyOut); |
| 151 | + copyOut->setOperand(0, store.getMemref()); |
| 152 | + copyOut->setOperand(1, needsCleanup); |
| 153 | + rewriter.finalizeOpModification(copyOut); |
| 154 | + |
| 155 | + rewriter.replaceOp(copyIn, {resultBox, builder.genNot(loc, isContiguous)}); |
| 156 | + return mlir::success(); |
| 157 | +} |
| 158 | + |
| 159 | +class InlineHLFIRCopyInPass |
| 160 | + : public hlfir::impl::InlineHLFIRCopyInBase<InlineHLFIRCopyInPass> { |
| 161 | +public: |
| 162 | + void runOnOperation() override { |
| 163 | + mlir::MLIRContext *context = &getContext(); |
| 164 | + |
| 165 | + mlir::GreedyRewriteConfig config; |
| 166 | + // Prevent the pattern driver from merging blocks. |
| 167 | + config.setRegionSimplificationLevel( |
| 168 | + mlir::GreedySimplifyRegionLevel::Disabled); |
| 169 | + |
| 170 | + mlir::RewritePatternSet patterns(context); |
| 171 | + if (!noInlineHLFIRCopyIn) { |
| 172 | + patterns.insert<InlineCopyInConversion>(context); |
| 173 | + } |
| 174 | + |
| 175 | + if (mlir::failed(mlir::applyPatternsGreedily( |
| 176 | + getOperation(), std::move(patterns), config))) { |
| 177 | + mlir::emitError(getOperation()->getLoc(), |
| 178 | + "failure in hlfir.copy_in inlining"); |
| 179 | + signalPassFailure(); |
| 180 | + } |
| 181 | + } |
| 182 | +}; |
| 183 | +} // namespace |
0 commit comments