-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Add folding constants canonicalization for mlir::index::AddOp. #111084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add folding constants canonicalization for mlir::index::AddOp. #111084
Conversation
@llvm/pr-subscribers-mlir-index Author: weiwei chen (weiweichen) Changes
Full diff: https://github.com/llvm/llvm-project/pull/111084.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
%c = index.add %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 42401dae217ce1..ace9b43014a665 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return {};
}
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+
+ auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
+ v = op.getLhs();
+ if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
+ v = op.getRhs();
+ if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
+ return false;
+ }
+ return true;
+ };
+
+ IntegerAttr c1, c2;
+ Value v1, v2;
+
+ if (!matchConstant(op, v1, c1))
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "neither LHS nor RHS is constant");
+
+ auto add = v1.getDefiningOp<mlir::index::AddOp>();
+ if (!add)
+ return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+ if (!matchConstant(add, v2, c2))
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "neither LHS nor RHS is constant");
+
+ auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+ c1.getInt() + c2.getInt());
+ auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+
+ rewriter.replaceOp(op, newAdd);
+ return success();
+}
//===----------------------------------------------------------------------===//
// SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..256e327e83ea9c 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,21 @@ func.func @add_overflow() -> (index, index) {
return %2, %3 : index, index
}
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 1
+ %1 = index.constant 2
+ %2 = index.add %arg, %0
+ %3 = index.add %1, %2
+ %4 = index.add %3, %1
+ %5 = index.add %4, %0
+
+ // CHECK-DAG: [[A:%.*]] = index.constant 6
+ // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+ // CHECK: return [[V0]]
+ return %5 : index
+}
+
// CHECK-LABEL: @sub
func.func @sub() -> index {
%0 = index.constant -2000000000
|
@llvm/pr-subscribers-mlir Author: weiwei chen (weiweichen) Changes
Full diff: https://github.com/llvm/llvm-project/pull/111084.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
%c = index.add %a, %b
```
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 42401dae217ce1..ace9b43014a665 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return {};
}
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+
+ auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
+ v = op.getLhs();
+ if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
+ v = op.getRhs();
+ if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
+ return false;
+ }
+ return true;
+ };
+
+ IntegerAttr c1, c2;
+ Value v1, v2;
+
+ if (!matchConstant(op, v1, c1))
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "neither LHS nor RHS is constant");
+
+ auto add = v1.getDefiningOp<mlir::index::AddOp>();
+ if (!add)
+ return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+ if (!matchConstant(add, v2, c2))
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "neither LHS nor RHS is constant");
+
+ auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+ c1.getInt() + c2.getInt());
+ auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+
+ rewriter.replaceOp(op, newAdd);
+ return success();
+}
//===----------------------------------------------------------------------===//
// SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..256e327e83ea9c 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,21 @@ func.func @add_overflow() -> (index, index) {
return %2, %3 : index, index
}
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+ %0 = index.constant 1
+ %1 = index.constant 2
+ %2 = index.add %arg, %0
+ %3 = index.add %1, %2
+ %4 = index.add %3, %1
+ %5 = index.add %4, %0
+
+ // CHECK-DAG: [[A:%.*]] = index.constant 6
+ // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+ // CHECK: return [[V0]]
+ return %5 : index
+}
+
// CHECK-LABEL: @sub
func.func @sub() -> index {
%0 = index.constant -2000000000
|
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { | |||
|
|||
return {}; | |||
} | |||
/// Canonicalize | |||
/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only permutation of the pattern that you actually need to implement. The canonicalizer will make sure that constants for commutative ops are always on the RHS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, good to know, PR updated!
…eic/canonicalize-index-add
@Mogball another 👀 please 🙏 ? |
mlir::index::AddOp
.