|
| 1 | +//===- InlineHLFIRAssign.cpp - Inline hlfir.assign ops --------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// Transform hlfir.assign array operations into loop nests performing element |
| 9 | +// per element assignments. The inlining is done for trivial data types always, |
| 10 | +// though, we may add performance/code-size heuristics in future. |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "flang/Optimizer/Analysis/AliasAnalysis.h" |
| 14 | +#include "flang/Optimizer/Builder/FIRBuilder.h" |
| 15 | +#include "flang/Optimizer/Builder/HLFIRTools.h" |
| 16 | +#include "flang/Optimizer/HLFIR/HLFIROps.h" |
| 17 | +#include "flang/Optimizer/HLFIR/Passes.h" |
| 18 | +#include "flang/Optimizer/OpenMP/Passes.h" |
| 19 | +#include "mlir/IR/PatternMatch.h" |
| 20 | +#include "mlir/Pass/Pass.h" |
| 21 | +#include "mlir/Support/LLVM.h" |
| 22 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 23 | + |
| 24 | +namespace hlfir { |
| 25 | +#define GEN_PASS_DEF_INLINEHLFIRASSIGN |
| 26 | +#include "flang/Optimizer/HLFIR/Passes.h.inc" |
| 27 | +} // namespace hlfir |
| 28 | + |
| 29 | +#define DEBUG_TYPE "inline-hlfir-assign" |
| 30 | + |
| 31 | +namespace { |
| 32 | +/// Expand hlfir.assign of array RHS to array LHS into a loop nest |
| 33 | +/// of element-by-element assignments: |
| 34 | +/// hlfir.assign %4 to %5 : !fir.ref<!fir.array<3x3xf32>>, |
| 35 | +/// !fir.ref<!fir.array<3x3xf32>> |
| 36 | +/// into: |
| 37 | +/// fir.do_loop %arg1 = %c1 to %c3 step %c1 unordered { |
| 38 | +/// fir.do_loop %arg2 = %c1 to %c3 step %c1 unordered { |
| 39 | +/// %6 = hlfir.designate %4 (%arg2, %arg1) : |
| 40 | +/// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32> |
| 41 | +/// %7 = fir.load %6 : !fir.ref<f32> |
| 42 | +/// %8 = hlfir.designate %5 (%arg2, %arg1) : |
| 43 | +/// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32> |
| 44 | +/// hlfir.assign %7 to %8 : f32, !fir.ref<f32> |
| 45 | +/// } |
| 46 | +/// } |
| 47 | +/// |
| 48 | +/// The transformation is correct only when LHS and RHS do not alias. |
| 49 | +/// When RHS is an array expression, then there is no aliasing. |
| 50 | +/// This transformation does not support runtime checking for |
| 51 | +/// non-conforming LHS/RHS arrays' shapes currently. |
| 52 | +class InlineHLFIRAssignConversion |
| 53 | + : public mlir::OpRewritePattern<hlfir::AssignOp> { |
| 54 | +public: |
| 55 | + using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern; |
| 56 | + |
| 57 | + llvm::LogicalResult |
| 58 | + matchAndRewrite(hlfir::AssignOp assign, |
| 59 | + mlir::PatternRewriter &rewriter) const override { |
| 60 | + if (assign.isAllocatableAssignment()) |
| 61 | + return rewriter.notifyMatchFailure(assign, |
| 62 | + "AssignOp may imply allocation"); |
| 63 | + |
| 64 | + hlfir::Entity rhs{assign.getRhs()}; |
| 65 | + |
| 66 | + if (!rhs.isArray()) |
| 67 | + return rewriter.notifyMatchFailure(assign, |
| 68 | + "AssignOp's RHS is not an array"); |
| 69 | + |
| 70 | + mlir::Type rhsEleTy = rhs.getFortranElementType(); |
| 71 | + if (!fir::isa_trivial(rhsEleTy)) |
| 72 | + return rewriter.notifyMatchFailure( |
| 73 | + assign, "AssignOp's RHS data type is not trivial"); |
| 74 | + |
| 75 | + hlfir::Entity lhs{assign.getLhs()}; |
| 76 | + if (!lhs.isArray()) |
| 77 | + return rewriter.notifyMatchFailure(assign, |
| 78 | + "AssignOp's LHS is not an array"); |
| 79 | + |
| 80 | + mlir::Type lhsEleTy = lhs.getFortranElementType(); |
| 81 | + if (!fir::isa_trivial(lhsEleTy)) |
| 82 | + return rewriter.notifyMatchFailure( |
| 83 | + assign, "AssignOp's LHS data type is not trivial"); |
| 84 | + |
| 85 | + if (lhsEleTy != rhsEleTy) |
| 86 | + return rewriter.notifyMatchFailure(assign, |
| 87 | + "RHS/LHS element types mismatch"); |
| 88 | + |
| 89 | + if (!mlir::isa<hlfir::ExprType>(rhs.getType())) { |
| 90 | + // If RHS is not an hlfir.expr, then we should prove that |
| 91 | + // LHS and RHS do not alias. |
| 92 | + // TODO: if they may alias, we can insert hlfir.as_expr for RHS, |
| 93 | + // and proceed with the inlining. |
| 94 | + fir::AliasAnalysis aliasAnalysis; |
| 95 | + mlir::AliasResult aliasRes = aliasAnalysis.alias(lhs, rhs); |
| 96 | + // TODO: use areIdenticalOrDisjointSlices() from |
| 97 | + // OptimizedBufferization.cpp to check if we can still do the expansion. |
| 98 | + if (!aliasRes.isNo()) { |
| 99 | + LLVM_DEBUG(llvm::dbgs() << "InlineHLFIRAssign:\n" |
| 100 | + << "\tLHS: " << lhs << "\n" |
| 101 | + << "\tRHS: " << rhs << "\n" |
| 102 | + << "\tALIAS: " << aliasRes << "\n"); |
| 103 | + return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias"); |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + mlir::Location loc = assign->getLoc(); |
| 108 | + fir::FirOpBuilder builder(rewriter, assign.getOperation()); |
| 109 | + builder.setInsertionPoint(assign); |
| 110 | + rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs); |
| 111 | + lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); |
| 112 | + mlir::Value shape = hlfir::genShape(loc, builder, lhs); |
| 113 | + llvm::SmallVector<mlir::Value> extents = |
| 114 | + hlfir::getIndexExtents(loc, builder, shape); |
| 115 | + hlfir::LoopNest loopNest = |
| 116 | + hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, |
| 117 | + flangomp::shouldUseWorkshareLowering(assign)); |
| 118 | + builder.setInsertionPointToStart(loopNest.body); |
| 119 | + auto rhsArrayElement = |
| 120 | + hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices); |
| 121 | + rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement); |
| 122 | + auto lhsArrayElement = |
| 123 | + hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); |
| 124 | + builder.create<hlfir::AssignOp>(loc, rhsArrayElement, lhsArrayElement); |
| 125 | + rewriter.eraseOp(assign); |
| 126 | + return mlir::success(); |
| 127 | + } |
| 128 | +}; |
| 129 | + |
| 130 | +class InlineHLFIRAssignPass |
| 131 | + : public hlfir::impl::InlineHLFIRAssignBase<InlineHLFIRAssignPass> { |
| 132 | +public: |
| 133 | + void runOnOperation() override { |
| 134 | + mlir::MLIRContext *context = &getContext(); |
| 135 | + |
| 136 | + mlir::GreedyRewriteConfig config; |
| 137 | + // Prevent the pattern driver from merging blocks. |
| 138 | + config.enableRegionSimplification = |
| 139 | + mlir::GreedySimplifyRegionLevel::Disabled; |
| 140 | + |
| 141 | + mlir::RewritePatternSet patterns(context); |
| 142 | + patterns.insert<InlineHLFIRAssignConversion>(context); |
| 143 | + |
| 144 | + if (mlir::failed(mlir::applyPatternsGreedily( |
| 145 | + getOperation(), std::move(patterns), config))) { |
| 146 | + mlir::emitError(getOperation()->getLoc(), |
| 147 | + "failure in hlfir.assign inlining"); |
| 148 | + signalPassFailure(); |
| 149 | + } |
| 150 | + } |
| 151 | +}; |
| 152 | +} // namespace |
0 commit comments