Skip to content

Commit 6902b39

Browse files
authored
[mlir] UnsignedWhenEquivalent: use greedy rewriter instead of dialect conversion (#112454)
`UnsignedWhenEquivalent` doesn't really need any dialect conversion features and switching it normal patterns makes it more composable with other patterns-based transformations (and probably faster).
1 parent 77ea619 commit 6902b39

File tree

3 files changed

+76
-46
lines changed

3 files changed

+76
-46
lines changed

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
7070
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
7171
DataFlowSolver &solver);
7272

73+
/// Replace signed ops with unsigned ones where they are proven equivalent.
74+
void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns,
75+
DataFlowSolver &solver);
76+
7377
/// Create a pass which do optimizations based on integer range analysis.
7478
std::unique_ptr<Pass> createIntRangeOptimizationsPass();
7579

mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
1414
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
16-
#include "mlir/Transforms/DialectConversion.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1718

1819
namespace mlir {
1920
namespace arith {
@@ -29,6 +30,9 @@ using namespace mlir::dataflow;
2930
/// Succeeds when a value is statically non-negative in that it has a lower
3031
/// bound on its value (if it is treated as signed) and that bound is
3132
/// non-negative.
33+
// TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
34+
// relies on this. These transformations may not be valid for 32bit index,
35+
// need more investigation.
3236
static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
3337
auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
3438
if (!result || result->getValue().isUninitialized())
@@ -85,35 +89,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
8589
}
8690

8791
namespace {
92+
class DataFlowListener : public RewriterBase::Listener {
93+
public:
94+
DataFlowListener(DataFlowSolver &s) : s(s) {}
95+
96+
protected:
97+
void notifyOperationErased(Operation *op) override {
98+
s.eraseState(s.getProgramPointAfter(op));
99+
for (Value res : op->getResults())
100+
s.eraseState(res);
101+
}
102+
103+
DataFlowSolver &s;
104+
};
105+
88106
template <typename Signed, typename Unsigned>
89-
struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
90-
using OpConversionPattern<Signed>::OpConversionPattern;
107+
struct ConvertOpToUnsigned final : OpRewritePattern<Signed> {
108+
ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
109+
: OpRewritePattern<Signed>(context), solver(s) {}
91110

92-
LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
93-
ConversionPatternRewriter &rw) const override {
94-
rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
95-
adaptor.getOperands(), op->getAttrs());
111+
LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
112+
if (failed(
113+
staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
114+
return failure();
115+
116+
rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
117+
op->getAttrs());
96118
return success();
97119
}
120+
121+
private:
122+
DataFlowSolver &solver;
98123
};
99124

100-
struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
101-
using OpConversionPattern<CmpIOp>::OpConversionPattern;
125+
struct ConvertCmpIToUnsigned final : OpRewritePattern<CmpIOp> {
126+
ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
127+
: OpRewritePattern<CmpIOp>(context), solver(s) {}
128+
129+
LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
130+
if (failed(isCmpIConvertable(this->solver, op)))
131+
return failure();
102132

103-
LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
104-
ConversionPatternRewriter &rw) const override {
105133
rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
106134
op.getLhs(), op.getRhs());
107135
return success();
108136
}
137+
138+
private:
139+
DataFlowSolver &solver;
109140
};
110141

111142
struct ArithUnsignedWhenEquivalentPass
112143
: public arith::impl::ArithUnsignedWhenEquivalentBase<
113144
ArithUnsignedWhenEquivalentPass> {
114-
/// Implementation structure: first find all equivalent ops and collect them,
115-
/// then perform all the rewrites in a second pass over the target op. This
116-
/// ensures that analysis results are not invalidated during rewriting.
145+
117146
void runOnOperation() override {
118147
Operation *op = getOperation();
119148
MLIRContext *ctx = op->getContext();
@@ -123,35 +152,32 @@ struct ArithUnsignedWhenEquivalentPass
123152
if (failed(solver.initializeAndRun(op)))
124153
return signalPassFailure();
125154

126-
ConversionTarget target(*ctx);
127-
target.addLegalDialect<ArithDialect>();
128-
target.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
129-
MinSIOp, MaxSIOp, ExtSIOp>(
130-
[&solver](Operation *op) -> std::optional<bool> {
131-
return failed(staticallyNonNegative(solver, op));
132-
});
133-
target.addDynamicallyLegalOp<CmpIOp>(
134-
[&solver](CmpIOp op) -> std::optional<bool> {
135-
return failed(isCmpIConvertable(solver, op));
136-
});
155+
DataFlowListener listener(solver);
137156

138157
RewritePatternSet patterns(ctx);
139-
patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
140-
ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
141-
ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
142-
ConvertOpToUnsigned<RemSIOp, RemUIOp>,
143-
ConvertOpToUnsigned<MinSIOp, MinUIOp>,
144-
ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
145-
ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
146-
ctx);
147-
148-
if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
158+
populateUnsignedWhenEquivalentPatterns(patterns, solver);
159+
160+
GreedyRewriteConfig config;
161+
config.listener = &listener;
162+
163+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
149164
signalPassFailure();
150-
}
151165
}
152166
};
153167
} // end anonymous namespace
154168

169+
void mlir::arith::populateUnsignedWhenEquivalentPatterns(
170+
RewritePatternSet &patterns, DataFlowSolver &solver) {
171+
patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
172+
ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
173+
ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
174+
ConvertOpToUnsigned<RemSIOp, RemUIOp>,
175+
ConvertOpToUnsigned<MinSIOp, MinUIOp>,
176+
ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
177+
ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
178+
patterns.getContext(), solver);
179+
}
180+
155181
std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() {
156182
return std::make_unique<ArithUnsignedWhenEquivalentPass>();
157183
}

mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// CHECK: arith.cmpi slt
1313
// CHECK: arith.cmpi sge
1414
// CHECK: arith.cmpi sgt
15-
func.func @not_with_maybe_overflow(%arg0 : i32) {
15+
func.func @not_with_maybe_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
1616
%ci32_smax = arith.constant 0x7fffffff : i32
1717
%c1 = arith.constant 1 : i32
1818
%c4 = arith.constant 4 : i32
@@ -29,7 +29,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
2929
%10 = arith.cmpi slt, %1, %c4 : i32
3030
%11 = arith.cmpi sge, %1, %c4 : i32
3131
%12 = arith.cmpi sgt, %1, %c4 : i32
32-
func.return
32+
func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
3333
}
3434

3535
// CHECK-LABEL: func @yes_with_no_overflow
@@ -44,7 +44,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
4444
// CHECK: arith.cmpi ult
4545
// CHECK: arith.cmpi uge
4646
// CHECK: arith.cmpi ugt
47-
func.func @yes_with_no_overflow(%arg0 : i32) {
47+
func.func @yes_with_no_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
4848
%ci32_almost_smax = arith.constant 0x7ffffffe : i32
4949
%c1 = arith.constant 1 : i32
5050
%c4 = arith.constant 4 : i32
@@ -61,7 +61,7 @@ func.func @yes_with_no_overflow(%arg0 : i32) {
6161
%10 = arith.cmpi slt, %1, %c4 : i32
6262
%11 = arith.cmpi sge, %1, %c4 : i32
6363
%12 = arith.cmpi sgt, %1, %c4 : i32
64-
func.return
64+
func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
6565
}
6666

6767
// CHECK-LABEL: func @preserves_structure
@@ -90,20 +90,20 @@ func.func @preserves_structure(%arg0 : memref<8xindex>) {
9090
func.func private @external() -> i8
9191

9292
// CHECK-LABEL: @dead_code
93-
func.func @dead_code() {
93+
func.func @dead_code() -> i8 {
9494
%0 = call @external() : () -> i8
9595
// CHECK: arith.floordivsi
9696
%1 = arith.floordivsi %0, %0 : i8
97-
return
97+
return %1 : i8
9898
}
9999

100100
// Make sure not crash.
101101
// CHECK-LABEL: @no_integer_or_index
102-
func.func @no_integer_or_index() {
102+
func.func @no_integer_or_index(%arg0: vector<1xi32>) -> vector<1xi1> {
103103
// CHECK: arith.cmpi
104104
%cst_0 = arith.constant dense<[0]> : vector<1xi32>
105-
%cmp = arith.cmpi slt, %cst_0, %cst_0 : vector<1xi32>
106-
return
105+
%cmp = arith.cmpi slt, %cst_0, %arg0 : vector<1xi32>
106+
return %cmp : vector<1xi1>
107107
}
108108

109109
// CHECK-LABEL: @gpu_func
@@ -113,4 +113,4 @@ func.func @gpu_func(%arg0: memref<2x32xf32>, %arg1: memref<2x32xf32>, %arg2: mem
113113
gpu.terminator
114114
}
115115
return %arg1 : memref<2x32xf32>
116-
}
116+
}

0 commit comments

Comments
 (0)