|
10 | 10 | //
|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
| 13 | +#include "mlir/Dialect/Affine/Passes.h" |
| 14 | + |
13 | 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h"
|
14 | 16 | #include "mlir/Dialect/Affine/Transforms/Transforms.h"
|
15 | 17 | #include "mlir/IR/PatternMatch.h"
|
| 18 | +#include "mlir/Interfaces/FunctionInterfaces.h" |
16 | 19 | #include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
| 20 | +#include "mlir/Pass/Pass.h" |
17 | 21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
18 | 22 | #include "llvm/ADT/IntEqClasses.h"
|
19 | 23 | #include "llvm/Support/Debug.h"
|
| 24 | +#include "llvm/Support/InterleavedRange.h" |
20 | 25 |
|
21 | 26 | #define DEBUG_TYPE "affine-min-max"
|
22 | 27 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
|
@@ -44,6 +49,12 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
|
44 | 49 | [&](unsigned i) {
|
45 | 50 | return Variable(affineMap.getSliceMap(i, 1), operands);
|
46 | 51 | });
|
| 52 | + LLVM_DEBUG({ |
| 53 | + DBGS() << "- constructed variables are: " |
| 54 | + << llvm::interleaved_array(llvm::map_range( |
| 55 | + variables, [](const Variable &v) { return v.getMap(); })) |
| 56 | + << "`\n"; |
| 57 | + }); |
47 | 58 |
|
48 | 59 | // Get the comparison operation.
|
49 | 60 | ComparisonOperator cmpOp =
|
@@ -125,8 +136,17 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
|
125 | 136 | for (auto [k, bound] : bounds)
|
126 | 137 | results.push_back(bound->getMap().getResult(0));
|
127 | 138 |
|
128 |
| - affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(), |
129 |
| - results, rewriter.getContext()); |
| 139 | + LLVM_DEBUG({ |
| 140 | + DBGS() << "- starting from map: " << affineMap << "\n"; |
| 141 | + DBGS() << "- creating new map with: \n"; |
| 142 | + DBGS() << "--- dims: " << affineMap.getNumDims() << "\n"; |
| 143 | + DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n"; |
| 144 | + DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n"; |
| 145 | + }); |
| 146 | + |
| 147 | + affineMap = |
| 148 | + AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(), |
| 149 | + results, rewriter.getContext()); |
130 | 150 |
|
131 | 151 | // Update the affine op.
|
132 | 152 | rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
|
@@ -172,3 +192,73 @@ LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
|
172 | 192 | *modified = changed;
|
173 | 193 | return success();
|
174 | 194 | }
|
| 195 | + |
| 196 | +namespace { |
| 197 | + |
| 198 | +struct SimplifyAffineMaxOp : public OpRewritePattern<AffineMaxOp> { |
| 199 | + using OpRewritePattern<AffineMaxOp>::OpRewritePattern; |
| 200 | + |
| 201 | + LogicalResult matchAndRewrite(AffineMaxOp affineOp, |
| 202 | + PatternRewriter &rewriter) const override { |
| 203 | + return success(simplifyAffineMaxOp(rewriter, affineOp)); |
| 204 | + } |
| 205 | +}; |
| 206 | + |
| 207 | +struct SimplifyAffineMinOp : public OpRewritePattern<AffineMinOp> { |
| 208 | + using OpRewritePattern<AffineMinOp>::OpRewritePattern; |
| 209 | + |
| 210 | + LogicalResult matchAndRewrite(AffineMinOp affineOp, |
| 211 | + PatternRewriter &rewriter) const override { |
| 212 | + return success(simplifyAffineMinOp(rewriter, affineOp)); |
| 213 | + } |
| 214 | +}; |
| 215 | + |
| 216 | +struct SimplifyAffineApplyOp : public OpRewritePattern<AffineApplyOp> { |
| 217 | + using OpRewritePattern<AffineApplyOp>::OpRewritePattern; |
| 218 | + |
| 219 | + LogicalResult matchAndRewrite(AffineApplyOp affineOp, |
| 220 | + PatternRewriter &rewriter) const override { |
| 221 | + AffineMap map = affineOp.getAffineMap(); |
| 222 | + SmallVector<Value> operands{affineOp->getOperands().begin(), |
| 223 | + affineOp->getOperands().end()}; |
| 224 | + fullyComposeAffineMapAndOperands(&map, &operands, |
| 225 | + /*composeAffineMin=*/true); |
| 226 | + |
| 227 | + // No change => failure to apply. |
| 228 | + if (map == affineOp.getAffineMap()) |
| 229 | + return failure(); |
| 230 | + |
| 231 | + rewriter.modifyOpInPlace(affineOp, [&]() { |
| 232 | + affineOp.setMap(map); |
| 233 | + affineOp->setOperands(operands); |
| 234 | + }); |
| 235 | + return success(); |
| 236 | + } |
| 237 | +}; |
| 238 | + |
| 239 | +} // namespace |
| 240 | + |
| 241 | +namespace mlir { |
| 242 | +namespace affine { |
| 243 | +#define GEN_PASS_DEF_SIMPLIFYAFFINEMINMAX |
| 244 | +#include "mlir/Dialect/Affine/Passes.h.inc" |
| 245 | +} // namespace affine |
| 246 | +} // namespace mlir |
| 247 | + |
| 248 | +/// Creates a simplification pass for affine min/max/apply. |
| 249 | +struct SimplifyAffineMinMaxPass |
| 250 | + : public affine::impl::SimplifyAffineMinMaxBase<SimplifyAffineMinMaxPass> { |
| 251 | + void runOnOperation() override; |
| 252 | +}; |
| 253 | + |
| 254 | +void SimplifyAffineMinMaxPass::runOnOperation() { |
| 255 | + FunctionOpInterface func = getOperation(); |
| 256 | + RewritePatternSet patterns(func.getContext()); |
| 257 | + AffineMaxOp::getCanonicalizationPatterns(patterns, func.getContext()); |
| 258 | + AffineMinOp::getCanonicalizationPatterns(patterns, func.getContext()); |
| 259 | + patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>( |
| 260 | + func.getContext()); |
| 261 | + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
| 262 | + if (failed(applyPatternsGreedily(func, std::move(frozenPatterns)))) |
| 263 | + return signalPassFailure(); |
| 264 | +} |
0 commit comments