|
| 1 | +//===- BufferDeallocationSimplification.cpp -------------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This file implements logic for optimizing `bufferization.dealloc` operations |
| 10 | +// that requires more analysis than what can be supported by regular |
| 11 | +// canonicalization patterns. |
| 12 | +// |
| 13 | +//===----------------------------------------------------------------------===// |
| 14 | + |
| 15 | +#include "mlir/Analysis/AliasAnalysis.h" |
| 16 | +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 17 | +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
| 18 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 19 | +#include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 20 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 21 | + |
| 22 | +namespace mlir { |
| 23 | +namespace bufferization { |
| 24 | +#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION |
| 25 | +#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
| 26 | +} // namespace bufferization |
| 27 | +} // namespace mlir |
| 28 | + |
| 29 | +using namespace mlir; |
| 30 | +using namespace mlir::bufferization; |
| 31 | + |
| 32 | +//===----------------------------------------------------------------------===// |
| 33 | +// Helpers |
| 34 | +//===----------------------------------------------------------------------===// |
| 35 | + |
| 36 | +static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, |
| 37 | + ValueRange memrefs, |
| 38 | + ValueRange conditions, |
| 39 | + PatternRewriter &rewriter) { |
| 40 | + if (deallocOp.getMemrefs() == memrefs && |
| 41 | + deallocOp.getConditions() == conditions) |
| 42 | + return failure(); |
| 43 | + |
| 44 | + rewriter.updateRootInPlace(deallocOp, [&]() { |
| 45 | + deallocOp.getMemrefsMutable().assign(memrefs); |
| 46 | + deallocOp.getConditionsMutable().assign(conditions); |
| 47 | + }); |
| 48 | + return success(); |
| 49 | +} |
| 50 | + |
| 51 | +//===----------------------------------------------------------------------===// |
| 52 | +// Patterns |
| 53 | +//===----------------------------------------------------------------------===// |
| 54 | + |
| 55 | +namespace { |
| 56 | + |
| 57 | +/// Remove values from the `memref` operand list that are also present in the |
| 58 | +/// `retained` list since they will always alias and thus never actually be |
| 59 | +/// deallocated. However, we also need to be certain that no other value in the |
| 60 | +/// `retained` list can alias, for which we use a static alias analysis. This is |
| 61 | +/// necessary because the `dealloc` operation is defined to return one `i1` |
| 62 | +/// value per memref in the `retained` list which represents the disjunction of |
| 63 | +/// the condition values corresponding to all aliasing values in the `memref` |
| 64 | +/// list. In particular, this means that if there is some value R in the |
| 65 | +/// `retained` list which aliases with a value M in the `memref` list (but can |
| 66 | +/// only be staticaly determined to may-alias) and M is also present in the |
| 67 | +/// `retained` list, then it would be illegal to remove M because the result |
| 68 | +/// corresponding to R would be computed incorrectly afterwards. |
| 69 | +/// Because we require an alias analysis, this pattern cannot be applied as a |
| 70 | +/// regular canonicalization pattern. |
| 71 | +/// |
| 72 | +/// Example: |
| 73 | +/// ```mlir |
| 74 | +/// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0) |
| 75 | +/// retain (%m0, %r0, %r1 : ...) |
| 76 | +/// ``` |
| 77 | +/// is canonicalized to |
| 78 | +/// ```mlir |
| 79 | +/// // bufferization.dealloc without memrefs and conditions returns %false for |
| 80 | +/// // every retained value |
| 81 | +/// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...) |
| 82 | +/// %1 = arith.ori %0#0, %cond0 : i1 |
| 83 | +/// // replace %0#0 with %1 |
| 84 | +/// ``` |
| 85 | +/// given that `%r0` and `%r1` may not alias with `%m0`. |
| 86 | +struct DeallocRemoveDeallocMemrefsContainedInRetained |
| 87 | + : public OpRewritePattern<DeallocOp> { |
| 88 | + DeallocRemoveDeallocMemrefsContainedInRetained(MLIRContext *context, |
| 89 | + AliasAnalysis &aliasAnalysis) |
| 90 | + : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {} |
| 91 | + |
| 92 | + LogicalResult matchAndRewrite(DeallocOp deallocOp, |
| 93 | + PatternRewriter &rewriter) const override { |
| 94 | + // Unique memrefs to be deallocated. |
| 95 | + DenseMap<Value, unsigned> retained; |
| 96 | + for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained())) |
| 97 | + retained[ret] = i; |
| 98 | + |
| 99 | + // There must not be any duplicates in the retain list anymore because we |
| 100 | + // would miss updating one of the result values otherwise. |
| 101 | + if (retained.size() != deallocOp.getRetained().size()) |
| 102 | + return failure(); |
| 103 | + |
| 104 | + SmallVector<Value> newMemrefs, newConditions; |
| 105 | + for (auto memrefAndCond : |
| 106 | + llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { |
| 107 | + Value memref = std::get<0>(memrefAndCond); |
| 108 | + Value cond = std::get<1>(memrefAndCond); |
| 109 | + |
| 110 | + auto replaceResultsIfNoInvalidAliasing = [&](Value memref) -> bool { |
| 111 | + Value retainedMemref = deallocOp.getRetained()[retained[memref]]; |
| 112 | + // The current memref must not have a may-alias relation to any retained |
| 113 | + // memref, and exactly one must-alias relation. |
| 114 | + // TODO: it is possible to extend this pattern to allow an arbitrary |
| 115 | + // number of must-alias relations as long as there is no may-alias. If |
| 116 | + // it's no-alias, then just proceed (only supported case as of now), if |
| 117 | + // it's must-alias, we also need to update the condition for that alias. |
| 118 | + if (llvm::all_of(deallocOp.getRetained(), [&](Value mr) { |
| 119 | + return aliasAnalysis.alias(mr, memref).isNo() || |
| 120 | + mr == retainedMemref; |
| 121 | + })) { |
| 122 | + rewriter.setInsertionPointAfter(deallocOp); |
| 123 | + auto orOp = rewriter.create<arith::OrIOp>( |
| 124 | + deallocOp.getLoc(), |
| 125 | + deallocOp.getUpdatedConditions()[retained[memref]], cond); |
| 126 | + rewriter.replaceAllUsesExcept( |
| 127 | + deallocOp.getUpdatedConditions()[retained[memref]], |
| 128 | + orOp.getResult(), orOp); |
| 129 | + return true; |
| 130 | + } |
| 131 | + return false; |
| 132 | + }; |
| 133 | + |
| 134 | + if (retained.contains(memref) && |
| 135 | + replaceResultsIfNoInvalidAliasing(memref)) |
| 136 | + continue; |
| 137 | + |
| 138 | + auto extractOp = memref.getDefiningOp<memref::ExtractStridedMetadataOp>(); |
| 139 | + if (extractOp && retained.contains(extractOp.getOperand()) && |
| 140 | + replaceResultsIfNoInvalidAliasing(extractOp.getOperand())) |
| 141 | + continue; |
| 142 | + |
| 143 | + newMemrefs.push_back(memref); |
| 144 | + newConditions.push_back(cond); |
| 145 | + } |
| 146 | + |
| 147 | + // Return failure if we don't change anything such that we don't run into an |
| 148 | + // infinite loop of pattern applications. |
| 149 | + return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, |
| 150 | + rewriter); |
| 151 | + } |
| 152 | + |
| 153 | +private: |
| 154 | + AliasAnalysis &aliasAnalysis; |
| 155 | +}; |
| 156 | + |
| 157 | +} // namespace |
| 158 | + |
| 159 | +//===----------------------------------------------------------------------===// |
| 160 | +// BufferDeallocationSimplificationPass |
| 161 | +//===----------------------------------------------------------------------===// |
| 162 | + |
| 163 | +namespace { |
| 164 | + |
| 165 | +/// The actual buffer deallocation pass that inserts and moves dealloc nodes |
| 166 | +/// into the right positions. Furthermore, it inserts additional clones if |
| 167 | +/// necessary. It uses the algorithm described at the top of the file. |
| 168 | +struct BufferDeallocationSimplificationPass |
| 169 | + : public bufferization::impl::BufferDeallocationSimplificationBase< |
| 170 | + BufferDeallocationSimplificationPass> { |
| 171 | + void runOnOperation() override { |
| 172 | + AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>(); |
| 173 | + RewritePatternSet patterns(&getContext()); |
| 174 | + patterns.add<DeallocRemoveDeallocMemrefsContainedInRetained>(&getContext(), |
| 175 | + aliasAnalysis); |
| 176 | + |
| 177 | + if (failed( |
| 178 | + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) |
| 179 | + signalPassFailure(); |
| 180 | + } |
| 181 | +}; |
| 182 | + |
| 183 | +} // namespace |
| 184 | + |
| 185 | +std::unique_ptr<Pass> |
| 186 | +mlir::bufferization::createBufferDeallocationSimplificationPass() { |
| 187 | + return std::make_unique<BufferDeallocationSimplificationPass>(); |
| 188 | +} |
0 commit comments