Skip to content

Commit 08da5b8

Browse files
committed
Support non-index types
1 parent 3883cc8 commit 08da5b8

File tree

4 files changed

+31
-75
lines changed

4 files changed

+31
-75
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,6 @@ def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
161161
compatible form. `scf.while` are left unchanged if uplifting is not
162162
possible.
163163
}];
164-
165-
let options = [
166-
Option<"indexBitWidth", "index-bitwidth", "unsigned",
167-
/*default=*/"64",
168-
"Bitwidth of index type.">,
169-
];
170164
}
171165

172166

mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
8080
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
8181

8282
/// Populate patterns to uplift `scf.while` ops to `scf.for`.
83-
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
84-
unsigned indexBitwidth);
83+
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
8584

8685
} // namespace scf
8786
} // namespace mlir

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,9 @@ namespace mlir {
2626

2727
using namespace mlir;
2828

29-
static bool checkIndexType(arith::CmpIOp op, unsigned indexBitWidth) {
30-
auto type = op.getLhs().getType();
31-
if (isa<mlir::IndexType>(type))
32-
return true;
33-
34-
if (type.isSignlessInteger(indexBitWidth))
35-
return true;
36-
37-
return false;
38-
}
39-
4029
namespace {
4130
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
42-
UpliftWhileOp(MLIRContext *context, unsigned indexBitWidth_)
43-
: OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
44-
}
31+
using OpRewritePattern::OpRewritePattern;
4532

4633
LogicalResult matchAndRewrite(scf::WhileOp loop,
4734
PatternRewriter &rewriter) const override {
@@ -71,11 +58,6 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
7158
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
7259
});
7360

74-
if (!checkIndexType(cmp, indexBitWidth))
75-
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
76-
diag << "Expected index-like type: " << *cmp;
77-
});
78-
7961
BlockArgument iterVar;
8062
Value end;
8163
DominanceInfo dom;
@@ -140,17 +122,9 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
140122

141123
auto begin = loop.getInits()[argNumber];
142124

143-
auto loc = loop.getLoc();
144-
auto indexType = rewriter.getIndexType();
145-
auto toIndex = [&](Value val) -> Value {
146-
if (val.getType() != indexType)
147-
return rewriter.create<arith::IndexCastOp>(loc, indexType, val);
148-
149-
return val;
150-
};
151-
begin = toIndex(begin);
152-
end = toIndex(end);
153-
step = toIndex(step);
125+
assert(begin.getType().isIntOrIndex());
126+
assert(begin.getType() == end.getType());
127+
assert(begin.getType() == step.getType());
154128

155129
llvm::SmallVector<Value> mapping;
156130
mapping.reserve(loop.getInits().size());
@@ -161,6 +135,7 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
161135
mapping.emplace_back(init);
162136
}
163137

138+
auto loc = loop.getLoc();
164139
auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
165140
auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
166141
emptyBuidler);
@@ -170,21 +145,14 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
170145
OpBuilder::InsertionGuard g(rewriter);
171146
rewriter.setInsertionPointToStart(newBody);
172147
Value newIterVar = newBody->getArgument(0);
173-
if (newIterVar.getType() != iterVar.getType())
174-
newIterVar = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(),
175-
newIterVar);
176148

177149
mapping.clear();
178150
auto newArgs = newBody->getArguments();
179151
for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
180152
if (i < argNumber) {
181153
mapping.emplace_back(newArgs[i + 1]);
182154
} else if (i == argNumber) {
183-
Value arg = newArgs.front();
184-
if (arg.getType() != iterVar.getType())
185-
arg =
186-
rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), arg);
187-
mapping.emplace_back(arg);
155+
mapping.emplace_back(newArgs.front());
188156
} else {
189157
mapping.emplace_back(newArgs[i]);
190158
}
@@ -207,26 +175,27 @@ struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
207175
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
208176

209177
rewriter.setInsertionPointAfter(newLoop);
210-
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
178+
Value one;
179+
if (isa<IndexType>(step.getType())) {
180+
one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
181+
} else {
182+
one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
183+
}
184+
211185
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
212186
Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
213187
len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
214188
len = rewriter.create<arith::DivSIOp>(loc, len, step);
215189
len = rewriter.create<arith::SubIOp>(loc, len, one);
216190
Value res = rewriter.create<arith::MulIOp>(loc, len, step);
217191
res = rewriter.create<arith::AddIOp>(loc, begin, res);
218-
if (res.getType() != iterVar.getType())
219-
res = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), res);
220192

221193
mapping.clear();
222194
llvm::append_range(mapping, newLoop.getResults());
223195
mapping.insert(mapping.begin() + argNumber, res);
224196
rewriter.replaceOp(loop, mapping);
225197
return success();
226198
}
227-
228-
private:
229-
unsigned indexBitWidth = 0;
230199
};
231200

232201
struct SCFUpliftWhileToFor final
@@ -237,14 +206,13 @@ struct SCFUpliftWhileToFor final
237206
Operation *op = getOperation();
238207
MLIRContext *ctx = op->getContext();
239208
RewritePatternSet patterns(ctx);
240-
mlir::scf::populateUpliftWhileToForPatterns(patterns, this->indexBitWidth);
209+
mlir::scf::populateUpliftWhileToForPatterns(patterns);
241210
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
242211
signalPassFailure();
243212
}
244213
};
245214
} // namespace
246215

247-
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
248-
unsigned indexBitwidth) {
249-
patterns.add<UpliftWhileOp>(patterns.getContext(), indexBitwidth);
216+
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
217+
patterns.add<UpliftWhileOp>(patterns.getContext());
250218
}

mlir/test/Dialect/SCF/uplift-while.mlir

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for{index-bitwidth=64}))' -split-input-file -allow-unregistered-dialect | FileCheck %s
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s
22

33
func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
44
%0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
@@ -141,22 +141,17 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
141141
}
142142

143143
// CHECK-LABEL: func @uplift_while
144-
// CHECK-SAME: (%[[BEGINI:.*]]: i64, %[[ENDI:.*]]: i64, %[[STEPI:.*]]: i64) -> i64
145-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
146-
// CHECK: %[[BEGIN:.*]] = arith.index_cast %[[BEGINI]] : i64 to index
147-
// CHECK: %[[END:.*]] = arith.index_cast %[[ENDI]] : i64 to index
148-
// CHECK: %[[STEP:.*]] = arith.index_cast %[[STEPI]] : i64 to index
149-
// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
150-
// CHECK: %[[II:.*]] = arith.index_cast %[[I]] : index to i64
151-
// CHECK: "test.test1"(%[[II]]) : (i64) -> ()
152-
// CHECK: %[[INC:.*]] = arith.addi %[[II]], %[[STEPI]] : i64
144+
// CHECK-SAME: (%[[BEGIN:.*]]: i64, %[[END:.*]]: i64, %[[STEP:.*]]: i64) -> i64
145+
// CHECK: %[[C1:.*]] = arith.constant 1 : i64
146+
// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] : i64 {
147+
// CHECK: "test.test1"(%[[I]]) : (i64) -> ()
148+
// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : i64
153149
// CHECK: "test.test2"(%[[INC]]) : (i64) -> ()
154-
// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
155-
// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
156-
// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
157-
// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
158-
// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
159-
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
160-
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
161-
// CHECK: %[[RES:.*]] = arith.index_cast %[[R7]] : index to i64
162-
// CHECK: return %[[RES]] : i64
150+
// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : i64
151+
// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : i64
152+
// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : i64
153+
// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : i64
154+
// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : i64
155+
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
156+
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
157+
// CHECK: return %[[R7]] : i64

0 commit comments

Comments
 (0)