Skip to content

Commit a1d5b9d

Browse files
[mlir][affine] Wrap SimplifyAffineMinMax in a pass (#145741)
This revision adds a pass working on FunctionOpInterface to connect recently introduced AffineMin/Max simplification patterns. Additionally fixes some minor issues that have surfaced upon larger scale testing.
1 parent 90f3147 commit a1d5b9d

File tree

5 files changed

+181
-3
lines changed

5 files changed

+181
-3
lines changed

mlir/include/mlir/Dialect/Affine/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,23 @@ def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"
414414
let constructor = "mlir::affine::createSimplifyAffineStructuresPass()";
415415
}
416416

417+
def SimplifyAffineMinMax : InterfacePass<"affine-simplify-min-max", "FunctionOpInterface"> {
418+
let summary = "Simplify affine min/max/apply";
419+
let description = [{
420+
Apply the SimplifyAffineMaxOp, SimplifyAffineMinOp and SimplifyAffineApplyOp
421+
patterns in addition to AffineMin/Max canonicalization patterns until a
422+
fixed point is reached.
423+
These patterns apply ValueBoundsOp interface on AffineMin/Max ops and
424+
additional simplifications such as:
425+
```
426+
min(x, y, cst) / cst -> 1
427+
```
428+
when x, y, cst are all >= 0.
429+
This is typically useful to extract more static informationfrom IR after
430+
tiling but can also come at a cost due to Presburger-style analysis.
431+
}];
432+
}
433+
417434
def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
418435
let summary = "Lower affine operations operating on indices into more fundamental operations";
419436
let constructor = "mlir::affine::createAffineExpandIndexOpsPass()";

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ using llvm::divideFloorSigned;
4242
using llvm::mod;
4343

4444
#define DEBUG_TYPE "affine-ops"
45+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
4546

4647
#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
4748

@@ -1065,6 +1066,10 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
10651066
ValueRange syms) {
10661067
AffineMap affineMinMap = minOp.getAffineMap();
10671068

1069+
LLVM_DEBUG({
1070+
DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n";
1071+
});
1072+
10681073
// Check the value is positive.
10691074
for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
10701075
// Compare each expression in the minimum against 0.
@@ -1263,6 +1268,12 @@ void mlir::affine::fullyComposeAffineMapAndOperands(
12631268
})) {
12641269
composeAffineMapAndOperands(map, operands, composeAffineMin);
12651270
}
1271+
// Additional trailing step for AffineMinOps in case no chains of AffineApply.
1272+
if (composeAffineMin && llvm::any_of(*operands, [](Value v) {
1273+
return isa_and_nonnull<AffineMinOp>(v.getDefiningOp());
1274+
})) {
1275+
composeAffineMapAndOperands(map, operands, composeAffineMin);
1276+
}
12661277
}
12671278

12681279
AffineApplyOp

mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "mlir/Dialect/Affine/Passes.h"
14+
1315
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1416
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
1517
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Interfaces/FunctionInterfaces.h"
1619
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
20+
#include "mlir/Pass/Pass.h"
1721
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1822
#include "llvm/ADT/IntEqClasses.h"
1923
#include "llvm/Support/Debug.h"
24+
#include "llvm/Support/InterleavedRange.h"
2025

2126
#define DEBUG_TYPE "affine-min-max"
2227
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -44,6 +49,12 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
4449
[&](unsigned i) {
4550
return Variable(affineMap.getSliceMap(i, 1), operands);
4651
});
52+
LLVM_DEBUG({
53+
DBGS() << "- constructed variables are: "
54+
<< llvm::interleaved_array(llvm::map_range(
55+
variables, [](const Variable &v) { return v.getMap(); }))
56+
<< "`\n";
57+
});
4758

4859
// Get the comparison operation.
4960
ComparisonOperator cmpOp =
@@ -125,8 +136,17 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
125136
for (auto [k, bound] : bounds)
126137
results.push_back(bound->getMap().getResult(0));
127138

128-
affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
129-
results, rewriter.getContext());
139+
LLVM_DEBUG({
140+
DBGS() << "- starting from map: " << affineMap << "\n";
141+
DBGS() << "- creating new map with: \n";
142+
DBGS() << "--- dims: " << affineMap.getNumDims() << "\n";
143+
DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n";
144+
DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n";
145+
});
146+
147+
affineMap =
148+
AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
149+
results, rewriter.getContext());
130150

131151
// Update the affine op.
132152
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
@@ -172,3 +192,73 @@ LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
172192
*modified = changed;
173193
return success();
174194
}
195+
196+
namespace {
197+
198+
struct SimplifyAffineMaxOp : public OpRewritePattern<AffineMaxOp> {
199+
using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
200+
201+
LogicalResult matchAndRewrite(AffineMaxOp affineOp,
202+
PatternRewriter &rewriter) const override {
203+
return success(simplifyAffineMaxOp(rewriter, affineOp));
204+
}
205+
};
206+
207+
struct SimplifyAffineMinOp : public OpRewritePattern<AffineMinOp> {
208+
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
209+
210+
LogicalResult matchAndRewrite(AffineMinOp affineOp,
211+
PatternRewriter &rewriter) const override {
212+
return success(simplifyAffineMinOp(rewriter, affineOp));
213+
}
214+
};
215+
216+
struct SimplifyAffineApplyOp : public OpRewritePattern<AffineApplyOp> {
217+
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
218+
219+
LogicalResult matchAndRewrite(AffineApplyOp affineOp,
220+
PatternRewriter &rewriter) const override {
221+
AffineMap map = affineOp.getAffineMap();
222+
SmallVector<Value> operands{affineOp->getOperands().begin(),
223+
affineOp->getOperands().end()};
224+
fullyComposeAffineMapAndOperands(&map, &operands,
225+
/*composeAffineMin=*/true);
226+
227+
// No change => failure to apply.
228+
if (map == affineOp.getAffineMap())
229+
return failure();
230+
231+
rewriter.modifyOpInPlace(affineOp, [&]() {
232+
affineOp.setMap(map);
233+
affineOp->setOperands(operands);
234+
});
235+
return success();
236+
}
237+
};
238+
239+
} // namespace
240+
241+
namespace mlir {
242+
namespace affine {
243+
#define GEN_PASS_DEF_SIMPLIFYAFFINEMINMAX
244+
#include "mlir/Dialect/Affine/Passes.h.inc"
245+
} // namespace affine
246+
} // namespace mlir
247+
248+
/// Creates a simplification pass for affine min/max/apply.
249+
struct SimplifyAffineMinMaxPass
250+
: public affine::impl::SimplifyAffineMinMaxBase<SimplifyAffineMinMaxPass> {
251+
void runOnOperation() override;
252+
};
253+
254+
void SimplifyAffineMinMaxPass::runOnOperation() {
255+
FunctionOpInterface func = getOperation();
256+
RewritePatternSet patterns(func.getContext());
257+
AffineMaxOp::getCanonicalizationPatterns(patterns, func.getContext());
258+
AffineMinOp::getCanonicalizationPatterns(patterns, func.getContext());
259+
patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>(
260+
func.getContext());
261+
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
262+
if (failed(applyPatternsGreedily(func, std::move(frozenPatterns))))
263+
return signalPassFailure();
264+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(affine-simplify-min-max))" %s | FileCheck %s
2+
3+
// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
4+
// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
5+
// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)>
6+
7+
// CHECK: @min_max_full_simplify
8+
func.func @min_max_full_simplify() -> (index, index) {
9+
%0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
10+
%1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
11+
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
12+
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
13+
// CHECK-NOT: affine.min
14+
// CHECK-NOT: affine.max
15+
// CHECK: return %[[V0]], %[[V1]]
16+
%r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
17+
%r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
18+
return %r0, %r1 : index, index
19+
}
20+
21+
// CHECK: @min_only_simplify
22+
func.func @min_only_simplify() -> (index, index) {
23+
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
24+
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
25+
// CHECK: affine.min #[[MAP_0]]()[%[[V0]]]
26+
// CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]]
27+
%0 = test.value_with_bounds {max = 512 : index, min = 0 : index}
28+
%1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
29+
%r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
30+
%r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
31+
return %r0, %r1 : index, index
32+
}
33+
34+
// CHECK: @max_only_simplify
35+
func.func @max_only_simplify() -> (index, index) {
36+
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
37+
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
38+
// CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
39+
// CHECK: affine.max #[[MAP_2]]()[%[[V1]]]
40+
%0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
41+
%1 = test.value_with_bounds {max = 512 : index, min = 0 : index}
42+
%r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
43+
%r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
44+
return %r0, %r1 : index, index
45+
}
46+
47+
// CHECK: @overlapping_constraints
48+
func.func @overlapping_constraints() -> (index, index) {
49+
%0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
50+
%1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
51+
%2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
52+
// CHECK: %[[V0:.*]] = test.value_with_bounds {max = 192 : index, min = 0 : index}
53+
// CHECK: %[[V1:.*]] = test.value_with_bounds {max = 384 : index, min = 128 : index}
54+
// CHECK: %[[V2:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
55+
// CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
56+
// CHECK: affine.max #[[MAP_1]]()[%[[V1]], %[[V2]]]
57+
%r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
58+
%r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
59+
return %r0, %r1 : index, index
60+
}

mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
1+
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
22

33
// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
44
// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>

0 commit comments

Comments
 (0)