@@ -26,22 +26,9 @@ namespace mlir {
26
26
27
27
using namespace mlir ;
28
28
29
- static bool checkIndexType (arith::CmpIOp op, unsigned indexBitWidth) {
30
- auto type = op.getLhs ().getType ();
31
- if (isa<mlir::IndexType>(type))
32
- return true ;
33
-
34
- if (type.isSignlessInteger (indexBitWidth))
35
- return true ;
36
-
37
- return false ;
38
- }
39
-
40
29
namespace {
41
30
struct UpliftWhileOp : public OpRewritePattern <scf::WhileOp> {
42
- UpliftWhileOp (MLIRContext *context, unsigned indexBitWidth_)
43
- : OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
44
- }
31
+ using OpRewritePattern::OpRewritePattern;
45
32
46
33
LogicalResult matchAndRewrite (scf::WhileOp loop,
47
34
PatternRewriter &rewriter) const override {
@@ -71,11 +58,6 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
71
58
diag << " Expected 'slt' or 'sgt' predicate: " << *cmp;
72
59
});
73
60
74
- if (!checkIndexType (cmp, indexBitWidth))
75
- return rewriter.notifyMatchFailure (loop, [&](Diagnostic &diag) {
76
- diag << " Expected index-like type: " << *cmp;
77
- });
78
-
79
61
BlockArgument iterVar;
80
62
Value end;
81
63
DominanceInfo dom;
@@ -140,17 +122,9 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
140
122
141
123
auto begin = loop.getInits ()[argNumber];
142
124
143
- auto loc = loop.getLoc ();
144
- auto indexType = rewriter.getIndexType ();
145
- auto toIndex = [&](Value val) -> Value {
146
- if (val.getType () != indexType)
147
- return rewriter.create <arith::IndexCastOp>(loc, indexType, val);
148
-
149
- return val;
150
- };
151
- begin = toIndex (begin);
152
- end = toIndex (end);
153
- step = toIndex (step);
125
+ assert (begin.getType ().isIntOrIndex ());
126
+ assert (begin.getType () == end.getType ());
127
+ assert (begin.getType () == step.getType ());
154
128
155
129
llvm::SmallVector<Value> mapping;
156
130
mapping.reserve (loop.getInits ().size ());
@@ -161,6 +135,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
161
135
mapping.emplace_back (init);
162
136
}
163
137
138
+ auto loc = loop.getLoc ();
164
139
auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
165
140
auto newLoop = rewriter.create <scf::ForOp>(loc, begin, end, step, mapping,
166
141
emptyBuidler);
@@ -170,21 +145,14 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
170
145
OpBuilder::InsertionGuard g (rewriter);
171
146
rewriter.setInsertionPointToStart (newBody);
172
147
Value newIterVar = newBody->getArgument (0 );
173
- if (newIterVar.getType () != iterVar.getType ())
174
- newIterVar = rewriter.create <arith::IndexCastOp>(loc, iterVar.getType (),
175
- newIterVar);
176
148
177
149
mapping.clear ();
178
150
auto newArgs = newBody->getArguments ();
179
151
for (auto i : llvm::seq<size_t >(0 , newArgs.size ())) {
180
152
if (i < argNumber) {
181
153
mapping.emplace_back (newArgs[i + 1 ]);
182
154
} else if (i == argNumber) {
183
- Value arg = newArgs.front ();
184
- if (arg.getType () != iterVar.getType ())
185
- arg =
186
- rewriter.create <arith::IndexCastOp>(loc, iterVar.getType (), arg);
187
- mapping.emplace_back (arg);
155
+ mapping.emplace_back (newArgs.front ());
188
156
} else {
189
157
mapping.emplace_back (newArgs[i]);
190
158
}
@@ -207,26 +175,27 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
207
175
rewriter.replaceOpWithNewOp <scf::YieldOp>(term, mapping);
208
176
209
177
rewriter.setInsertionPointAfter (newLoop);
210
- Value one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
178
+ Value one;
179
+ if (isa<IndexType>(step.getType ())) {
180
+ one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
181
+ } else {
182
+ one = rewriter.create <arith::ConstantIntOp>(loc, 1 , step.getType ());
183
+ }
184
+
211
185
Value stepDec = rewriter.create <arith::SubIOp>(loc, step, one);
212
186
Value len = rewriter.create <arith::SubIOp>(loc, end, begin);
213
187
len = rewriter.create <arith::AddIOp>(loc, len, stepDec);
214
188
len = rewriter.create <arith::DivSIOp>(loc, len, step);
215
189
len = rewriter.create <arith::SubIOp>(loc, len, one);
216
190
Value res = rewriter.create <arith::MulIOp>(loc, len, step);
217
191
res = rewriter.create <arith::AddIOp>(loc, begin, res);
218
- if (res.getType () != iterVar.getType ())
219
- res = rewriter.create <arith::IndexCastOp>(loc, iterVar.getType (), res);
220
192
221
193
mapping.clear ();
222
194
llvm::append_range (mapping, newLoop.getResults ());
223
195
mapping.insert (mapping.begin () + argNumber, res);
224
196
rewriter.replaceOp (loop, mapping);
225
197
return success ();
226
198
}
227
-
228
- private:
229
- unsigned indexBitWidth = 0 ;
230
199
};
231
200
232
201
struct SCFUpliftWhileToFor final
@@ -237,14 +206,13 @@ struct SCFUpliftWhileToFor final
237
206
Operation *op = getOperation ();
238
207
MLIRContext *ctx = op->getContext ();
239
208
RewritePatternSet patterns (ctx);
240
- mlir::scf::populateUpliftWhileToForPatterns (patterns, this -> indexBitWidth );
209
+ mlir::scf::populateUpliftWhileToForPatterns (patterns);
241
210
if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
242
211
signalPassFailure ();
243
212
}
244
213
};
245
214
} // namespace
246
215
247
- void mlir::scf::populateUpliftWhileToForPatterns (RewritePatternSet &patterns,
248
- unsigned indexBitwidth) {
249
- patterns.add <UpliftWhileOp>(patterns.getContext (), indexBitwidth);
216
+ void mlir::scf::populateUpliftWhileToForPatterns (RewritePatternSet &patterns) {
217
+ patterns.add <UpliftWhileOp>(patterns.getContext ());
250
218
}
0 commit comments