13
13
#include " mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
14
14
#include " mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
- #include " mlir/Transforms/DialectConversion.h"
16
+ #include " mlir/IR/PatternMatch.h"
17
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
17
18
18
19
namespace mlir {
19
20
namespace arith {
@@ -29,6 +30,9 @@ using namespace mlir::dataflow;
29
30
// / Succeeds when a value is statically non-negative in that it has a lower
30
31
// / bound on its value (if it is treated as signed) and that bound is
31
32
// / non-negative.
33
+ // TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
34
+ // relies on this. These transformations may not be valid for 32bit index,
35
+ // need more investigation.
32
36
static LogicalResult staticallyNonNegative (DataFlowSolver &solver, Value v) {
33
37
auto *result = solver.lookupState <IntegerValueRangeLattice>(v);
34
38
if (!result || result->getValue ().isUninitialized ())
@@ -85,35 +89,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
85
89
}
86
90
87
91
namespace {
92
+ class DataFlowListener : public RewriterBase ::Listener {
93
+ public:
94
+ DataFlowListener (DataFlowSolver &s) : s(s) {}
95
+
96
+ protected:
97
+ void notifyOperationErased (Operation *op) override {
98
+ s.eraseState (s.getProgramPointAfter (op));
99
+ for (Value res : op->getResults ())
100
+ s.eraseState (res);
101
+ }
102
+
103
+ DataFlowSolver &s;
104
+ };
105
+
88
106
template <typename Signed, typename Unsigned>
89
- struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
90
- using OpConversionPattern<Signed>::OpConversionPattern;
107
+ struct ConvertOpToUnsigned final : OpRewritePattern<Signed> {
108
+ ConvertOpToUnsigned (MLIRContext *context, DataFlowSolver &s)
109
+ : OpRewritePattern<Signed>(context), solver(s) {}
91
110
92
- LogicalResult matchAndRewrite (Signed op, typename Signed::Adaptor adaptor,
93
- ConversionPatternRewriter &rw) const override {
94
- rw.replaceOpWithNewOp <Unsigned>(op, op->getResultTypes (),
95
- adaptor.getOperands (), op->getAttrs ());
111
+ LogicalResult matchAndRewrite (Signed op, PatternRewriter &rw) const override {
112
+ if (failed (
113
+ staticallyNonNegative (this ->solver , static_cast <Operation *>(op))))
114
+ return failure ();
115
+
116
+ rw.replaceOpWithNewOp <Unsigned>(op, op->getResultTypes (), op->getOperands (),
117
+ op->getAttrs ());
96
118
return success ();
97
119
}
120
+
121
+ private:
122
+ DataFlowSolver &solver;
98
123
};
99
124
100
- struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
101
- using OpConversionPattern<CmpIOp>::OpConversionPattern;
125
+ struct ConvertCmpIToUnsigned final : OpRewritePattern<CmpIOp> {
126
+ ConvertCmpIToUnsigned (MLIRContext *context, DataFlowSolver &s)
127
+ : OpRewritePattern<CmpIOp>(context), solver(s) {}
128
+
129
+ LogicalResult matchAndRewrite (CmpIOp op, PatternRewriter &rw) const override {
130
+ if (failed (isCmpIConvertable (this ->solver , op)))
131
+ return failure ();
102
132
103
- LogicalResult matchAndRewrite (CmpIOp op, CmpIOpAdaptor adaptor,
104
- ConversionPatternRewriter &rw) const override {
105
133
rw.replaceOpWithNewOp <CmpIOp>(op, toUnsignedPred (op.getPredicate ()),
106
134
op.getLhs (), op.getRhs ());
107
135
return success ();
108
136
}
137
+
138
+ private:
139
+ DataFlowSolver &solver;
109
140
};
110
141
111
142
struct ArithUnsignedWhenEquivalentPass
112
143
: public arith::impl::ArithUnsignedWhenEquivalentBase<
113
144
ArithUnsignedWhenEquivalentPass> {
114
- // / Implementation structure: first find all equivalent ops and collect them,
115
- // / then perform all the rewrites in a second pass over the target op. This
116
- // / ensures that analysis results are not invalidated during rewriting.
145
+
117
146
void runOnOperation () override {
118
147
Operation *op = getOperation ();
119
148
MLIRContext *ctx = op->getContext ();
@@ -123,35 +152,32 @@ struct ArithUnsignedWhenEquivalentPass
123
152
if (failed (solver.initializeAndRun (op)))
124
153
return signalPassFailure ();
125
154
126
- ConversionTarget target (*ctx);
127
- target.addLegalDialect <ArithDialect>();
128
- target.addDynamicallyLegalOp <DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
129
- MinSIOp, MaxSIOp, ExtSIOp>(
130
- [&solver](Operation *op) -> std::optional<bool > {
131
- return failed (staticallyNonNegative (solver, op));
132
- });
133
- target.addDynamicallyLegalOp <CmpIOp>(
134
- [&solver](CmpIOp op) -> std::optional<bool > {
135
- return failed (isCmpIConvertable (solver, op));
136
- });
155
+ DataFlowListener listener (solver);
137
156
138
157
RewritePatternSet patterns (ctx);
139
- patterns.add <ConvertOpToUnsigned<DivSIOp, DivUIOp>,
140
- ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
141
- ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
142
- ConvertOpToUnsigned<RemSIOp, RemUIOp>,
143
- ConvertOpToUnsigned<MinSIOp, MinUIOp>,
144
- ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
145
- ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
146
- ctx);
147
-
148
- if (failed (applyPartialConversion (op, target, std::move (patterns)))) {
158
+ populateUnsignedWhenEquivalentPatterns (patterns, solver);
159
+
160
+ GreedyRewriteConfig config;
161
+ config.listener = &listener;
162
+
163
+ if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns), config)))
149
164
signalPassFailure ();
150
- }
151
165
}
152
166
};
153
167
} // end anonymous namespace
154
168
169
+ void mlir::arith::populateUnsignedWhenEquivalentPatterns (
170
+ RewritePatternSet &patterns, DataFlowSolver &solver) {
171
+ patterns.add <ConvertOpToUnsigned<DivSIOp, DivUIOp>,
172
+ ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
173
+ ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
174
+ ConvertOpToUnsigned<RemSIOp, RemUIOp>,
175
+ ConvertOpToUnsigned<MinSIOp, MinUIOp>,
176
+ ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
177
+ ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
178
+ patterns.getContext (), solver);
179
+ }
180
+
155
181
std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass () {
156
182
return std::make_unique<ArithUnsignedWhenEquivalentPass>();
157
183
}
0 commit comments