|
9 | 9 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
10 | 10 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
11 | 11 | #include "mlir/Dialect/Arith/Utils/Utils.h"
|
12 |
| -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
13 | 12 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
14 | 13 | #include "mlir/Dialect/MemRef/IR/MemRef.h"
|
15 | 14 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
|
@@ -1082,139 +1081,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
|
1082 | 1081 | }
|
1083 | 1082 | };
|
1084 | 1083 |
|
1085 |
| -/// Canonicalize the iter_args of an scf::ForOp that involve a |
1086 |
| -/// `bufferization.to_tensor` and for which only the last loop iteration is |
1087 |
| -/// actually visible outside of the loop. The canonicalization looks for a |
1088 |
| -/// pattern such as: |
1089 |
| -/// ``` |
1090 |
| -/// %t0 = ... : tensor_type |
1091 |
| -/// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) { |
1092 |
| -/// ... |
1093 |
| -/// // %m is either buffer_cast(%bb00) or defined above the loop |
1094 |
| -/// %m... : memref_type |
1095 |
| -/// ... // uses of %m with potential inplace updates |
1096 |
| -/// %new_tensor = bufferization.to_tensor %m : memref_type |
1097 |
| -/// ... |
1098 |
| -/// scf.yield %new_tensor : tensor_type |
1099 |
| -/// } |
1100 |
| -/// ``` |
1101 |
| -/// |
1102 |
| -/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a |
1103 |
| -/// `%m = buffer_cast %bb0` op that feeds into the yielded |
1104 |
| -/// `bufferization.to_tensor` op. |
1105 |
| -/// |
1106 |
| -/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded, |
1107 |
| -/// occurs between `bufferization.to_tensor and yield then the value %0 |
1108 |
| -/// visible outside of the loop is the last `bufferization.to_tensor` |
1109 |
| -/// produced in the loop. |
1110 |
| -/// |
1111 |
| -/// For now, we approximate the absence of aliasing by only supporting the case |
1112 |
| -/// when the bufferization.to_tensor is the operation immediately preceding |
1113 |
| -/// the yield. |
1114 |
| -// |
1115 |
| -/// The canonicalization rewrites the pattern as: |
1116 |
| -/// ``` |
1117 |
| -/// // %m is either a buffer_cast or defined above |
1118 |
| -/// %m... : memref_type |
1119 |
| -/// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) { |
1120 |
| -/// ... // uses of %m with potential inplace updates |
1121 |
| -/// scf.yield %bb0: tensor_type |
1122 |
| -/// } |
1123 |
| -/// %0 = bufferization.to_tensor %m : memref_type |
1124 |
| -/// ``` |
1125 |
| -/// |
1126 |
| -/// A later bbArg canonicalization will further rewrite as: |
1127 |
| -/// ``` |
1128 |
| -/// // %m is either a buffer_cast or defined above |
1129 |
| -/// %m... : memref_type |
1130 |
| -/// scf.for ... { // no iter_args |
1131 |
| -/// ... // uses of %m with potential inplace updates |
1132 |
| -/// } |
1133 |
| -/// %0 = bufferization.to_tensor %m : memref_type |
1134 |
| -/// ``` |
1135 |
| -struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> { |
1136 |
| - using OpRewritePattern<ForOp>::OpRewritePattern; |
1137 |
| - |
1138 |
| - LogicalResult matchAndRewrite(ForOp forOp, |
1139 |
| - PatternRewriter &rewriter) const override { |
1140 |
| - assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() && |
1141 |
| - "unexpected multiple blocks"); |
1142 |
| - |
1143 |
| - Location loc = forOp.getLoc(); |
1144 |
| - DenseMap<Value, Value> replacements; |
1145 |
| - for (BlockArgument bbArg : forOp.getRegionIterArgs()) { |
1146 |
| - unsigned idx = bbArg.getArgNumber() - /*numIv=*/1; |
1147 |
| - auto yieldOp = |
1148 |
| - cast<scf::YieldOp>(forOp.getRegion().front().getTerminator()); |
1149 |
| - Value yieldVal = yieldOp->getOperand(idx); |
1150 |
| - auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>(); |
1151 |
| - bool isTensor = llvm::isa<TensorType>(bbArg.getType()); |
1152 |
| - |
1153 |
| - bufferization::ToMemrefOp tensorToMemref; |
1154 |
| - // Either bbArg has no use or it has a single buffer_cast use. |
1155 |
| - if (bbArg.hasOneUse()) |
1156 |
| - tensorToMemref = |
1157 |
| - dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin()); |
1158 |
| - if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref)) |
1159 |
| - continue; |
1160 |
| - // If tensorToMemref is present, it must feed into the `ToTensorOp`. |
1161 |
| - if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref) |
1162 |
| - continue; |
1163 |
| - // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp` |
1164 |
| - // must be before `ToTensorOp` in the block so that the lastWrite |
1165 |
| - // property is not subject to additional side-effects. |
1166 |
| - // For now, we only support the case when ToTensorOp appears |
1167 |
| - // immediately before the terminator. |
1168 |
| - if (tensorLoadOp->getNextNode() != yieldOp) |
1169 |
| - continue; |
1170 |
| - |
1171 |
| - // Clone the optional tensorToMemref before forOp. |
1172 |
| - if (tensorToMemref) { |
1173 |
| - rewriter.setInsertionPoint(forOp); |
1174 |
| - rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>( |
1175 |
| - tensorToMemref, tensorToMemref.getMemref().getType(), |
1176 |
| - tensorToMemref.getTensor()); |
1177 |
| - } |
1178 |
| - |
1179 |
| - // Clone the tensorLoad after forOp. |
1180 |
| - rewriter.setInsertionPointAfter(forOp); |
1181 |
| - Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>( |
1182 |
| - loc, tensorLoadOp.getMemref()); |
1183 |
| - Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1); |
1184 |
| - replacements.insert(std::make_pair(forOpResult, newTensorLoad)); |
1185 |
| - |
1186 |
| - // Make the terminator just yield the bbArg, the old tensorLoadOp + the |
1187 |
| - // old bbArg (that is now directly yielded) will canonicalize away. |
1188 |
| - rewriter.startRootUpdate(yieldOp); |
1189 |
| - yieldOp.setOperand(idx, bbArg); |
1190 |
| - rewriter.finalizeRootUpdate(yieldOp); |
1191 |
| - } |
1192 |
| - if (replacements.empty()) |
1193 |
| - return failure(); |
1194 |
| - |
1195 |
| - // We want to replace a subset of the results of `forOp`. rewriter.replaceOp |
1196 |
| - // replaces the whole op and erase it unconditionally. This is wrong for |
1197 |
| - // `forOp` as it generally contains ops with side effects. |
1198 |
| - // Instead, use `rewriter.replaceOpWithIf`. |
1199 |
| - SmallVector<Value> newResults; |
1200 |
| - newResults.reserve(forOp.getNumResults()); |
1201 |
| - for (Value v : forOp.getResults()) { |
1202 |
| - auto it = replacements.find(v); |
1203 |
| - newResults.push_back((it != replacements.end()) ? it->second : v); |
1204 |
| - } |
1205 |
| - unsigned idx = 0; |
1206 |
| - rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) { |
1207 |
| - return op.get() != newResults[idx++]; |
1208 |
| - }); |
1209 |
| - return success(); |
1210 |
| - } |
1211 |
| -}; |
1212 | 1084 | } // namespace
|
1213 | 1085 |
|
1214 | 1086 | void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
1215 | 1087 | MLIRContext *context) {
|
1216 |
| - results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, |
1217 |
| - LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context); |
| 1088 | + results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>( |
| 1089 | + context); |
1218 | 1090 | }
|
1219 | 1091 |
|
1220 | 1092 | std::optional<APInt> ForOp::getConstantStep() {
|
|
0 commit comments