Skip to content

[mlir,python] Expose replaceAllUsesExcept to Python bindings #115850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 20, 2024

Conversation

Wheest
Copy link
Contributor

@Wheest Wheest commented Nov 12, 2024

Problem originally described in the forums here.

Using the MLIR Python bindings, the method replaceAllUsesWith for Value is exposed, e.g.,

orig_value.replace_all_uses_with(
    new_value               
)

However, in my use-case I am separating a block into multiple blocks, so thus want to exclude certain Operations from having their Values replaced (since I want them to diverge).

Within Value, we have replaceAllUsesExcept, where we can pass the Operations which should be skipped.

This is not currently exposed in the Python bindings: this PR fixes this. Adds replace_all_uses_except, which works with individual Operations, and lists of Operations.

@llvmbot llvmbot added the mlir label Nov 12, 2024
@Wheest Wheest changed the title Expose replaceAllUsesExcept to Python bindings [mlir] Expose replaceAllUsesExcept to Python bindings Nov 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-mlir

Author: Perry Gibson (Wheest)

Changes

Problem originally described in the forums here.

Using the MLIR Python bindings, the method replaceAllUsesWith for Value is exposed, e.g.,

orig_value.replace_all_uses_with(
    new_value               
)

However, in my use-case I am separating a block into multiple blocks, so thus want to exclude certain Operations from having their Values replaced (since I want them to diverge).

Within Value, we have replaceAllUsesExcept, where we can pass the Operations which should be skipped.

This is not currently exposed in the Python bindings: this PR fixes this. Adds replace_all_uses_except, which works with individual Operations, and lists of Operations.


Full diff: https://github.com/llvm/llvm-project/pull/115850.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/IR.h (+16)
  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+32)
  • (modified) mlir/lib/CAPI/IR/IR.cpp (+26)
  • (modified) mlir/test/python/ir/value.py (+71)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b8a6f08b159817..012353993c341a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -956,6 +956,22 @@ MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
 MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of,
                                                       MlirValue with);
 
+/// Replace all uses of 'of' value with 'with' value, updating anything in the
+/// IR that uses 'of' to use 'with' instead, except if the user is listed in
+/// 'exceptions'. The 'exceptions' parameter is an array of MlirOperation
+/// pointers with a length of 'numExceptions'.
+MLIR_CAPI_EXPORTED void
+mlirValueReplaceAllUsesExceptWithSet(MlirValue of, MlirValue with,
+                                     MlirOperation *exceptions,
+                                     intptr_t numExceptions);
+
+/// Replace all uses of 'of' value with 'with' value, updating anything in the
+/// IR that uses 'of' to use 'with' instead, except if the user is
+/// 'exceptedUser'.
+MLIR_CAPI_EXPORTED void
+mlirValueReplaceAllUsesExceptWithSingle(MlirValue of, MlirValue with,
+                                        MlirOperation exceptedUser);
+
 //===----------------------------------------------------------------------===//
 // OpOperand API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 3562ff38201dc3..4bddcab8ccda6d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -178,6 +178,12 @@ static const char kValueReplaceAllUsesWithDocstring[] =
 the IR that uses 'self' to use the other value instead.
 )";
 
+static const char kValueReplaceAllUsesExceptDocstring[] =
+    R"("Replace all uses of this value with the 'with' value, except for those
+in 'exceptions'. 'exceptions' can be either a single operation or a list of
+operations.
+)";
+
 //------------------------------------------------------------------------------
 // Utilities.
 //------------------------------------------------------------------------------
@@ -3718,6 +3724,32 @@ void mlir::python::populateIRCore(py::module &m) {
             mlirValueReplaceAllUsesOfWith(self.get(), with.get());
           },
           kValueReplaceAllUsesWithDocstring)
+      .def(
+          "replace_all_uses_except",
+          [](PyValue &self, PyValue &with, py::object exceptions) {
+            MlirValue selfValue = self.get();
+            MlirValue withValue = with.get();
+
+            // Check if 'exceptions' is a list
+            if (py::isinstance<py::list>(exceptions)) {
+              // Convert Python list to a vector of MlirOperations
+              std::vector<MlirOperation> exceptionOps;
+              for (py::handle exception : exceptions) {
+                exceptionOps.push_back(exception.cast<PyOperation &>().get());
+              }
+              mlirValueReplaceAllUsesExceptWithSet(
+                  selfValue, withValue, exceptionOps.data(),
+                  static_cast<intptr_t>(exceptionOps.size()));
+            } else {
+              // Assume 'exceptions' is a single Operation
+              MlirOperation exceptedUser =
+                  exceptions.cast<PyOperation &>().get();
+              mlirValueReplaceAllUsesExceptWithSingle(selfValue, withValue,
+                                                      exceptedUser);
+            }
+          },
+          py::arg("with"), py::arg("exceptions"),
+          kValueReplaceAllUsesExceptDocstring)
       .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
            [](PyValue &self) { return self.maybeDownCast(); });
   PyBlockArgument::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e7e6b11c81b9d3..5fd5f0a8f36457 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/ThreadPool.h"
 
 #include <cstddef>
@@ -1009,6 +1010,31 @@ void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
   unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
 }
 
+void mlirValueReplaceAllUsesExceptWithSet(MlirValue oldValue,
+                                          MlirValue newValue,
+                                          MlirOperation *exceptions,
+                                          intptr_t numExceptions) {
+  auto oldValueCpp = unwrap(oldValue);
+  auto newValueCpp = unwrap(newValue);
+
+  llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet;
+  for (intptr_t i = 0; i < numExceptions; ++i) {
+    exceptionSet.insert(unwrap(exceptions[i]));
+  }
+
+  oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
+}
+
+void mlirValueReplaceAllUsesExceptWithSingle(MlirValue oldValue,
+                                             MlirValue newValue,
+                                             MlirOperation exceptedUser) {
+  auto oldValueCpp = unwrap(oldValue);
+  auto newValueCpp = unwrap(newValue);
+  auto exceptedUserCpp = unwrap(exceptedUser);
+
+  oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptedUserCpp);
+}
+
 //===----------------------------------------------------------------------===//
 // OpOperand API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 50b0e8403a7f21..0991d71151c894 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -148,6 +148,77 @@ def testValueReplaceAllUsesWith():
         print(f"Use operand_number: {use.operand_number}")
 
 
+# CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept
+@run
+def testValueReplaceAllUsesWithExcept():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            value = Operation.create("custom.op1", results=[i32]).results[0]
+            op1 = Operation.create("custom.op1", operands=[value])
+            op2 = Operation.create("custom.op2", operands=[value])
+            value2 = Operation.create("custom.op3", results=[i32]).results[0]
+            value.replace_all_uses_except(value2, [op1])
+
+    assert len(list(value.uses)) == 1
+
+    # CHECK: Use owner: "custom.op2"
+    # CHECK: Use operand_number: 0
+    for use in value2.uses:
+        assert use.owner in [op2]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+    # CHECK: Use owner: "custom.op1"
+    # CHECK: Use operand_number: 0
+    for use in value.uses:
+        assert use.owner in [op1]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+
+# CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions
+@run
+def testValueReplaceAllUsesWithMultipleExceptions():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            value = Operation.create("custom.op1", results=[i32]).results[0]
+            op1 = Operation.create("custom.op1", operands=[value])
+            op2 = Operation.create("custom.op2", operands=[value])
+            op3 = Operation.create("custom.op3", operands=[value])
+            value2 = Operation.create("custom.op4", results=[i32]).results[0]
+
+            # Replace all uses of `value` with `value2`, except for `op1` and `op2`.
+            value.replace_all_uses_except(value2, [op1, op2])
+
+    # After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`.
+    assert len(list(value.uses)) == 2
+    assert len(list(value2.uses)) == 1
+
+    # CHECK: Use owner: "custom.op3"
+    # CHECK: Use operand_number: 0
+    for use in value2.uses:
+        assert use.owner in [op3]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+    # CHECK: Use owner: "custom.op1"
+    # CHECK: Use operand_number: 0
+    # CHECK: Use owner: "custom.op2"
+    # CHECK: Use operand_number: 0
+    for use in value.uses:
+        assert use.owner in [op1, op2]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+
 # CHECK-LABEL: TEST: testValuePrintAsOperand
 @run
 def testValuePrintAsOperand():

@Wheest Wheest changed the title [mlir] Expose replaceAllUsesExcept to Python bindings [mlir,python] Expose replaceAllUsesExcept to Python bindings Nov 12, 2024
@Wheest Wheest force-pushed the mlir-python-replaceAllUsesExcept branch from 533600d to 7f37794 Compare November 12, 2024 12:07
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the feature. Let me know when you're ready for me to merge.

@Wheest
Copy link
Contributor Author

Wheest commented Nov 12, 2024

Cheers both! I'm happy to get it merged now if there are no further comments.

I've implemented @ftynse's suggestions.

For @makslevental, I added the PybindAdaptor, stuff, worth it I reckon.

Only things to point out are:

  1. If you note d9c3cd7, I had to revert MlirOperation exception back to PyOperation &exception, as this was not compatible with the ADL:

Calling value.replace_all_uses_except(value2, [op1, op2]) raised the error

# | TypeError: Expected an MLIR object (got [<mlir._mlir_libs._mlir.ir.Operation object at 0x71b9f1341270>, <mlir._mlir_libs._mlir.ir.Operation object at 0x71b9f1342730>]).
  1. I fixed the first test testValueReplaceAllUsesWithExcept
- value.replace_all_uses_except(value2, [op1])
+ value.replace_all_uses_except(value2, op1)

As you can see, before I was testing passing a list of an Operation, which doesn't cover the behaviour I wanted to check.

Btw, thanks for running https://github.com/makslevental/mlir-wheels/, it has been a boon to my hacking!

@makslevental
Copy link
Contributor

makslevental commented Nov 12, 2024

cool - let's let @ftynse approve so everyone is happy 😄

@Wheest Wheest requested a review from ftynse November 18, 2024 09:07
@makslevental makslevental merged commit 21df325 into llvm:main Nov 20, 2024
8 checks passed
@ftynse
Copy link
Member

ftynse commented Nov 26, 2024

cool - let's let @ftynse approve so everyone is happy

Don't worry too much about me, if there are things that make me want a second look, I put a blocker.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants