Skip to content

[MLIR][Transform] apply_registered_op fixes: arg order & python options auto-conversion #143779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025

Conversation

rolfmorel
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 11, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/143779.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+3-3)
  • (modified) mlir/python/mlir/dialects/transform/init.py (+11-7)
  • (modified) mlir/test/Dialect/Transform/test-pass-application.mlir (+9-10)
  • (modified) mlir/test/python/dialects/transform.py (+5-5)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index f75ba27e58e76..0aa750e625436 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -434,10 +434,10 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
     of targeted ops.
   }];
 
-  let arguments = (ins StrAttr:$pass_name,
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       StrAttr:$pass_name,
                        DefaultValuedAttr<DictionaryAttr, "{}">:$options,
-                       Variadic<TransformParamTypeInterface>:$dynamic_options,
-                       TransformHandleTypeInterface:$target);
+                       Variadic<TransformParamTypeInterface>:$dynamic_options);
   let results = (outs TransformHandleTypeInterface:$result);
   let assemblyFormat = [{
     $pass_name (`with` `options` `=`
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 10a04b0cc14e0..bfe96b1b3e5d4 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -224,13 +224,13 @@ class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
     def __init__(
         self,
         result: Type,
-        pass_name: Union[str, StringAttr],
         target: Union[Operation, Value, OpView],
+        pass_name: Union[str, StringAttr],
         *,
         options: Optional[
             Dict[
                 Union[str, StringAttr],
-                Union[Attribute, Value, Operation, OpView],
+                Union[Attribute, Value, Operation, OpView, str, int, bool],
             ]
         ] = None,
         loc=None,
@@ -253,17 +253,21 @@ def __init__(
                 cur_param_operand_idx += 1
             elif isinstance(value, Attribute):
                 options_dict[key] = value
+            # The following cases auto-convert Python values to attributes.
+            elif isinstance(value, bool):
+                options_dict[key] = BoolAttr.get(value)
+            elif isinstance(value, int):
+                default_int_type = IntegerType.get_signless(64, context)
+                options_dict[key] = IntegerAttr.get(default_int_type, value)
             elif isinstance(value, str):
                 options_dict[key] = StringAttr.get(value)
             else:
                 raise TypeError(f"Unsupported option type: {type(value)}")
-        if len(options_dict) > 0:
-            print(options_dict, cur_param_operand_idx)
         super().__init__(
             result,
+            _get_op_result_or_value(target),
             pass_name,
             dynamic_options,
-            target=_get_op_result_or_value(target),
             options=DictAttr.get(options_dict),
             loc=loc,
             ip=ip,
@@ -272,13 +276,13 @@ def __init__(
 
 def apply_registered_pass(
     result: Type,
-    pass_name: Union[str, StringAttr],
     target: Union[Operation, Value, OpView],
+    pass_name: Union[str, StringAttr],
     *,
     options: Optional[
         Dict[
             Union[str, StringAttr],
-            Union[Attribute, Value, Operation, OpView],
+            Union[Attribute, Value, Operation, OpView, str, int, bool],
         ]
     ] = None,
     loc=None,
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 6e6d4eb7e249f..1d1be9eda3496 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} {
                          "test-convergence" = true,
                          "max-num-rewrites" =  %max_rewrites }
         to %1
-        : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+        : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
     transform.yield
   }
 }
@@ -171,7 +171,6 @@ func.func @invalid_options_as_str() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
     // expected-error @+2 {{expected '{' in options dictionary}}
     %2 = transform.apply_registered_pass "canonicalize"
         with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
@@ -256,7 +255,7 @@ module attributes {transform.with_named_sequence} {
     // expected-error @+2 {{expected '{' in options dictionary}}
     transform.apply_registered_pass "canonicalize"
         with options = %pass_options to %1
-        : (!transform.any_param, !transform.any_op) -> !transform.any_op
+        : (!transform.any_op, !transform.any_param) -> !transform.any_op
     transform.yield
   }
 }
@@ -276,7 +275,7 @@ module attributes {transform.with_named_sequence} {
     // expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
     transform.apply_registered_pass "canonicalize"
         with options = { "top-down" = %topdown_options } to %1
-        : (!transform.any_param, !transform.any_op) -> !transform.any_op
+        : (!transform.any_op, !transform.any_param) -> !transform.any_op
     transform.yield
   }
 }
@@ -316,12 +315,12 @@ module attributes {transform.with_named_sequence} {
     %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
     %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
     // expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
-    %2 = "transform.apply_registered_pass"(%1, %0) <{
+    %2 = "transform.apply_registered_pass"(%0, %1) <{
       options = {"max-iterations" = #transform.param_operand<index=1 : i64>,
                  "test-convergence" = true,
                  "top-down" = false},
       pass_name = "canonicalize"}>
-    : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    : (!transform.any_op, !transform.any_param) -> !transform.any_op
     "transform.yield"() : () -> ()
   }) : () -> ()
 }) {transform.with_named_sequence} : () -> ()
@@ -340,13 +339,13 @@ module attributes {transform.with_named_sequence} {
     %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
     %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
     // expected-error @below {{dynamic option index 0 is already used in options}}
-    %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+    %3 = "transform.apply_registered_pass"(%0, %1, %2) <{
       options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
                  "max-num-rewrites" = #transform.param_operand<index=0 : i64>,
                  "test-convergence" = true,
                  "top-down" = false},
       pass_name = "canonicalize"}>
-    : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+    : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
     "transform.yield"() : () -> ()
   }) : () -> ()
 }) {transform.with_named_sequence} : () -> ()
@@ -364,12 +363,12 @@ module attributes {transform.with_named_sequence} {
     %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
     %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
     // expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}}
-    %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+    %3 = "transform.apply_registered_pass"(%0, %1, %2) <{
       options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
                  "test-convergence" = true,
                  "top-down" = false},
       pass_name = "canonicalize"}>
-    : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+    : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
     "transform.yield"() : () -> ()
   }) : () -> ()
 }) {transform.with_named_sequence} : () -> ()
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 48bc9bad37a1e..eeb95605d7a9a 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -263,12 +263,12 @@ def testApplyRegisteredPassOp(module: Module):
     )
     with InsertionPoint(sequence.body):
         mod = transform.ApplyRegisteredPassOp(
-            transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+            transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
         )
         mod = transform.ApplyRegisteredPassOp(
             transform.AnyOpType.get(),
-            "canonicalize",
             mod.result,
+            "canonicalize",
             options={"top-down": BoolAttr.get(False)},
         )
         max_iter = transform.param_constant(
@@ -281,12 +281,12 @@ def testApplyRegisteredPassOp(module: Module):
         )
         transform.apply_registered_pass(
             transform.AnyOpType.get(),
-            "canonicalize",
             mod,
+            "canonicalize",
             options={
                 "top-down": BoolAttr.get(False),
                 "max-iterations": max_iter,
-                "test-convergence": BoolAttr.get(True),
+                "test-convergence": True,
                 "max-rewrites": max_rewrites,
             },
         )
@@ -305,4 +305,4 @@ def testApplyRegisteredPassOp(module: Module):
     # CHECK-SAME:                    "max-rewrites" =  %[[MAX_REWRITE]],
     # CHECK-SAME:                    "test-convergence" = true,
     # CHECK-SAME:                    "top-down" = false}
-    # CHECK-SAME:    to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+    # CHECK-SAME:    to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op

@rolfmorel
Copy link
Contributor Author

Merging without review so that downstream users don't need to deal with op's arg order having been different from what it was before: #142683

@rolfmorel rolfmorel merged commit fb761aa into llvm:main Jun 11, 2025
8 of 10 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jun 12, 2025

LLVM Buildbot has detected a new failure on builder ppc64le-flang-rhel-clang running on ppc64le-flang-rhel-test while building mlir at step 6 "test-build-unified-tree-check-flang".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/30544

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-flang) failure: test (failure)
******************** TEST 'Flang :: Semantics/modfile75.F90' FAILED ********************
Exit Code: 2

Command Output (stderr):
--
/home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=1 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 && /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=2 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 && /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -fc1 -fdebug-unparse /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 | /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/FileCheck /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 # RUN: at line 1
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=1 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=2 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -fc1 -fdebug-unparse /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/FileCheck /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
error: Semantic errors in /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
/home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90:15:11: error: Must be a constant value
    integer(c_int) n
            ^^^^^
FileCheck error: '<stdin>' is empty.
FileCheck command line:  /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/FileCheck /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90

--

********************


@joker-eph
Copy link
Collaborator

joker-eph commented Jun 12, 2025

Please keep the habit of providing a comprehensive description of the PR. The title barely gives an idea of the "what" but we're missing the context on this commit.

Also, skipping review when something is broken (or you're addressing trivial post-commit comments, or ...) is always OK, but your mention of a downstream convenience does not seem like an urgency to land to me: downstream need to be able to handle their integration with local patches or delays in fixes if an issue does not reproduce or show upstream (but here without the description I can't really say about the seriousness of the issue).

rolfmorel added a commit to libxsmm/tpp-mlir that referenced this pull request Jun 12, 2025
* llvm/llvm-project#139340
```
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

* llvm/llvm-project#141466 &
llvm/llvm-project#141019
  * Add `BufferizationState &state` to `bufferize` and `getBuffer` 

* llvm/llvm-project#143159 &
llvm/llvm-project#142683 &
llvm/llvm-project#143779
  * Updates to `transform.apply_registered_pass` and its Python-bindings

* llvm/llvm-project#143217
* `tilingResult->mergeResult.replacements` ->
`tilingResult->replacements`

* llvm/llvm-project#140559 &
llvm/llvm-project#143871
* Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s
& fix which enables conversion again.
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants