Skip to content

[MLIR] Add folding constants canonicalization for mlir::index::AddOp. #111084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 22, 2024

Conversation

weiweichen
Copy link
Contributor

  • Add a simple canonicalization for mlir::index::AddOp.

@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-mlir-index

Author: weiwei chen (weiweichen)

Changes
  • Add a simple canonicalization for mlir::index::AddOp.

Full diff: https://github.com/llvm/llvm-project/pull/111084.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Index/IR/IndexOps.td (+2)
  • (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+39)
  • (modified) mlir/test/Dialect/Index/index-canonicalize.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
     %c = index.add %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 42401dae217ce1..ace9b43014a665 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+
+  auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
+    v = op.getLhs();
+    if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
+      v = op.getRhs();
+      if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
+        return false;
+    }
+    return true;
+  };
+
+  IntegerAttr c1, c2;
+  Value v1, v2;
+
+  if (!matchConstant(op, v1, c1))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto add = v1.getDefiningOp<mlir::index::AddOp>();
+  if (!add)
+    return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+  if (!matchConstant(add, v2, c2))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+                                                    c1.getInt() + c2.getInt());
+  auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+
+  rewriter.replaceOp(op, newAdd);
+  return success();
+}
 
 //===----------------------------------------------------------------------===//
 // SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..256e327e83ea9c 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,21 @@ func.func @add_overflow() -> (index, index) {
   return %2, %3 : index, index
 }
 
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 1
+  %1 = index.constant 2
+  %2 = index.add %arg, %0
+  %3 = index.add %1, %2
+  %4 = index.add %3, %1
+  %5 = index.add %4, %0
+
+  // CHECK-DAG: [[A:%.*]] = index.constant 6
+  // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+  // CHECK: return [[V0]]
+  return %5 : index
+}
+
 // CHECK-LABEL: @sub
 func.func @sub() -> index {
   %0 = index.constant -2000000000

@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-mlir

Author: weiwei chen (weiweichen)

Changes
  • Add a simple canonicalization for mlir::index::AddOp.

Full diff: https://github.com/llvm/llvm-project/pull/111084.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Index/IR/IndexOps.td (+2)
  • (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+39)
  • (modified) mlir/test/Dialect/Index/index-canonicalize.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
     %c = index.add %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 42401dae217ce1..ace9b43014a665 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+
+  auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
+    v = op.getLhs();
+    if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
+      v = op.getRhs();
+      if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
+        return false;
+    }
+    return true;
+  };
+
+  IntegerAttr c1, c2;
+  Value v1, v2;
+
+  if (!matchConstant(op, v1, c1))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto add = v1.getDefiningOp<mlir::index::AddOp>();
+  if (!add)
+    return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+  if (!matchConstant(add, v2, c2))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+                                                    c1.getInt() + c2.getInt());
+  auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+
+  rewriter.replaceOp(op, newAdd);
+  return success();
+}
 
 //===----------------------------------------------------------------------===//
 // SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..256e327e83ea9c 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,21 @@ func.func @add_overflow() -> (index, index) {
   return %2, %3 : index, index
 }
 
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 1
+  %1 = index.constant 2
+  %2 = index.add %arg, %0
+  %3 = index.add %1, %2
+  %4 = index.add %3, %1
+  %5 = index.add %4, %0
+
+  // CHECK-DAG: [[A:%.*]] = index.constant 6
+  // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+  // CHECK: return [[V0]]
+  return %5 : index
+}
+
 // CHECK-LABEL: @sub
 func.func @sub() -> index {
   %0 = index.constant -2000000000

@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {

return {};
}
/// Canonicalize
/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only permutation of the pattern that you actually need to implement. The canonicalizer will make sure that constants for commutative ops are always on the RHS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, good to know, PR updated!

@weiweichen
Copy link
Contributor Author

@Mogball another 👀 please 🙏 ?

@weiweichen weiweichen merged commit 7191ced into llvm:main Oct 22, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants