Skip to content

Commit aaf353d

Browse files
committed
ArithIndexingBuilder
1 parent c3fe42b commit aaf353d

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ Type getType(OpFoldResult ofr);
101101
/// Helper struct to build simple arithmetic quantities with minimal type
102102
/// inference support.
103103
struct ArithBuilder {
104-
ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
104+
ArithBuilder(
105+
OpBuilder &b, Location loc,
106+
arith::IntegerOverflowFlags ovf = arith::IntegerOverflowFlags::none)
107+
: b(b), loc(loc), ovf(ovf) {}
105108

106109
Value _and(Value lhs, Value rhs);
107110
Value add(Value lhs, Value rhs);
@@ -114,6 +117,15 @@ struct ArithBuilder {
114117
private:
115118
OpBuilder &b;
116119
Location loc;
120+
arith::IntegerOverflowFlags ovf;
121+
};
122+
123+
/// ArithBuilder specialized specifically for tensor/memref indexing
124+
/// calculations. Those calculations generally should never signed overflow, so
125+
/// we can set oveflow flags accordingly.
126+
struct ArithIndexingBuilder : public ArithBuilder {
127+
ArithIndexingBuilder(OpBuilder &b, Location loc)
128+
: ArithBuilder(b, loc, arith::IntegerOverflowFlags::nsw) {}
117129
};
118130

119131
namespace arith {

mlir/lib/Dialect/Arith/Utils/Utils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,17 +315,17 @@ Value ArithBuilder::_and(Value lhs, Value rhs) {
315315
Value ArithBuilder::add(Value lhs, Value rhs) {
316316
if (isa<FloatType>(lhs.getType()))
317317
return b.create<arith::AddFOp>(loc, lhs, rhs);
318-
return b.create<arith::AddIOp>(loc, lhs, rhs);
318+
return b.create<arith::AddIOp>(loc, lhs, rhs, ovf);
319319
}
320320
Value ArithBuilder::sub(Value lhs, Value rhs) {
321321
if (isa<FloatType>(lhs.getType()))
322322
return b.create<arith::SubFOp>(loc, lhs, rhs);
323-
return b.create<arith::SubIOp>(loc, lhs, rhs);
323+
return b.create<arith::SubIOp>(loc, lhs, rhs, ovf);
324324
}
325325
Value ArithBuilder::mul(Value lhs, Value rhs) {
326326
if (isa<FloatType>(lhs.getType()))
327327
return b.create<arith::MulFOp>(loc, lhs, rhs);
328-
return b.create<arith::MulIOp>(loc, lhs, rhs);
328+
return b.create<arith::MulIOp>(loc, lhs, rhs, ovf);
329329
}
330330
Value ArithBuilder::sgt(Value lhs, Value rhs) {
331331
if (isa<FloatType>(lhs.getType()))

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,15 +1162,14 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
11621162
OpBuilder::InsertionGuard g(rewriter);
11631163
rewriter.setInsertionPoint(loadOp);
11641164
Location loc = loadOp.getLoc();
1165+
ArithIndexingBuilder idxBuilderf(rewriter, loc);
11651166
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
11661167
OpFoldResult pos = extractPos[i - rankOffset];
11671168
if (isConstantIntValue(pos, 0))
11681169
continue;
11691170

11701171
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
1171-
1172-
auto ovf = arith::IntegerOverflowFlags::nsw;
1173-
indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
1172+
indices[i] = idxBuilderf.add(indices[i], offset);
11741173
}
11751174

11761175
Value base = loadOp.getBase();

0 commit comments

Comments
 (0)