|
| 1 | +//===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" |
| 10 | + |
| 11 | +#include "PassDetail.h" |
| 12 | + |
| 13 | +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 14 | +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 15 | +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
| 16 | +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
| 17 | +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" |
| 18 | +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
| 19 | + |
| 20 | +using namespace mlir; |
| 21 | +using namespace mlir::bufferization; |
| 22 | + |
| 23 | +LogicalResult mlir::bufferization::insertTensorCopies( |
| 24 | + Operation *op, const OneShotBufferizationOptions &options) { |
| 25 | + OneShotAnalysisState state(op, options); |
| 26 | + // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize |
| 27 | + // analysis depending on whether function boundary bufferization is enabled or |
| 28 | + // not. |
| 29 | + if (options.bufferizeFunctionBoundaries) { |
| 30 | + if (failed(analyzeModuleOp(cast<ModuleOp>(op), state))) |
| 31 | + return failure(); |
| 32 | + } else { |
| 33 | + if (failed(analyzeOp(op, state))) |
| 34 | + return failure(); |
| 35 | + } |
| 36 | + |
| 37 | + if (options.testAnalysisOnly) |
| 38 | + return success(); |
| 39 | + |
| 40 | + return insertTensorCopies(op, state); |
| 41 | +} |
| 42 | + |
| 43 | +LogicalResult |
| 44 | +mlir::bufferization::insertTensorCopies(Operation *op, |
| 45 | + const AnalysisState &state) { |
| 46 | + OpBuilder builder(op->getContext()); |
| 47 | + WalkResult result = op->walk([&](Operation *op) { |
| 48 | + auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op); |
| 49 | + if (!bufferizableOp) |
| 50 | + return WalkResult::skip(); |
| 51 | + |
| 52 | + // Find AllocTensorOps without an `escape` attribute and add the attribute |
| 53 | + // based on analysis results. |
| 54 | + if (auto allocTensorOp = dyn_cast<AllocTensorOp>(op)) { |
| 55 | + if (allocTensorOp.escape()) |
| 56 | + return WalkResult::advance(); |
| 57 | + bool escape = state.isTensorYielded(allocTensorOp.result()); |
| 58 | + allocTensorOp.escapeAttr(builder.getBoolAttr(escape)); |
| 59 | + return WalkResult::advance(); |
| 60 | + } |
| 61 | + |
| 62 | + // Find out-of-place tensor OpOperands and resolve them with an explicit |
| 63 | + // tensor copy in the form of an AllocTensorOp. |
| 64 | + builder.setInsertionPoint(op); |
| 65 | + for (OpOperand &opOperand : op->getOpOperands()) { |
| 66 | + if (opOperand.get().getType().isa<UnrankedTensorType>()) { |
| 67 | + op->emitError("copies of unranked tensors are not supported"); |
| 68 | + return WalkResult::interrupt(); |
| 69 | + } |
| 70 | + auto tensorType = opOperand.get().getType().dyn_cast<RankedTensorType>(); |
| 71 | + if (!tensorType) |
| 72 | + continue; |
| 73 | + if (state.isInPlace(opOperand)) |
| 74 | + continue; |
| 75 | + SmallVector<OpResult> aliasingOpResults = |
| 76 | + state.getAliasingOpResult(opOperand); |
| 77 | + bool escape = llvm::any_of( |
| 78 | + aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); |
| 79 | + Value copy = builder.create<AllocTensorOp>( |
| 80 | + op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape); |
| 81 | + opOperand.set(copy); |
| 82 | + } |
| 83 | + |
| 84 | + return WalkResult::advance(); |
| 85 | + }); |
| 86 | + |
| 87 | + return failure(result.wasInterrupted()); |
| 88 | +} |
| 89 | + |
| 90 | +namespace { |
| 91 | +struct TensorCopyInsertionPass |
| 92 | + : TensorCopyInsertionBase<TensorCopyInsertionPass> { |
| 93 | + TensorCopyInsertionPass() |
| 94 | + : TensorCopyInsertionBase<TensorCopyInsertionPass>(), |
| 95 | + options(llvm::None) {} |
| 96 | + TensorCopyInsertionPass(const OneShotBufferizationOptions &options) |
| 97 | + : TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {} |
| 98 | + |
| 99 | + void getDependentDialects(DialectRegistry ®istry) const override { |
| 100 | + registry.insert<bufferization::BufferizationDialect>(); |
| 101 | + } |
| 102 | + |
| 103 | + void runOnOperation() override { |
| 104 | + if (options.hasValue()) { |
| 105 | + if (failed(insertTensorCopies(getOperation(), *options))) |
| 106 | + signalPassFailure(); |
| 107 | + } else { |
| 108 | + OneShotBufferizationOptions options; |
| 109 | + options.allowReturnAllocs = allowReturnAllocs; |
| 110 | + options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; |
| 111 | + if (failed(insertTensorCopies(getOperation(), options))) |
| 112 | + signalPassFailure(); |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | +private: |
| 117 | + Optional<OneShotBufferizationOptions> options; |
| 118 | +}; |
| 119 | +} // namespace |
| 120 | + |
| 121 | +std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() { |
| 122 | + return std::make_unique<TensorCopyInsertionPass>(); |
| 123 | +} |
| 124 | + |
| 125 | +std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass( |
| 126 | + const OneShotBufferizationOptions &options) { |
| 127 | + return std::make_unique<TensorCopyInsertionPass>(options); |
| 128 | +} |
0 commit comments