Skip to content

[mlir][transform] Support for multiple top-level transform ops #69615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

martin-luecke
Copy link
Contributor

This adds a flag to the TransformDialectInterpreter that relaxes the requirement for only a single top-level transform op.
This is useful for supporting transforms that take transform IR as payload.

This also aligns the function findTopLevelTransform here with its documentation:
In the presence of multiple top-level transform ops it now correctly returns the first of them after reporting the error instead of returning a nullptr.

@martin-luecke martin-luecke requested a review from ftynse October 19, 2023 16:50
@llvmbot llvmbot added the mlir label Oct 19, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2023

@llvm/pr-subscribers-mlir

Author: None (martin-luecke)

Changes

This adds a flag to the TransformDialectInterpreter that relaxes the requirement for only a single top-level transform op.
This is useful for supporting transforms that take transform IR as payload.

This also aligns the function findTopLevelTransform here with its documentation:
In the presence of multiple top-level transform ops it now correctly returns the first of them after reporting the error instead of returning a nullptr.


Full diff: https://github.com/llvm/llvm-project/pull/69615.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+13)
  • (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (+13-11)
  • (added) mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir (+26)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp (+6)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 7b37245fc3d117b..60eb48a764eb4b2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -95,11 +95,24 @@ class TransformOptions {
     return *this;
   }
 
+  // Ensures that only a single top-level transform op is present in the IR.
+  TransformOptions &
+  enableEnforceSingleToplevelTransformOp(bool enable = true) {
+    enforceSingleToplevelTransformOp = enable;
+    return *this;
+  }
+
   /// Returns true if the expensive checks are requested.
   bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
 
+  // Returns true if enforcing a single top-level transform op is requested.
+  bool getEnforceSingleToplevelTransformOp() const {
+    return enforceSingleToplevelTransformOp;
+  }
+
 private:
   bool expensiveChecksEnabled = true;
+  bool enforceSingleToplevelTransformOp = true;
 };
 
 /// Entry point to the Transform dialect infrastructure. Applies the
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 538c81fe39fddb2..741456e7ebbfb86 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -56,10 +56,11 @@ constexpr static llvm::StringLiteral
 /// Reports an error if there is more than one such operation and returns the
 /// first one found. Reports an error returns nullptr if no such operation
 /// found.
-static Operation *findTopLevelTransform(Operation *root,
-                                        StringRef filenameOption) {
+static Operation *
+findTopLevelTransform(Operation *root, StringRef filenameOption,
+                      mlir::transform::TransformOptions options) {
   ::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
-  WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
+  root->walk<WalkOrder::PreOrder>(
       [&](::mlir::transform::TransformOpInterface transformOp) {
         if (!transformOp
                  ->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
@@ -68,14 +69,15 @@ static Operation *findTopLevelTransform(Operation *root,
           topLevelTransform = transformOp;
           return WalkResult::skip();
         }
-        auto diag = transformOp.emitError()
-                    << "more than one top-level transform op";
-        diag.attachNote(topLevelTransform.getLoc())
-            << "previous top-level transform op";
-        return WalkResult::interrupt();
+        if (options.getEnforceSingleToplevelTransformOp()) {
+          auto diag = transformOp.emitError()
+                      << "more than one top-level transform op";
+          diag.attachNote(topLevelTransform.getLoc())
+              << "previous top-level transform op";
+          return WalkResult::interrupt();
+        }
+        return WalkResult::skip();
       });
-  if (walkResult.wasInterrupted())
-    return nullptr;
   if (!topLevelTransform) {
     auto diag = root->emitError()
                 << "could not find a nested top-level transform op";
@@ -310,7 +312,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   Operation *transformRoot =
       debugTransformRootTag.empty()
           ? findTopLevelTransform(transformContainer,
-                                  transformFileName.getArgStr())
+                                  transformFileName.getArgStr(), options)
           : findOpWithTag(transformContainer, kTransformDialectTagAttrName,
                           debugTransformRootTag);
   if (!transformRoot)
diff --git a/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
new file mode 100644
index 000000000000000..db7fecdf753e984
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='enforce-single-top-level-transform-op=0' -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %match = transform.structured.match ops{["transform.get_parent_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %match, "found get_parent_op" : !transform.any_op
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %op = transform.structured.match ops{[]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-remark @below{{found get_parent_op}}
+  %1 = transform.get_parent_op %op : (!transform.any_op) -> !transform.any_op
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index c60b21c918338b4..756b7f669b0c5bf 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -158,6 +158,8 @@ class TestTransformDialectInterpreterPass
     }
 
     options = options.enableExpensiveChecks(enableExpensiveChecks);
+    options = options.enableEnforceSingleToplevelTransformOp(
+        enforceSingleToplevelTransformOp);
     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
             getOperation(), getArgument(), getSharedTransformModule(),
             getTransformLibraryModule(), extraMapping, options,
@@ -170,6 +172,10 @@ class TestTransformDialectInterpreterPass
       *this, "enable-expensive-checks", llvm::cl::init(false),
       llvm::cl::desc("perform expensive checks to better report errors in the "
                      "transform IR")};
+  Option<bool> enforceSingleToplevelTransformOp{
+      *this, "enforce-single-top-level-transform-op", llvm::cl::init(true),
+      llvm::cl::desc("Ensure that only a single top-level transform op is "
+                     "present in the IR.")};
 
   Option<std::string> bindFirstExtraToOps{
       *this, "bind-first-extra-to-ops",

@github-actions
Copy link

github-actions bot commented Oct 19, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay to have, but we are generally moving away from the model where a top-level op is executed and towards a model where there is a named entry point.

@martin-luecke martin-luecke force-pushed the transform_interpreter_multiple_top_level_ops branch from 7b0f4c9 to 23add3c Compare October 19, 2023 17:22
@martin-luecke martin-luecke merged commit 9ccf01f into llvm:main Oct 20, 2023
@martin-luecke martin-luecke deleted the transform_interpreter_multiple_top_level_ops branch October 20, 2023 09:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants