Skip to content

Commit 7191ced

Browse files
authored
[MLIR] Add folding constants canonicalization for mlir::index::AddOp. (#111084)
- [x] Add a simple canonicalization for `mlir::index::AddOp`.
1 parent 00b47b9 commit 7191ced

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

mlir/include/mlir/Dialect/Index/IR/IndexOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
5656
%c = index.add %a, %b
5757
```
5858
}];
59+
60+
let hasCanonicalizeMethod = 1;
5961
}
6062

6163
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,28 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
136136

137137
return {};
138138
}
139+
/// Canonicalize
140+
/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
141+
LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
142+
IntegerAttr c1, c2;
143+
if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
144+
return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
145+
146+
auto add = op.getLhs().getDefiningOp<mlir::index::AddOp>();
147+
if (!add)
148+
return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
149+
150+
if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2)))
151+
return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
152+
153+
auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
154+
c1.getInt() + c2.getInt());
155+
auto newAdd =
156+
rewriter.create<mlir::index::AddOp>(op->getLoc(), add.getLhs(), c);
157+
158+
rewriter.replaceOp(op, newAdd);
159+
return success();
160+
}
139161

140162
//===----------------------------------------------------------------------===//
141163
// SubOp

mlir/test/Dialect/Index/index-canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,19 @@ func.func @add_overflow() -> (index, index) {
3232
return %2, %3 : index, index
3333
}
3434

35+
// CHECK-LABEL: @add
36+
func.func @add_fold_constants(%arg: index) -> (index) {
37+
%0 = index.constant 1
38+
%1 = index.constant 2
39+
%2 = index.add %arg, %0
40+
%3 = index.add %2, %1
41+
42+
// CHECK-DAG: [[C3:%.*]] = index.constant 3
43+
// CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[C3]]
44+
// CHECK: return [[V0]]
45+
return %3 : index
46+
}
47+
3548
// CHECK-LABEL: @sub
3649
func.func @sub() -> index {
3750
%0 = index.constant -2000000000

0 commit comments

Comments
 (0)