Skip to content

[mlir] transform.apply_patterns support more config options #88484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2024

Conversation

ftynse
Copy link
Member

@ftynse ftynse commented Apr 12, 2024

Greedy rewrite driver has options to control the number of rewrites applies. Expose those via the corresponding transform op.

@ftynse ftynse requested a review from wsmoses April 12, 2024 08:09
@llvmbot llvmbot added the mlir label Apr 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

Greedy rewrite driver has options to control the number of rewrites applies. Expose those via the corresponding transform op.


Full diff: https://github.com/llvm/llvm-project/pull/88484.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+4-1)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+7)
  • (modified) mlir/test/Dialect/Transform/test-pattern-application.mlir (+30)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 21c9595860d4c5..fbac1ffb621fd2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -331,7 +331,10 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
   }];
 
   let arguments = (ins
-    TransformHandleTypeInterface:$target, UnitAttr:$apply_cse);
+    TransformHandleTypeInterface:$target,
+    UnitAttr:$apply_cse,
+    DefaultValuedAttr<I64Attr, "static_cast<uint64_t>(-1)">:$max_iterations,
+    DefaultValuedAttr<I64Attr, "static_cast<uint64_t>(-1)">:$max_num_rewrites);
   let results = (outs);
   let regions = (region MaxSizedRegion<1>:$patterns);
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index dc19022219e5b2..53f958caa0bdb7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -396,6 +396,13 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
       static_cast<RewriterBase::Listener *>(rewriter.getListener());
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
 
+  config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
+                             ? GreedyRewriteConfig::kNoLimit
+                             : getMaxIterations();
+  config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
+                              ? GreedyRewriteConfig::kNoLimit
+                              : getMaxNumRewrites();
+
   // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
   // was requested, apply the greedy pattern rewrite only once. (The greedy
   // pattern rewrite driver already iterates to a fixpoint internally.)
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index fa8a555af92188..f36036a8855fd0 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -26,6 +26,36 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @limited_updates
+func.func @limited_updates() {
+  "test.container"() ({
+    // Only one is replaced.
+    // CHECK: "test.foo"() {replace_with_new_op = "test.foo"}
+    // CHECK: "test.foo"() {}
+    %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
+    %1 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
+  }) : () -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // Pattern application will fail because of the upper limit, wrap in
+    // sequence to suppress the error message.
+    transform.sequence %arg0 : !transform.any_op failures(suppress) {
+    ^bb0(%arg1: !transform.any_op):
+      %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+      %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+      transform.apply_patterns to %0 {
+        transform.apply_patterns.transform.test_patterns
+      }  {max_num_rewrites = 1} : !transform.any_op
+    }
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @replacement_op_not_found() {
   "test.container"() ({
     // expected-note @below {{[0] replaced op}}

Greedy rewrite driver has options to control the number of rewrites
applies. Expose those via the corresponding transform op.
@ftynse ftynse merged commit 37b26bf into llvm:main Apr 17, 2024
@ftynse ftynse deleted the td-patterns branch April 17, 2024 12:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants